Skip to content

Commit

Permalink
Fix #10380 - Memoize should handle lambdas (#7507)
Browse files Browse the repository at this point in the history
  • Loading branch information
Biotronic authored Dec 29, 2024
1 parent d895504 commit ff25f6d
Showing 1 changed file with 157 additions and 60 deletions.
217 changes: 157 additions & 60 deletions std/functional.d
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,15 @@ alias pipe(fun...) = compose!(Reverse!(fun));
assert(compose!(`a + 0.5`, `to!(int)(a) + 1`, foo)(1) == 2.5);
}

private template getOverloads(alias fun)
{
import std.meta : AliasSeq;
static if (__traits(compiles, __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun), true)))
alias getOverloads = __traits(getOverloads, __traits(parent, fun), __traits(identifier, fun), true);
else
alias getOverloads = AliasSeq!fun;
}

/**
* $(LINK2 https://en.wikipedia.org/wiki/Memoization, Memoizes) a function so as
* to avoid repeated computation. The memoization structure is a hash table keyed by a
Expand Down Expand Up @@ -1324,87 +1333,131 @@ Note:
*/
template memoize(alias fun)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
// Generic implementation
alias memoize = impl;
}

ReturnType!fun memoize(Parameters!fun args)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
alias Args = Parameters!fun;
import std.typecons : Tuple;
import std.typecons : Tuple, tuple;
import std.traits : Unqual;

static Unqual!(ReturnType!fun)[Tuple!Args] memo;
auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
static if (args.length > 0)
{
static Unqual!(typeof(fun(args)))[Tuple!(typeof(args))] memo;

auto t = Tuple!Args(args);
if (auto p = t in memo)
return *p;
auto r = fun(args);
memo[t] = r;
return r;
}
else
{
static typeof(fun(args)) result;
result = fun(args);
return result;
}
}
}

/// ditto
template memoize(alias fun, uint maxSize)
{
import std.traits : ReturnType;
// https://issues.dlang.org/show_bug.cgi?id=13580
// alias Args = Parameters!fun;
ReturnType!fun memoize(Parameters!fun args)
import std.traits : Parameters;
import std.meta : anySatisfy;

// Specific overloads:
alias overloads = getOverloads!fun;
static foreach (fn; overloads)
static if (is(Parameters!fn))
alias memoize = impl!(Parameters!fn);

enum isTemplate(alias a) = __traits(isTemplate, a);
static if (anySatisfy!(isTemplate, overloads))
{
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
static struct Value { staticMap!(Unqual, Parameters!fun) args; Unqual!(ReturnType!fun) res; }
static Value[] memo;
static size_t[] initialized;
// Generic implementation
alias memoize = impl;
}

if (!memo.length)
auto impl(Args...)(Args args) if (is(typeof(fun(args))))
{
static if (args.length > 0)
{
import core.memory : GC;
import std.meta : staticMap;
import std.traits : hasIndirections, Unqual;
import std.typecons : tuple;
alias returnType = typeof(fun(args));
static struct Value { staticMap!(Unqual, Args) args; Unqual!returnType res; }
static Value[] memo;
static size_t[] initialized;

// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));
if (!memo.length)
{
import core.memory : GC;

enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}
// Ensure no allocation overflows
static assert(maxSize < size_t.max / Value.sizeof);
static assert(maxSize < size_t.max - (8 * size_t.sizeof - 1));

import core.bitop : bt, bts;
import core.lifetime : emplace;
enum attr = GC.BlkAttr.NO_INTERIOR | (hasIndirections!Value ? 0 : GC.BlkAttr.NO_SCAN);
memo = (cast(Value*) GC.malloc(Value.sizeof * maxSize, attr))[0 .. maxSize];
enum nwords = (maxSize + 8 * size_t.sizeof - 1) / (8 * size_t.sizeof);
initialized = (cast(size_t*) GC.calloc(nwords * size_t.sizeof, attr | GC.BlkAttr.NO_SCAN))[0 .. nwords];
}

size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
import core.bitop : bt, bts;
import core.lifetime : emplace;

size_t hash;
foreach (ref arg; args)
hash = hashOf(arg, hash);
// cuckoo hashing
immutable idx1 = hash % maxSize;
if (!bt(initialized.ptr, idx1))
{
emplace(&memo[idx1], args, fun(args));
// only set to initialized after setting args and value
// https://issues.dlang.org/show_bug.cgi?id=14025
bts(initialized.ptr, idx1);
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

memo[idx1] = Value(args, fun(args));
return memo[idx1].res;
}
else if (memo[idx1].args == args)
return memo[idx1].res;
// FNV prime
immutable idx2 = (hash * 16_777_619) % maxSize;
if (!bt(initialized.ptr, idx2))
else
{
emplace(&memo[idx2], memo[idx1]);
bts(initialized.ptr, idx2);
static typeof(fun(args)) result;
result = fun(args);
return result;
}
else if (memo[idx2].args == args)
return memo[idx2].res;
else if (idx1 != idx2)
memo[idx2] = memo[idx1];

memo[idx1] = Value(args, fun(args));
return memo[idx1].res;
}
}

Expand Down Expand Up @@ -1464,6 +1517,37 @@ unittest
assert(fact(10) == 3628800);
}

// Issue 20099
@system unittest // not @safe due to memoize
{
int i = 3;
alias a = memoize!((n) => i + n);
alias b = memoize!((n) => i + n, 3);

assert(a(3) == 6);
assert(b(3) == 6);
}

@system unittest // not @safe due to memoize
{
static Object objNum(int a) { return new Object(); }
assert(memoize!objNum(0) is memoize!objNum(0U));
assert(memoize!(objNum, 3)(0) is memoize!(objNum, 3)(0U));
}

@system unittest // not @safe due to memoize
{
struct S
{
static int fun() { return 0; }
static int fun(int i) { return 1; }
}
assert(memoize!(S.fun)() == 0);
assert(memoize!(S.fun)(3) == 1);
assert(memoize!(S.fun, 3)() == 0);
assert(memoize!(S.fun, 3)(3) == 1);
}

@system unittest // not @safe due to memoize
{
import core.math : sqrt;
Expand Down Expand Up @@ -1626,6 +1710,19 @@ unittest
}}
}

// memoize should continue to work with functions that cannot be evaluated at compile time
@system unittest
{
__gshared string[string] glob;

static bool foo()
{
return (":-)" in glob) is null;
}

assert(memoize!foo);
}

private struct DelegateFaker(F)
{
import std.typecons : FuncInfo, MemberFunctionGenerator;
Expand Down

0 comments on commit ff25f6d

Please sign in to comment.