プログラミングRustをちらちら読んでいたら並列化の例としてあったので、いまさらマンデルブロー集合を実装した。
元のコードがRustなのでなるべく型安全な感じに書いてみたつもり。
- 他の言語でfind相当のものがcountUntilという名前で、findは別の操作を行う関数だった。紛らわしい。
- PNGファイルの出力はimageformatsというライブラリを使った。ヘッダから自動で画像形式を推定してくれる、手軽に使えるので便利。
- nogcというライブラリを使えば型変換をnogcでできるようだが、結局
write_image
がGC必要とするので用いなかった。 - tupleの中の要素はinoutと組み合わせられないのか、
stack based variables can be inout
だそう。
ソースコードは以下。
import core.stdc.stdlib : exit; import std.complex; import std.conv : to, ConvException; import std.parallelism : parallel; import std.range : chunks; import std.stdio; import std.typecons : Nullable, Tuple, tuple; import imageformats : write_image; Nullable!uint escapeTime(uint limit)(Complex!double c) @nogc nothrow pure @safe { auto z = complex(0.0, 0.0); foreach (i; 0 .. limit) { z = z * z + c; if ((z.re * z.re + z.im * z.im) > 4.0) return typeof(return)(i); } return (typeof(return)).init; } alias Pair(T) = Nullable!(Tuple!(T, T)); Pair!T parsePair(T, char separator)(string s) pure @safe if (__traits(isArithmetic, T)) { import std.algorithm : countUntil; immutable index = s.countUntil(separator); if (index == -1) return (typeof(return)).init; try { T l = s[0 .. index].to!T; T r = s[index + 1 .. $].to!T; return typeof(return)(tuple(l, r)); } catch (ConvException) return (typeof(return)).init; } @safe unittest { assert(parsePair!(int, ',')("").isNull); assert(parsePair!(int, ',')("10,").isNull); assert(parsePair!(int, ',')("10,20").get() == tuple(10, 20)); assert(parsePair!(int, ',')("10,20xy").isNull); assert(parsePair!(double, 'x')("0.5x").isNull); assert(parsePair!(double, 'x')("0.5x1.5").get() == tuple(0.5, 1.5)); } Nullable!(Complex!double) parseComplex(string s) pure @safe { immutable pair = parsePair!(double, ',')(s); if (pair.isNull) return (typeof(return)).init; return typeof(return)(complex(pair[0], pair[1])); } @safe unittest { assert(parseComplex("1.25,-0.0625").get() == complex(1.25, -0.0625)); assert(parseComplex(",-0.0625").isNull); } Complex!double pixelToPoint(Tuple!(size_t, size_t) bounds, Tuple!(size_t, size_t) pixel, Complex!double upperLeft, Complex!double lowerRight) @nogc pure nothrow @safe { immutable width = lowerRight.re - upperLeft.re; immutable height = upperLeft.im - lowerRight.im; return complex(upperLeft.re + pixel[0] * width / bounds[0], upperLeft.im - pixel[1] * height / bounds[1]); } @nogc @safe unittest { assert(pixelToPoint(tuple!(size_t, size_t)(100, 100), tuple!(size_t, size_t)(25, 75), complex(-1.0, 1.0), complex(1.0, -1.0)) == complex(-0.5, -0.5)); } void render(ref ubyte[] pixels, Tuple!(size_t, size_t) bounds, Complex!double upperLeft, Complex!double lowerRight) @nogc @safe in { assert(pixels.length == bounds[0] * bounds[1]); } do { foreach (row; 0 .. bounds[1]) { foreach (column; 0 .. bounds[0]) { immutable point = pixelToPoint(bounds, tuple(column, row), upperLeft, lowerRight); immutable count = escapeTime!(255)(point); pixels[row * bounds[0] + column] = count.isNull ? 0 : cast(ubyte)(255 - count.get()); /* ensure count.get <= 255 */ } } } void writeImage(string filename, const ubyte[] pixels, Tuple!(size_t, size_t) bounds) { write_image(filename, bounds[0], bounds[1], pixels); } version(unittest) { void main() {} } else { void main(string[] args) { if (args.length != 5) { stderr.writeln("Usage: mandelbrot FILE PIXELS UPPERLEFT LOWERRIGHT"); stderr.writefln("Example %s mandel.png 1000x750 -1.20,0.35 -1,0.20", args[0]); exit(1); } immutable bounds = parsePair!(size_t, 'x')(args[2]).get(); immutable upperLeft = parseComplex(args[3]).get(); immutable lowerRight = parseComplex(args[4]).get(); auto pixels = new ubyte[bounds[0] * bounds[1]]; auto bands = pixels.chunks(bounds[0]); foreach (top, band; parallel(bands)) { auto bandBounds = tuple(cast() bounds[0], 1UL); immutable bandUpperLeft = pixelToPoint(bounds, tuple(0UL, top), upperLeft, lowerRight); immutable bandLowerRight = pixelToPoint(bounds, tuple(cast() bounds[0], top + 1UL), upperLeft, lowerRight); render(band, bandBounds, bandUpperLeft, bandLowerRight); } writeImage(args[1], pixels, bounds); } }
追記
imageformatsのnogc版のimagefmtというライブラリがあった。
https://github.com/lgvz/imagefmt
これを使ってちょっとリファクタリングしたコードもはっておく。
import core.stdc.stdlib : exit; import std.complex; import std.conv : to, ConvException; import std.parallelism : parallel; import std.range : chunks; import std.stdio; import std.typecons : Nullable, Tuple, tuple; import imagefmt : write_image; Nullable!uint escapeTime(uint limit)(Complex!double c) @nogc nothrow pure @safe { auto z = complex(0.0, 0.0); foreach (i; 0 .. limit) { z = z * z + c; if ((z.re * z.re + z.im * z.im) > 4.0) return typeof(return)(i); } return (typeof(return)).init; } alias Pair(T) = Nullable!(Tuple!(T, T)); Pair!T parsePair(T, char separator)(string s) pure @safe if (__traits(isArithmetic, T)) { import std.algorithm : countUntil; immutable index = s.countUntil(separator); if (index == -1) return (typeof(return)).init; try { T l = s[0 .. index].to!T; T r = s[index + 1 .. $].to!T; return typeof(return)(tuple(l, r)); } catch (ConvException) return (typeof(return)).init; } @safe unittest { assert(parsePair!(int, ',')("").isNull); assert(parsePair!(int, ',')("10,").isNull); assert(parsePair!(int, ',')("10,20").get() == tuple(10, 20)); assert(parsePair!(int, ',')("10,20xy").isNull); assert(parsePair!(double, 'x')("0.5x").isNull); assert(parsePair!(double, 'x')("0.5x1.5").get() == tuple(0.5, 1.5)); } Nullable!(Complex!double) parseComplex(string s) pure @safe { immutable pair = parsePair!(double, ',')(s); if (pair.isNull) return (typeof(return)).init; return typeof(return)(complex(pair[0], pair[1])); } @safe unittest { assert(parseComplex("1.25,-0.0625").get() == complex(1.25, -0.0625)); assert(parseComplex(",-0.0625").isNull); } Complex!double pixelToPoint(T)(Tuple!(T, T) bounds, Tuple!(T, T) pixel, Complex!double upperLeft, Complex!double lowerRight) @nogc pure nothrow @safe { immutable width = lowerRight.re - upperLeft.re; immutable height = upperLeft.im - lowerRight.im; return complex(upperLeft.re + pixel[0] * width / bounds[0], upperLeft.im - pixel[1] * height / bounds[1]); } @nogc @safe unittest { assert(pixelToPoint!int(tuple(100, 100), tuple(25, 75), complex(-1.0, 1.0), complex(1.0, -1.0)) == complex(-0.5, -0.5)); } void render(T)(ref ubyte[] pixels, Tuple!(T, T) bounds, Complex!double upperLeft, Complex!double lowerRight) @nogc @safe in { assert(pixels.length == bounds[0] * bounds[1]); } do { foreach (row; 0 .. bounds[1]) { foreach (column; 0 .. bounds[0]) { immutable point = pixelToPoint!T(bounds, tuple(column, row), upperLeft, lowerRight); immutable count = escapeTime!(255)(point); pixels[row * bounds[0] + column] = count.isNull ? 0 : cast(ubyte)(255 - count.get()); /* ensure count.get <= 255 */ } } } int writeImage(T : int)(string filename, const ubyte[] pixels, Tuple!(T, T) bounds) nothrow @nogc { return write_image(filename, bounds[0].to!int, bounds[1].to!int, pixels); } version(unittest) { void main() {} } else { int main(string[] args) { if (args.length != 5) { stderr.writeln("Usage: mandelbrot FILE PIXELS UPPERLEFT LOWERRIGHT"); stderr.writefln("Example %s mandel.png 1000x750 -1.20,0.35 -1,0.20", args[0]); exit(1); } immutable bounds = parsePair!(int, 'x')(args[2]).get(); immutable upperLeft = parseComplex(args[3]).get(); immutable lowerRight = parseComplex(args[4]).get(); auto pixels = new ubyte[bounds[0] * bounds[1]]; auto bands = pixels.chunks(bounds[0]); foreach (top, band; parallel(bands)) { auto bandBounds = tuple(cast() bounds[0], 1); immutable bandUpperLeft = pixelToPoint(bounds, tuple(0, top.to!int), upperLeft, lowerRight); immutable bandLowerRight = pixelToPoint(bounds, tuple(cast() bounds[0], top.to!int + 1), upperLeft, lowerRight); render(band, bandBounds, bandUpperLeft, bandLowerRight); } return writeImage(args[1], pixels, bounds); } }