Skip to content

Commit

Permalink
Stack-safe FreeApplicative
Browse files Browse the repository at this point in the history
  • Loading branch information
edmundnoble committed Jul 21, 2017
1 parent 16ea2ed commit 92b6edc
Show file tree
Hide file tree
Showing 4 changed files with 247 additions and 53 deletions.
19 changes: 19 additions & 0 deletions docs/src/main/tut/datatypes/freeapplicative.md
Original file line number Diff line number Diff line change
Expand Up @@ -164,5 +164,24 @@ val prodCompiler: FunctionK[ValidationOp, ValidateAndLog] = parCompiler and logC
val prodValidation = prog.foldMap[ValidateAndLog](prodCompiler)
```

### The way FreeApplicative#foldMap works
Despite being an imperative loop, there is a functional intuition behind `FreeApplicative#foldMap`.

The new `FreeAp`'s `foldMap` is a sort of mutually-recursive function that operates on an argument stack and a
function stack, where the argument stack has type `List[FreeAp[F, _]]` and the functions have type `List[Fn[G, _, _]]`.
`Fn[G[_, _]]` contains a function to be `Ap`'d that has already been translated to the target `Applicative`,
as well as the number of functions that were `Ap`'d immediately subsequently to it.

#### Main re-association loop
Pull an argument out of the stack, eagerly remove right-associated `Ap` nodes, by looping on the right and
adding the `Ap` nodes' arguments on the left to the argument stack; at the end, pushes a single function to the
function stack of the applied functions, the rest of which will be pushed in this loop in later iterations.
Once all of the Ap nodes on the right are removed, the loop resets to deal with the ones on the left.

#### Function application loop
Then it has a loop which pulls functions from the stack until it reaches a curried function,
in which case it applies a single argument, pushes its continuation on to the function stack,
and returns to the main loop.

## References
Deeper explanations can be found in this paper [Free Applicative Functors by Paolo Capriotti](http://www.paolocapriotti.com/assets/applicative.pdf)
172 changes: 147 additions & 25 deletions free/src/main/scala/cats/free/FreeApplicative.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,49 +4,149 @@ package free
import cats.arrow.FunctionK
import cats.data.Const

/** Applicative Functor for Free */
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable { self =>
import scala.annotation.tailrec

/**
* Applicative Functor for Free,
* implementation inspired by https://github.com/safareli/free/pull/31/
*/
sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable {
self =>
// ap => apply alias needed so we can refer to both
// FreeApplicative.ap and FreeApplicative#ap
import FreeApplicative.{FA, Pure, Ap, ap => apply, lift}
import FreeApplicative.{FA, Pure, Ap, lift}

final def ap[B](b: FA[F, A => B]): FA[F, B] =
final def ap[B](b: FA[F, A => B]): FA[F, B] = {
b match {
case Pure(f) =>
this.map(f)
case Ap(pivot, fn) =>
apply(pivot)(self.ap(fn.map(fx => a => p => fx(p)(a))))
case _ =>
Ap(b, this)
}
}

final def map[B](f: A => B): FA[F, B] =
final def map[B](f: A => B): FA[F, B] = {
this match {
case Pure(a) => Pure(f(a))
case Ap(pivot, fn) => apply(pivot)(fn.map(f compose _))
case _ => Ap(Pure(f), this)
}
}

/** Interprets/Runs the sequence of operations using the semantics of Applicative G
* Tail recursive only if G provides tail recursive interpretation (ie G is FreeMonad)
*/
final def foldMap[G[_]](f: FunctionK[F, G])(implicit G: Applicative[G]): G[A] =
this match {
case Pure(a) => G.pure(a)
case Ap(pivot, fn) => G.map2(f(pivot), fn.foldMap(f))((a, g) => g(a))
/** Interprets/Runs the sequence of operations using the semantics of `Applicative` G[_].
* Tail recursive.
*/
// scalastyle:off method.length
final def foldMap[G[_]](f: F ~> G)(implicit G: Applicative[G]): G[A] = {
import FreeApplicative._
// the remaining arguments to G[A => B]'s
var argsF: List[FA[F, Any]] = this.asInstanceOf[FA[F, Any]] :: Nil
var argsFLength: Int = 1
// the remaining stack of G[A => B]'s to be applied to the arguments
var fns: List[Fn[G, Any, Any]] = Nil
var fnsLength: Int = 0

@tailrec
def loop(): G[Any] = {
var argF: FA[F, Any] = argsF.head
argsF = argsF.tail
argsFLength -= 1

// rip off every `Ap` in `argF`, peeling off left-associated prefixes
if (argF.isInstanceOf[Ap[F, _, _]]) {
val lengthInitial = argsFLength
// reassociate the functions into a single fn,
// and move the arguments into argsF
do {
val ap = argF.asInstanceOf[Ap[F, Any, Any]]
argsF ::= ap.fp
argsFLength += 1
argF = ap.fn.asInstanceOf[FA[F, Any]]
} while (argF.isInstanceOf[Ap[F, _, _]])
// consecutive `ap` calls have been queued as operations;
// argF is no longer an `Ap` node, so the entire topmost left-associated
// function application branch has been looped through and we've
// moved (`argsFLength` - `lengthInitial`) arguments to the stack, through
// (`argsFLength` - `lengthInitial`) `Ap` nodes, so the function on the right
// which consumes them all must have (`argsFLength` - `lengthInitial`) arguments
val argc = argsFLength - lengthInitial
fns ::= Fn[G, Any, Any](foldArg(argF.asInstanceOf[FA[F, Any => Any]], f), argc)
fnsLength += 1
loop()
} else {
val argT: G[Any] = foldArg(argF, f)
if (fns ne Nil) {
// single right-associated function application
var fn = fns.head
fns = fns.tail
fnsLength -= 1
var res = G.ap(fn.gab)(argT)
if (fn.argc > 1) {
// this function has more than 1 argument,
// bail out of nested right-associated function application
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.argc - 1)
fnsLength += 1
loop()
} else {
if (fnsLength > 0) {
// we've got a nested right-associated `Ap` tree,
// so apply as many functions as possible
@tailrec
def innerLoop(): Unit = {
fn = fns.head
fns = fns.tail
fnsLength -= 1
res = G.ap(fn.gab)(res)
if (fn.argc > 1) {
fns ::= Fn(res.asInstanceOf[G[Any => Any]], fn.argc - 1)
fnsLength += 1
}
// we have to bail out if fn has more than one argument,
// because it means we may have more left-associated trees
// deeper to the right in the application tree
if (fn.argc == 1 && fnsLength > 0) innerLoop()
}

innerLoop()
}
if (fnsLength == 0) res
else loop()
}
} else argT
}
}

/** Interpret/run the operations using the semantics of `Applicative[F]`.
* Tail recursive only if `F` provides tail recursive interpretation.
*/
loop().asInstanceOf[G[A]]
}
// scalastyle:on method.length


/**
* Interpret/run the operations using the semantics of `Applicative[F]`.
* Stack-safe.
*/
final def fold(implicit F: Applicative[F]): F[A] =
foldMap(FunctionK.id[F])

/** Interpret this algebra into another FreeApplicative */
final def compile[G[_]](f: FunctionK[F, G]): FA[G, A] =
/**
* Interpret this algebra into another algebra.
* Stack-safe.
*/
final def compile[G[_]](f: F ~> G): FA[G, A] =
foldMap[FA[G, ?]] {
λ[FunctionK[F, FA[G, ?]]](fa => lift(f(fa)))
}

/** Interpret this algebra into a Monoid */
final def analyze[M:Monoid](f: FunctionK[F, λ[α => M]]): M =
/**
* Interpret this algebra into a FreeApplicative over another algebra.
* Stack-safe.
*/
def flatCompile[G[_]](f: F ~> FA[G, ?]): FA[G, A] =
foldMap(f)

/**
* Interpret this algebra into a Monoid
*/
final def analyze[M: Monoid](f: FunctionK[F, λ[α => M]]): M =
foldMap[Const[M, ?]](
λ[FunctionK[F, Const[M, ?]]](x => Const(f(x)))
).getConst
Expand All @@ -63,23 +163,45 @@ sealed abstract class FreeApplicative[F[_], A] extends Product with Serializable
object FreeApplicative {
type FA[F[_], A] = FreeApplicative[F, A]

// Internal helper function for foldMap, it folds only Pure and Lift nodes
private[free] def foldArg[F[_], G[_], A](node: FA[F, A], f: F ~> G)(implicit G: Applicative[G]): G[A] =
if (node.isInstanceOf[Pure[F, A]]) {
val Pure(x) = node
G.pure(x)
} else {
val Lift(fa) = node
f(fa)
}

/** Represents a curried function `F[A => B => C => ...]`
* that has been constructed with chained `ap` calls.
* Fn#argc denotes the amount of curried params remaining.
*/
private final case class Fn[G[_], A, B](gab: G[A => B], argc: Int)

private final case class Pure[F[_], A](a: A) extends FA[F, A]

private final case class Ap[F[_], P, A](pivot: F[P], fn: FA[F, P => A]) extends FA[F, A]
private final case class Lift[F[_], A](fa: F[A]) extends FA[F, A]

private final case class Ap[F[_], P, A](fn: FA[F, P => A], fp: FA[F, P]) extends FA[F, A]

final def pure[F[_], A](a: A): FA[F, A] =
Pure(a)

final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] = Ap(fp, f)
final def ap[F[_], P, A](fp: F[P])(f: FA[F, P => A]): FA[F, A] =
Ap(f, Lift(fp))

final def lift[F[_], A](fa: F[A]): FA[F, A] =
ap(fa)(Pure(a => a))
Lift(fa)

implicit final def freeApplicative[S[_]]: Applicative[FA[S, ?]] = {
new Applicative[FA[S, ?]] {
override def product[A, B](fa: FA[S, A], fb: FA[S, B]): FA[S, (A, B)] = ap(fa.map((a: A) => (b: B) => (a, b)))(fb)

override def map[A, B](fa: FA[S, A])(f: A => B): FA[S, B] = fa.map(f)

override def ap[A, B](f: FA[S, A => B])(fa: FA[S, A]): FA[S, B] = fa.ap(f)

def pure[A](a: A): FA[S, A] = Pure(a)
}
}
Expand Down
90 changes: 63 additions & 27 deletions free/src/test/scala/cats/free/FreeApplicativeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,13 @@ package free
import cats.tests.CatsSuite
import cats.arrow.FunctionK
import cats.laws.discipline.{CartesianTests, ApplicativeTests, SerializableTests}
import cats.laws.discipline.arbitrary._
import cats.data.State

import org.scalacheck.{Arbitrary, Gen}

class FreeApplicativeTests extends CatsSuite {
implicit def freeApplicativeArbitrary[F[_], A](implicit F: Arbitrary[F[A]], A: Arbitrary[A]): Arbitrary[FreeApplicative[F, A]] =
Arbitrary(
Gen.oneOf(
A.arbitrary.map(FreeApplicative.pure[F, A]),
F.arbitrary.map(FreeApplicative.lift[F, A])))

implicit def freeApplicativeEq[S[_]: Applicative, A](implicit SA: Eq[S[A]]): Eq[FreeApplicative[S, A]] =
new Eq[FreeApplicative[S, A]] {
def eqv(a: FreeApplicative[S, A], b: FreeApplicative[S, A]): Boolean = {
val nt = FunctionK.id[S]
SA.eqv(a.foldMap(nt), b.foldMap(nt))
}
}
import FreeApplicativeTests._

implicit val iso = CartesianTests.Isomorphisms.invariant[FreeApplicative[Option, ?]]

Expand All @@ -34,6 +23,14 @@ class FreeApplicativeTests extends CatsSuite {
rr.toString.length should be > 0
}

test("fold/map is stack-safe") {
val r = FreeApplicative.lift[List, Int](List(333))
val rr = (1 to 70000).foldLeft(r)((r, _) => r.ap(FreeApplicative.lift[List, Int => Int](List((_: Int) + 1))))
rr.fold should be (List(333 + 70000))
val rx = (1 to 70000).foldRight(r)((_, r) => r.ap(FreeApplicative.lift[List, Int => Int](List((_: Int) + 1))))
rx.fold should be (List(333 + 70000))
}

test("FreeApplicative#fold") {
val n = 2
val o1 = Option(1)
Expand All @@ -47,23 +44,30 @@ class FreeApplicativeTests extends CatsSuite {
}

test("FreeApplicative#compile") {
val x = FreeApplicative.lift[Id, Int](1)
val y = FreeApplicative.pure[Id, Int](2)
val f = x.map(i => (j: Int) => i + j)
val nt = FunctionK.id[Id]
val r1 = y.ap(f)
val r2 = r1.compile(nt)
r1.foldMap(nt) should === (r2.foldMap(nt))
forAll { (x: FreeApplicative[List, Int], y: FreeApplicative[List, Int], nt: List ~> List) =>
x.compile(nt).fold should ===(x.foldMap(nt))
}
}

test("FreeApplicative#flatCompile") {
forAll { (x: FreeApplicative[Option, Int]) =>
val nt: Option ~> FreeApplicative[Option, ?] = new FunctionK[Option, FreeApplicative[Option, ?]] {
def apply[A](a: Option[A]): FreeApplicative[Option, A] = FreeApplicative.lift(a)
}
x.foldMap[FreeApplicative[Option, ?]](nt).fold should === (x.flatCompile[Option](nt).fold)
}
}

test("FreeApplicative#monad") {
val x = FreeApplicative.lift[Id, Int](1)
val y = FreeApplicative.pure[Id, Int](2)
val f = x.map(i => (j: Int) => i + j)
val r1 = y.ap(f)
val r2 = r1.monad
val nt = FunctionK.id[Id]
r1.foldMap(nt) should === (r2.foldMap(nt))
forAll { (x: FreeApplicative[List, Int]) =>
x.monad.foldMap(FunctionK.id) should === (x.fold)
}
}

test("FreeApplicative#ap") {
val x = FreeApplicative.ap[Id, Int, Int](1)(FreeApplicative.pure((_: Int) + 1))
val y = FreeApplicative.lift[Id, Int](1).ap(FreeApplicative.pure((_: Int) + 1))
x should === (y)
}

// Ensure that syntax and implicit resolution work as expected.
Expand Down Expand Up @@ -126,3 +130,35 @@ class FreeApplicativeTests extends CatsSuite {
z.analyze(asString) should === ("xy")
}
}

object FreeApplicativeTests {
private def freeGen[F[_], A](maxDepth: Int)(implicit F: Arbitrary[F[A]], FF: Arbitrary[(A, A) => A], A: Arbitrary[A]): Gen[FreeApplicative[F, A]] = {
val noFlatMapped = Gen.oneOf(
A.arbitrary.map(FreeApplicative.pure[F, A]),
F.arbitrary.map(FreeApplicative.lift[F, A]))

val nextDepth = Gen.chooseNum(1, math.max(1, maxDepth - 1))

def withFlatMapped = for {
fDepth <- nextDepth
freeDepth <- nextDepth
ff <- FF.arbitrary
f <- freeGen[F, A](fDepth).map(_.map(l => (u: A) => ff(l, u)))
freeFA <- freeGen[F, A](freeDepth)
} yield freeFA.ap(f)

if (maxDepth <= 1) noFlatMapped
else Gen.oneOf(noFlatMapped, withFlatMapped)
}

implicit def freeArbitrary[F[_], A](implicit F: Arbitrary[F[A]], FF: Arbitrary[(A, A) => A], A: Arbitrary[A]): Arbitrary[FreeApplicative[F, A]] =
Arbitrary(freeGen[F, A](2))

implicit def freeApplicativeEq[S[_]: Applicative, A](implicit SA: Eq[S[A]]): Eq[FreeApplicative[S, A]] =
new Eq[FreeApplicative[S, A]] {
def eqv(a: FreeApplicative[S, A], b: FreeApplicative[S, A]): Boolean = {
SA.eqv(a.fold, b.fold)
}
}

}
Loading

0 comments on commit 92b6edc

Please sign in to comment.