diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 169e7832d7e..fb10a718792 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -1,73 +1,72 @@ package cats package data +import Fun.-> + /** * `StateT[F, S, A]` is similar to `Kleisli[F, S, A]` in that it takes an `S` * argument and produces an `A` value wrapped in `F`. However, it also produces * an `S` value representing the updated state (which is wrapped in the `F` * context along with the `A` value. */ -final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable { - - def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => - fas(a).run(s) - } - }) - - def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] = - StateT(s => - F.flatMap(runF) { fsf => - F.flatMap(fsf(s)) { case (s, a) => - F.map(faf(a))((s, _)) - } +final class StateT[F[_], S, A](private val runF: S -> F[(S, A)]) extends Serializable { + + def flatMap[B](f: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] = + StateT(runF andThen { fsa => + F.flatMap(fsa) { case (s, a) => + f(a).runF(s) } - ) + }) - def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = + def flatMapF[B](f: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] = + StateT(runF andThen { fsa => + F.flatMap(fsa) { case (s, a) => + F.map(f(a))((s, _)) + } + }) + + def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } /** * Run with the provided initial state value */ - def run(initial: S)(implicit F: FlatMap[F]): F[(S, A)] = - F.flatMap(runF)(f => f(initial)) + def run(initial: S): F[(S, A)] = + runF(initial) /** * Run with the provided initial state value and return the final state * (discarding the final value). */ - def runS(s: S)(implicit F: FlatMap[F]): F[S] = F.map(run(s))(_._1) + def runS(s: S)(implicit F: Functor[F]): F[S] = F.map(run(s))(_._1) /** * Run with the provided initial state value and return the final value * (discarding the final state). */ - def runA(s: S)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2) + def runA(s: S)(implicit F: Functor[F]): F[A] = F.map(run(s))(_._2) /** * Run with `S`'s empty monoid value as the initial state. */ - def runEmpty(implicit S: Monoid[S], F: FlatMap[F]): F[(S, A)] = run(S.empty) + def runEmpty(implicit S: Monoid[S]): F[(S, A)] = run(S.empty) /** * Run with `S`'s empty monoid value as the initial state and return the final * state (discarding the final value). */ - def runEmptyS(implicit S: Monoid[S], F: FlatMap[F]): F[S] = runS(S.empty) + def runEmptyS(implicit S: Monoid[S], F: Functor[F]): F[S] = runS(S.empty) /** * Run with `S`'s empty monoid value as the initial state and return the final * value (discarding the final state). */ - def runEmptyA(implicit S: Monoid[S], F: FlatMap[F]): F[A] = runA(S.empty) + def runEmptyA(implicit S: Monoid[S], F: Functor[F]): F[A] = runA(S.empty) /** * Like [[map]], but also allows the state (`S`) value to be modified. */ - def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] = + def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] = transformF { fsa => F.map(fsa){ case (s, a) => f(s, a) } } @@ -75,8 +74,8 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable /** * Like [[transform]], but allows the context to change from `F` to `G`. */ - def transformF[G[_], B](f: F[(S, A)] => G[(S, B)])(implicit F: FlatMap[F], G: Applicative[G]): StateT[G, S, B] = - StateT(s => f(run(s))) + def transformF[G[_], B](f: F[(S, A)] => G[(S, B)]): StateT[G, S, B] = + StateT(runF andThen f) /** * Transform the state used. @@ -96,40 +95,36 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable * res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0)) * }}} */ - def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] = + def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] = StateT { r => - F.flatMap(runF) { ff => - val s = f(r) - val nextState = ff(s) - F.map(nextState) { case (s, a) => (g(r, s), a) } - } + F.map(runF(f(r))) { case (s, a) => (g(r, s), a) } } /** * Modify the state (`S`) component. */ - def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] = + def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] = transform((s, a) => (f(s), a)) /** * Inspect a value from the input state, without modifying the state. */ - def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] = + def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] = transform((s, _) => (s, f(s))) /** * Get the input state, without modifying the state. */ - def get(implicit F: Monad[F]): StateT[F, S, S] = + def get(implicit F: Functor[F]): StateT[F, S, S] = inspect(identity) } object StateT extends StateTInstances { - def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Applicative[F]): StateT[F, S, A] = - new StateT(F.pure(f)) + private[StateT] def apply[F[_], S, A](f: S -> F[(S, A)]): StateT[F, S, A] = + new StateT(f) - def applyF[F[_], S, A](runF: F[S => F[(S, A)]]): StateT[F, S, A] = - new StateT(runF) + def apply[F[_], S, A](f: S => F[(S, A)]): StateT[F, S, A] = + StateT(Fun.Wrap(f)) def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): StateT[F, S, A] = StateT(s => F.pure((s, a))) @@ -158,7 +153,7 @@ private[data] sealed abstract class StateTInstances1 { private[data] abstract class StateFunctions { def apply[S, A](f: S => (S, A)): State[S, A] = - StateT.applyF(Now((s: S) => Now(f(s)))) + StateT((s: S) => Now(f(s))) /** * Return `a` and maintain the input state. @@ -213,3 +208,40 @@ private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S, case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) } }) } + +/** Function with stack-safe composition. */ +private[data] sealed abstract class Fun[A, B] { + import Fun._ + + def apply(a: A): B = Evaluator(a, this).eval + def andThen[C](g: Fun[B, C]): Fun[A, C] = AndThen(this, g) + def andThen[C](g: B => C): Fun[A, C] = andThen(Wrap(g)) +} + +private[data] object Fun { + type ->[A, B] = Fun[A, B] + + final case class Wrap[A, B](f: A => B) extends Fun[A, B] + final case class AndThen[A, B, C](f: Fun[A, B], g: Fun[B, C]) extends Fun[A, C] + + private abstract class Evaluator[B] { + type A + val a: A + val fun: Fun[A, B] + + @annotation.tailrec + final def eval: B = fun match { + case Wrap(f) => f(a) + case AndThen(Wrap(f), g) => Evaluator(f(a), g).eval + case AndThen(AndThen(f, g), h) => Evaluator(a, f andThen (g andThen h)).eval + } + } + + private object Evaluator { + def apply[A0, B](a0: A0, f: Fun[A0, B]): Evaluator[B] = new Evaluator[B] { + type A = A0 + val a = a0 + val fun = f + } + } +} diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 1f7b3aee752..439e9b2b4fd 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -21,6 +21,11 @@ class StateTTests extends CatsSuite { x.runS(0).value should === (100001) } + test("10000 maps is stack-safe"){ + val x = (0 until 10000).foldLeft(StateT.pure[Id, Int, Int](0))((s, i) => s.map(_ + 1)) + x.runA(0) should === (10000) + } + test("State.pure and StateT.pure are consistent"){ forAll { (s: String, i: Int) => val state: State[String, Int] = State.pure(i)