From bce5de8e37c4cd0d23515781f73547d3a2d2e888 Mon Sep 17 00:00:00 2001 From: "P. Oscar Boykin" Date: Sat, 2 Sep 2017 13:57:56 -1000 Subject: [PATCH] Improve Eval.memoize --- core/src/main/scala/cats/Eval.scala | 140 ++++++++++++++++++---------- 1 file changed, 93 insertions(+), 47 deletions(-) diff --git a/core/src/main/scala/cats/Eval.scala b/core/src/main/scala/cats/Eval.scala index 4361e9903b..8641e1ba5c 100644 --- a/core/src/main/scala/cats/Eval.scala +++ b/core/src/main/scala/cats/Eval.scala @@ -253,34 +253,32 @@ object Eval extends EvalInstances { * they will be automatically created when needed. */ sealed abstract class Call[A](val thunk: () => Eval[A]) extends Eval[A] { - def memoize: Eval[A] = new Later(() => value) - def value: A = Call.loop(this).value + def memoize: Eval[A] = Memoize(this) + def value: A = evaluate(this) } - object Call { - - /** - * Collapse the call stack for eager evaluations. - */ - @tailrec private def loop[A](fa: Eval[A]): Eval[A] = fa match { - case call: Eval.Call[A] => - loop(call.thunk()) - case compute: Eval.Compute[A] => - new Eval.Compute[A] { - type Start = compute.Start - val start: () => Eval[Start] = () => compute.start() - val run: Start => Eval[A] = s => loop1(compute.run(s)) - } - case other => other - } - - /** - * Alias for loop that can be called in a non-tail position - * from an otherwise tailrec-optimized loop. - */ - private def loop1[A](fa: Eval[A]): Eval[A] = loop(fa) + /** + * Collapse the call stack for eager evaluations. + * returns a non Call Eval node + */ + @tailrec private def doCall[A](fa: Eval[A]): Eval[A] = fa match { + case call: Eval.Call[A] => + doCall(call.thunk()) + case compute: Eval.Compute[A] => + new Eval.Compute[A] { + type Start = compute.Start + val start: () => Eval[Start] = () => compute.start() + val run: Start => Eval[A] = s => doCall1(compute.run(s)) + } + case other => other } + /** + * Alias for doCall that can be called in a non-tail position + * from an otherwise tailrec-optimized doCall. + */ + private def doCall1[A](fa: Eval[A]): Eval[A] = doCall(fa) + /** * Compute is a type of Eval[A] that is used to chain computations * involving .map and .flatMap. Along with Eval#flatMap it @@ -299,30 +297,78 @@ object Eval extends EvalInstances { val start: () => Eval[Start] val run: Start => Eval[A] - def memoize: Eval[A] = Later(value) - - def value: A = { - type L = Eval[Any] - type C = Any => Eval[Any] - @tailrec def loop(curr: L, fs: List[C]): Any = - curr match { - case c: Compute[_] => - c.start() match { - case cc: Compute[_] => - loop( - cc.start().asInstanceOf[L], - cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) - case xx => - loop(c.run(xx.value), fs) - } - case x => - fs match { - case f :: fs => loop(f(x.value), fs) - case Nil => x.value - } - } - loop(this.asInstanceOf[L], Nil).asInstanceOf[A] + def memoize: Eval[A] = Memoize(this) + + def value: A = evaluate(this) + } + + private case class Memoize[A](eval: Eval[A]) extends Eval[A] { + var result: Option[A] = None + def memoize: Eval[A] = this + def value: A = + result match { + case Some(a) => a + case None => + val a = evaluate(this) + result = Some(a) + a + } + } + + private def evaluate[A](e: Eval[A]): A = { + type L = Eval[Any] + type C = Any => Eval[Any] + + /** + * Here we implement memoization in a way + * that only callers to memoize pay a price. + * We insert a side effecting fn into the pipeline + * the updates the memo when it is not yet set + */ + val memos = collection.mutable.Map.empty[L, Any] + def addToMemo(m: Memoize[Any]): C = { a: Any => + memos.put(m.eval, a) + m.result = Some(a) + Now(a) } + @tailrec def loop(curr: L, fs: List[C]): Any = + curr match { + case c: Compute[_] => + c.start() match { + case cc: Compute[_] => + loop( + cc.start().asInstanceOf[L], + cc.run.asInstanceOf[C] :: c.run.asInstanceOf[C] :: fs) + case xx => + loop(c.run(xx.value), fs) + } + case call: Call[_] => + loop(doCall(call), fs) + case m@Memoize(eval) => + m.result match { + case Some(a) => + fs match { + case f :: fs => loop(f(a), fs) + case Nil => a + } + case None => + memos.get(eval) match { + case Some(a) => + fs match { + case f :: fs => loop(f(a), fs) + case Nil => a + } + case None => + loop(eval, addToMemo(m) :: fs) + } + } + case x => + fs match { + case f :: fs => loop(f(x.value), fs) + case Nil => x.value + } + } + loop(e.asInstanceOf[L], Nil).asInstanceOf[A] } }