Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MonadRec instances for Eval and StateT. #1076

Merged
merged 1 commit into from
May 31, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
MonadRec instances for Eval and StateT.
  • Loading branch information
TomasMikula committed May 31, 2016
commit de5d911e02f2fb70d07da05ce1623b4d8b9b8d3e
10 changes: 8 additions & 2 deletions core/src/main/scala/cats/Eval.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cats

import scala.annotation.tailrec
import cats.data.Xor
import cats.syntax.all._

/**
Expand Down Expand Up @@ -294,14 +295,19 @@ object Eval extends EvalInstances {

private[cats] trait EvalInstances extends EvalInstances0 {

implicit val evalBimonad: Bimonad[Eval] =
new Bimonad[Eval] {
implicit val evalBimonad: Bimonad[Eval] with MonadRec[Eval] =
new Bimonad[Eval] with MonadRec[Eval] {
override def map[A, B](fa: Eval[A])(f: A => B): Eval[B] = fa.map(f)
def pure[A](a: A): Eval[A] = Now(a)
override def pureEval[A](la: Eval[A]): Eval[A] = la
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
def extract[A](la: Eval[A]): A = la.value
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
def tailRecM[A, B](a: A)(f: A => Eval[A Xor B]): Eval[B] =
f(a).flatMap(_ match {
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since flatMap is lazy
case Xor.Right(b) => Eval.now(b)
})
}

implicit def evalOrder[A: Order]: Order[Eval[A]] =
Expand Down
52 changes: 36 additions & 16 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,22 +135,9 @@ object StateT extends StateTInstances {
StateT(s => F.pure((s, a)))
}

private[data] sealed abstract class StateTInstances {
implicit def catsDataMonadStateForStateT[F[_], S](implicit F: Monad[F]): MonadState[StateT[F, S, ?], S] =
new MonadState[StateT[F, S, ?], S] {
def pure[A](a: A): StateT[F, S, A] =
StateT.pure(a)

def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

val get: StateT[F, S, S] = StateT(a => F.pure((a, a)))

def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ())))

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] =
fa.map(f)
}
private[data] sealed abstract class StateTInstances extends StateTInstances1 {
implicit def catsDataMonadStateForStateT[F[_], S](implicit F0: Monad[F]): MonadState[StateT[F, S, ?], S] =
new StateTMonadState[F, S] { implicit def F = F0 }

implicit def catsDataLiftForStateT[S]: TransLift.Aux[StateT[?[_], S, ?], Applicative] =
new TransLift[StateT[?[_], S, ?]] {
Expand All @@ -161,6 +148,11 @@ private[data] sealed abstract class StateTInstances {

}

private[data] sealed abstract class StateTInstances1 {
implicit def catsDataMonadRecForStateT[F[_], S](implicit F0: MonadRec[F]): MonadRec[StateT[F, S, ?]] =
new StateTMonadRec[F, S] { implicit def F = F0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {
Expand Down Expand Up @@ -193,3 +185,31 @@ private[data] abstract class StateFunctions {
*/
def set[S](s: S): State[S, Unit] = State(_ => (s, ()))
}

private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
implicit def F: Monad[F]

def pure[A](a: A): StateT[F, S, A] =
StateT.pure(a)

def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] =
fa.map(f)
}

private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F, S, ?], S] with StateTMonad[F, S] {
val get: StateT[F, S, S] = StateT(s => F.pure((s, s)))

def set(s: S): StateT[F, S, Unit] = StateT(_ => F.pure((s, ())))
}

private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S, ?]] with StateTMonad[F, S] {
override implicit def F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => StateT[F, S, A Xor B]): StateT[F, S, B] =
StateT[F, S, B](s => F.tailRecM[(S, A), (S, B)]((s, a)) {
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}
4 changes: 3 additions & 1 deletion tests/src/test/scala/cats/tests/EvalTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ package tests

import scala.math.min
import cats.laws.ComonadLaws
import cats.laws.discipline.{CartesianTests, BimonadTests, SerializableTests}
import cats.laws.discipline.{BimonadTests, CartesianTests, MonadRecTests, SerializableTests}
import cats.laws.discipline.arbitrary._
import cats.kernel.laws.{GroupLaws, OrderLaws}

Expand Down Expand Up @@ -93,8 +93,10 @@ class EvalTests extends CatsSuite {
{
implicit val iso = CartesianTests.Isomorphisms.invariant[Eval]
checkAll("Eval[Int]", BimonadTests[Eval].bimonad[Int, Int, Int])
checkAll("Eval[Int]", MonadRecTests[Eval].monadRec[Int, Int, Int])
}
checkAll("Bimonad[Eval]", SerializableTests.serializable(Bimonad[Eval]))
checkAll("MonadRec[Eval]", SerializableTests.serializable(MonadRec[Eval]))

checkAll("Eval[Int]", GroupLaws[Eval[Int]].group)

Expand Down
11 changes: 10 additions & 1 deletion tests/src/test/scala/cats/tests/MonadRecInstancesTests.scala
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package cats
package tests

import cats.data.{OptionT, Xor, XorT}
import cats.data.{OptionT, StateT, Xor, XorT}

class MonadRecInstancesTests extends CatsSuite {
def tailRecMStackSafety[M[_]](implicit M: MonadRec[M], Eq: Eq[M[Int]]): Unit = {
Expand Down Expand Up @@ -38,4 +38,13 @@ class MonadRecInstancesTests extends CatsSuite {
tailRecMStackSafety[List]
}

test("tailRecM stack-safety for Eval") {
tailRecMStackSafety[Eval]
}

test("tailRecM stack-safety for StateT") {
import StateTTests._ // import implicit Eq[StateT[...]]
tailRecMStackSafety[StateT[Option, Int, ?]]
}

}
12 changes: 10 additions & 2 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ package cats
package tests

import cats.kernel.std.tuple._
import cats.laws.discipline.{CartesianTests, MonadStateTests, SerializableTests}
import cats.laws.discipline.{CartesianTests, MonadRecTests, MonadStateTests, SerializableTests}
import cats.data.{State, StateT}
import cats.laws.discipline.eq._
import cats.laws.discipline.arbitrary._
Expand Down Expand Up @@ -116,14 +116,22 @@ class StateTTests extends CatsSuite {

{
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]]

checkAll("StateT[Option, Int, Int]", MonadStateTests[StateT[Option, Int, ?], Int].monadState[Int, Int, Int])
checkAll("MonadState[StateT[Option, ?, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int]))
checkAll("MonadState[StateT[Option, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int]))

checkAll("StateT[Option, Int, Int]", MonadRecTests[StateT[Option, Int, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[StateT[Option, Int, ?]]", SerializableTests.serializable(MonadRec[StateT[Option, Int, ?]]))
}

{
implicit val iso = CartesianTests.Isomorphisms.invariant[State[Long, ?]]

checkAll("State[Long, ?]", MonadStateTests[State[Long, ?], Long].monadState[Int, Int, Int])
checkAll("MonadState[State[Long, ?], Long]", SerializableTests.serializable(MonadState[State[Long, ?], Long]))

checkAll("State[Long, ?]", MonadRecTests[State[Long, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[State[Long, ?]]", SerializableTests.serializable(MonadRec[State[Long, ?]]))
}
}

Expand Down