Skip to content

Commit

Permalink
StateT representation using functions with stack-safe composition.
Browse files Browse the repository at this point in the history
  • Loading branch information
TomasMikula committed May 31, 2016
1 parent 2fb4d1d commit 5f125e6
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 42 deletions.
116 changes: 74 additions & 42 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
@@ -1,82 +1,81 @@
package cats
package data

import Fun.->

/**
* `StateT[F, S, A]` is similar to `Kleisli[F, S, A]` in that it takes an `S`
* argument and produces an `A` value wrapped in `F`. However, it also produces
* an `S` value representing the updated state (which is wrapped in the `F`
* context along with the `A` value.
*/
final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable {

def flatMap[B](fas: A => StateT[F, S, B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
fas(a).run(s)
}
})

def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] =
StateT(s =>
F.flatMap(runF) { fsf =>
F.flatMap(fsf(s)) { case (s, a) =>
F.map(faf(a))((s, _))
}
final class StateT[F[_], S, A](private val runF: S -> F[(S, A)]) extends Serializable {

def flatMap[B](f: A => StateT[F, S, B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT(runF andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
f(a).runF(s)
}
)
})

def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] =
def flatMapF[B](f: A => F[B])(implicit F: FlatMap[F]): StateT[F, S, B] =
StateT(runF andThen { fsa =>
F.flatMap(fsa) { case (s, a) =>
F.map(f(a))((s, _))
}
})

def map[B](f: A => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform { case (s, a) => (s, f(a)) }

/**
* Run with the provided initial state value
*/
def run(initial: S)(implicit F: FlatMap[F]): F[(S, A)] =
F.flatMap(runF)(f => f(initial))
def run(initial: S): F[(S, A)] =
runF(initial)

/**
* Run with the provided initial state value and return the final state
* (discarding the final value).
*/
def runS(s: S)(implicit F: FlatMap[F]): F[S] = F.map(run(s))(_._1)
def runS(s: S)(implicit F: Functor[F]): F[S] = F.map(run(s))(_._1)

/**
* Run with the provided initial state value and return the final value
* (discarding the final state).
*/
def runA(s: S)(implicit F: FlatMap[F]): F[A] = F.map(run(s))(_._2)
def runA(s: S)(implicit F: Functor[F]): F[A] = F.map(run(s))(_._2)

/**
* Run with `S`'s empty monoid value as the initial state.
*/
def runEmpty(implicit S: Monoid[S], F: FlatMap[F]): F[(S, A)] = run(S.empty)
def runEmpty(implicit S: Monoid[S]): F[(S, A)] = run(S.empty)

/**
* Run with `S`'s empty monoid value as the initial state and return the final
* state (discarding the final value).
*/
def runEmptyS(implicit S: Monoid[S], F: FlatMap[F]): F[S] = runS(S.empty)
def runEmptyS(implicit S: Monoid[S], F: Functor[F]): F[S] = runS(S.empty)

/**
* Run with `S`'s empty monoid value as the initial state and return the final
* value (discarding the final state).
*/
def runEmptyA(implicit S: Monoid[S], F: FlatMap[F]): F[A] = runA(S.empty)
def runEmptyA(implicit S: Monoid[S], F: Functor[F]): F[A] = runA(S.empty)

/**
* Like [[map]], but also allows the state (`S`) value to be modified.
*/
def transform[B](f: (S, A) => (S, B))(implicit F: Monad[F]): StateT[F, S, B] =
def transform[B](f: (S, A) => (S, B))(implicit F: Functor[F]): StateT[F, S, B] =
transformF { fsa =>
F.map(fsa){ case (s, a) => f(s, a) }
}

/**
* Like [[transform]], but allows the context to change from `F` to `G`.
*/
def transformF[G[_], B](f: F[(S, A)] => G[(S, B)])(implicit F: FlatMap[F], G: Applicative[G]): StateT[G, S, B] =
StateT(s => f(run(s)))
def transformF[G[_], B](f: F[(S, A)] => G[(S, B)]): StateT[G, S, B] =
StateT(runF andThen f)

/**
* Transform the state used.
Expand All @@ -96,40 +95,36 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable
* res1: Option[(GlobalEnv, Double)] = Some(((6,hello),5.0))
* }}}
*/
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Monad[F]): StateT[F, R, A] =
def transformS[R](f: R => S, g: (R, S) => R)(implicit F: Functor[F]): StateT[F, R, A] =
StateT { r =>
F.flatMap(runF) { ff =>
val s = f(r)
val nextState = ff(s)
F.map(nextState) { case (s, a) => (g(r, s), a) }
}
F.map(runF(f(r))) { case (s, a) => (g(r, s), a) }
}

/**
* Modify the state (`S`) component.
*/
def modify(f: S => S)(implicit F: Monad[F]): StateT[F, S, A] =
def modify(f: S => S)(implicit F: Functor[F]): StateT[F, S, A] =
transform((s, a) => (f(s), a))

/**
* Inspect a value from the input state, without modifying the state.
*/
def inspect[B](f: S => B)(implicit F: Monad[F]): StateT[F, S, B] =
def inspect[B](f: S => B)(implicit F: Functor[F]): StateT[F, S, B] =
transform((s, _) => (s, f(s)))

/**
* Get the input state, without modifying the state.
*/
def get(implicit F: Monad[F]): StateT[F, S, S] =
def get(implicit F: Functor[F]): StateT[F, S, S] =
inspect(identity)
}

object StateT extends StateTInstances {
def apply[F[_], S, A](f: S => F[(S, A)])(implicit F: Applicative[F]): StateT[F, S, A] =
new StateT(F.pure(f))
private[StateT] def apply[F[_], S, A](f: S -> F[(S, A)]): StateT[F, S, A] =
new StateT(f)

def applyF[F[_], S, A](runF: F[S => F[(S, A)]]): StateT[F, S, A] =
new StateT(runF)
def apply[F[_], S, A](f: S => F[(S, A)]): StateT[F, S, A] =
StateT(Fun.Wrap(f))

def pure[F[_], S, A](a: A)(implicit F: Applicative[F]): StateT[F, S, A] =
StateT(s => F.pure((s, a)))
Expand Down Expand Up @@ -158,7 +153,7 @@ private[data] sealed abstract class StateTInstances1 {
private[data] abstract class StateFunctions {

def apply[S, A](f: S => (S, A)): State[S, A] =
StateT.applyF(Now((s: S) => Now(f(s))))
StateT((s: S) => Now(f(s)))

/**
* Return `a` and maintain the input state.
Expand Down Expand Up @@ -213,3 +208,40 @@ private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S,
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}

/** Function with stack-safe composition. */
private[data] sealed abstract class Fun[A, B] {
import Fun._

def apply(a: A): B = Evaluator(a, this).eval
def andThen[C](g: Fun[B, C]): Fun[A, C] = AndThen(this, g)
def andThen[C](g: B => C): Fun[A, C] = andThen(Wrap(g))
}

private[data] object Fun {
type ->[A, B] = Fun[A, B]

final case class Wrap[A, B](f: A => B) extends Fun[A, B]
final case class AndThen[A, B, C](f: Fun[A, B], g: Fun[B, C]) extends Fun[A, C]

private abstract class Evaluator[B] {
type A
val a: A
val fun: Fun[A, B]

@annotation.tailrec
final def eval: B = fun match {
case Wrap(f) => f(a)
case AndThen(Wrap(f), g) => Evaluator(f(a), g).eval
case AndThen(AndThen(f, g), h) => Evaluator(a, f andThen (g andThen h)).eval
}
}

private object Evaluator {
def apply[A0, B](a0: A0, f: Fun[A0, B]): Evaluator[B] = new Evaluator[B] {
type A = A0
val a = a0
val fun = f
}
}
}
5 changes: 5 additions & 0 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,11 @@ class StateTTests extends CatsSuite {
x.runS(0).value should === (100001)
}

test("10000 maps is stack-safe"){
val x = (0 until 10000).foldLeft(StateT.pure[Id, Int, Int](0))((s, i) => s.map(_ + 1))
x.runA(0) should === (10000)
}

test("State.pure and StateT.pure are consistent"){
forAll { (s: String, i: Int) =>
val state: State[String, Int] = State.pure(i)
Expand Down

0 comments on commit 5f125e6

Please sign in to comment.