kubo39's blog

ただの雑記です。

いまさらマンデルブロー集合を書くので、なるべく型安全っぽくを意識した

プログラミングRustをちらちら読んでいたら並列化の例としてあったので、いまさらマンデルブロー集合を実装した。

元のコードがRustなのでなるべく型安全な感じに書いてみたつもり。

  • 他の言語でfind相当のものがcountUntilという名前で、findは別の操作を行う関数だった。紛らわしい。
  • PNGファイルの出力はimageformatsというライブラリを使った。ヘッダから自動で画像形式を推定してくれる、手軽に使えるので便利。
  • nogcというライブラリを使えば型変換をnogcでできるようだが、結局write_imageGC必要とするので用いなかった。
  • tupleの中の要素はinoutと組み合わせられないのか、 stack based variables can be inout だそう。

ソースコードは以下。

import core.stdc.stdlib : exit;
import std.complex;
import std.conv : to, ConvException;
import std.parallelism : parallel;
import std.range : chunks;
import std.stdio;
import std.typecons : Nullable, Tuple, tuple;

import imageformats : write_image;

Nullable!uint escapeTime(uint limit)(Complex!double c) @nogc nothrow pure @safe
{
    auto z = complex(0.0, 0.0);
    foreach (i; 0 .. limit)
    {
        z = z * z + c;
        if ((z.re * z.re + z.im * z.im) > 4.0)
            return typeof(return)(i);
    }
    return (typeof(return)).init;
}

alias Pair(T) = Nullable!(Tuple!(T, T));

Pair!T parsePair(T, char separator)(string s) pure @safe if (__traits(isArithmetic, T))
{
    import std.algorithm : countUntil;

    immutable index = s.countUntil(separator);
    if (index == -1)
        return (typeof(return)).init;
    try
    {
        T l = s[0 .. index].to!T;
        T r = s[index + 1 .. $].to!T;
        return typeof(return)(tuple(l, r));
    }
    catch (ConvException) return (typeof(return)).init;
}

@safe unittest
{
    assert(parsePair!(int, ',')("").isNull);
    assert(parsePair!(int, ',')("10,").isNull);
    assert(parsePair!(int, ',')("10,20").get() == tuple(10, 20));
    assert(parsePair!(int, ',')("10,20xy").isNull);
    assert(parsePair!(double, 'x')("0.5x").isNull);
    assert(parsePair!(double, 'x')("0.5x1.5").get() == tuple(0.5, 1.5));
}

Nullable!(Complex!double) parseComplex(string s) pure @safe
{
    immutable pair = parsePair!(double, ',')(s);
    if (pair.isNull) return (typeof(return)).init;
    return typeof(return)(complex(pair[0], pair[1]));
}

@safe unittest
{
    assert(parseComplex("1.25,-0.0625").get() == complex(1.25, -0.0625));
    assert(parseComplex(",-0.0625").isNull);
}

Complex!double pixelToPoint(Tuple!(size_t, size_t) bounds,
                            Tuple!(size_t, size_t) pixel,
                            Complex!double upperLeft,
                            Complex!double lowerRight) @nogc pure nothrow @safe
{
    immutable width = lowerRight.re - upperLeft.re;
    immutable height = upperLeft.im - lowerRight.im;
    return complex(upperLeft.re + pixel[0] * width / bounds[0],
                   upperLeft.im - pixel[1] * height / bounds[1]);
}

@nogc @safe unittest
{
    assert(pixelToPoint(tuple!(size_t, size_t)(100, 100),
                        tuple!(size_t, size_t)(25, 75),
                        complex(-1.0, 1.0), complex(1.0, -1.0))
           == complex(-0.5, -0.5));
}

void render(ref ubyte[] pixels, Tuple!(size_t, size_t) bounds,
            Complex!double upperLeft, Complex!double lowerRight) @nogc @safe
in { assert(pixels.length == bounds[0] * bounds[1]); }
do
{
    foreach (row; 0 .. bounds[1])
    {
        foreach (column; 0 .. bounds[0])
        {
            immutable point = pixelToPoint(bounds, tuple(column, row),
                                           upperLeft, lowerRight);
            immutable count = escapeTime!(255)(point);
            pixels[row * bounds[0] + column] =
                count.isNull ? 0 : cast(ubyte)(255 - count.get()); /* ensure count.get <= 255 */
        }
    }
}

void writeImage(string filename, const ubyte[] pixels,
                Tuple!(size_t, size_t) bounds)
{
    write_image(filename, bounds[0], bounds[1], pixels);
}

version(unittest) { void main() {} }
else
{
void main(string[] args)
{
    if (args.length != 5)
    {
        stderr.writeln("Usage: mandelbrot FILE PIXELS UPPERLEFT LOWERRIGHT");
        stderr.writefln("Example %s mandel.png 1000x750 -1.20,0.35 -1,0.20",
                        args[0]);
        exit(1);
    }
    immutable bounds = parsePair!(size_t, 'x')(args[2]).get();
    immutable upperLeft = parseComplex(args[3]).get();
    immutable lowerRight = parseComplex(args[4]).get();
    auto pixels = new ubyte[bounds[0] * bounds[1]];

    auto bands = pixels.chunks(bounds[0]);
    foreach (top, band; parallel(bands))
    {
        auto bandBounds = tuple(cast() bounds[0], 1UL);
        immutable bandUpperLeft = pixelToPoint(bounds, tuple(0UL, top),
                                               upperLeft, lowerRight);
        immutable bandLowerRight = pixelToPoint(bounds,
                                                tuple(cast() bounds[0], top + 1UL),
                                                upperLeft, lowerRight);
        render(band, bandBounds, bandUpperLeft, bandLowerRight);
    }
    writeImage(args[1], pixels, bounds);
}
}

追記

imageformatsのnogc版のimagefmtというライブラリがあった。

https://github.com/lgvz/imagefmt

これを使ってちょっとリファクタリングしたコードもはっておく。

import core.stdc.stdlib : exit;
import std.complex;
import std.conv : to, ConvException;
import std.parallelism : parallel;
import std.range : chunks;
import std.stdio;
import std.typecons : Nullable, Tuple, tuple;

import imagefmt : write_image;

Nullable!uint escapeTime(uint limit)(Complex!double c) @nogc nothrow pure @safe
{
    auto z = complex(0.0, 0.0);
    foreach (i; 0 .. limit)
    {
        z = z * z + c;
        if ((z.re * z.re + z.im * z.im) > 4.0)
            return typeof(return)(i);
    }
    return (typeof(return)).init;
}

alias Pair(T) = Nullable!(Tuple!(T, T));

Pair!T parsePair(T, char separator)(string s) pure @safe if (__traits(isArithmetic, T))
{
    import std.algorithm : countUntil;

    immutable index = s.countUntil(separator);
    if (index == -1)
        return (typeof(return)).init;
    try
    {
        T l = s[0 .. index].to!T;
        T r = s[index + 1 .. $].to!T;
        return typeof(return)(tuple(l, r));
    }
    catch (ConvException) return (typeof(return)).init;
}

@safe unittest
{
    assert(parsePair!(int, ',')("").isNull);
    assert(parsePair!(int, ',')("10,").isNull);
    assert(parsePair!(int, ',')("10,20").get() == tuple(10, 20));
    assert(parsePair!(int, ',')("10,20xy").isNull);
    assert(parsePair!(double, 'x')("0.5x").isNull);
    assert(parsePair!(double, 'x')("0.5x1.5").get() == tuple(0.5, 1.5));
}

Nullable!(Complex!double) parseComplex(string s) pure @safe
{
    immutable pair = parsePair!(double, ',')(s);
    if (pair.isNull) return (typeof(return)).init;
    return typeof(return)(complex(pair[0], pair[1]));
}

@safe unittest
{
    assert(parseComplex("1.25,-0.0625").get() == complex(1.25, -0.0625));
    assert(parseComplex(",-0.0625").isNull);
}

Complex!double pixelToPoint(T)(Tuple!(T, T) bounds, Tuple!(T, T) pixel,
                               Complex!double upperLeft,
                               Complex!double lowerRight) @nogc pure nothrow @safe
{
    immutable width = lowerRight.re - upperLeft.re;
    immutable height = upperLeft.im - lowerRight.im;
    return complex(upperLeft.re + pixel[0] * width / bounds[0],
                   upperLeft.im - pixel[1] * height / bounds[1]);
}

@nogc @safe unittest
{
    assert(pixelToPoint!int(tuple(100, 100), tuple(25, 75),
                            complex(-1.0, 1.0), complex(1.0, -1.0))
           == complex(-0.5, -0.5));
}

void render(T)(ref ubyte[] pixels, Tuple!(T, T) bounds,
               Complex!double upperLeft, Complex!double lowerRight) @nogc @safe
in { assert(pixels.length == bounds[0] * bounds[1]); }
do
{
    foreach (row; 0 .. bounds[1])
    {
        foreach (column; 0 .. bounds[0])
        {
            immutable point = pixelToPoint!T(bounds, tuple(column, row),
                                             upperLeft, lowerRight);
            immutable count = escapeTime!(255)(point);
            pixels[row * bounds[0] + column] =
                count.isNull ? 0 : cast(ubyte)(255 - count.get()); /* ensure count.get <= 255 */
        }
    }
}

int writeImage(T : int)(string filename, const ubyte[] pixels, Tuple!(T, T) bounds) nothrow @nogc
{
    return write_image(filename, bounds[0].to!int, bounds[1].to!int, pixels);
}

version(unittest) { void main() {} }
else
{
int main(string[] args)
{
    if (args.length != 5)
    {
        stderr.writeln("Usage: mandelbrot FILE PIXELS UPPERLEFT LOWERRIGHT");
        stderr.writefln("Example %s mandel.png 1000x750 -1.20,0.35 -1,0.20",
                        args[0]);
        exit(1);
    }
    immutable bounds = parsePair!(int, 'x')(args[2]).get();
    immutable upperLeft = parseComplex(args[3]).get();
    immutable lowerRight = parseComplex(args[4]).get();
    auto pixels = new ubyte[bounds[0] * bounds[1]];

    auto bands = pixels.chunks(bounds[0]);
    foreach (top, band; parallel(bands))
    {
        auto bandBounds = tuple(cast() bounds[0], 1);
        immutable bandUpperLeft = pixelToPoint(bounds, tuple(0, top.to!int),
                                               upperLeft, lowerRight);
        immutable bandLowerRight = pixelToPoint(bounds,
                                                tuple(cast() bounds[0], top.to!int + 1),
                                                upperLeft, lowerRight);
        render(band, bandBounds, bandUpperLeft, bandLowerRight);
    }
    return writeImage(args[1], pixels, bounds);
}
}