Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some MonadRec instances. #1041

Merged
merged 2 commits into from
May 31, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions core/src/main/scala/cats/FlatMapRec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package cats

import simulacrum.typeclass

import cats.data.Xor

/**
* Version of [[cats.FlatMap]] capable of stack-safe recursive `flatMap`s.
*
* Based on Phil Freeman's
* [[http://functorial.com/stack-safety-for-free/index.pdf Stack Safety for Free]].
*/
@typeclass trait FlatMapRec[F[_]] extends FlatMap[F] {

/**
* Keeps calling `f` until a `[[cats.data.Xor.Right Right]][B]` is returned.
*
* Implementations of this method must use constant stack space.
*
* `f` must use constant stack space. (It is OK to use a constant number of
* `map`s and `flatMap`s inside `f`.)
*/
def tailRecM[A, B](a: A)(f: A => F[A Xor B]): F[B]
}
5 changes: 5 additions & 0 deletions core/src/main/scala/cats/MonadRec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
package cats

import simulacrum.typeclass

@typeclass trait MonadRec[F[_]] extends Monad[F] with FlatMapRec[F]
42 changes: 31 additions & 11 deletions core/src/main/scala/cats/data/OptionT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ object OptionT extends OptionTInstances {
def liftF[F[_], A](fa: F[A])(implicit F: Functor[F]): OptionT[F, A] = OptionT(F.map(fa)(Some(_)))
}

private[data] sealed trait OptionTInstances1 {
private[data] sealed trait OptionTInstances2 {
implicit def catsDataFunctorForOptionT[F[_]:Functor]: Functor[OptionT[F, ?]] =
new Functor[OptionT[F, ?]] {
override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
Expand All @@ -148,22 +148,42 @@ private[data] sealed trait OptionTInstances1 {
}
}

private[data] sealed trait OptionTInstances extends OptionTInstances1 {

implicit def catsDataMonadForOptionT[F[_]](implicit F: Monad[F]): Monad[OptionT[F, ?]] =
new Monad[OptionT[F, ?]] {
def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)
private[data] sealed trait OptionTInstances1 extends OptionTInstances2 {

def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
fa.flatMap(f)
implicit def catsDataMonadForOptionT[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] =
new OptionTMonad[F] { implicit val F = F0 }
}

override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
fa.map(f)
}
private[data] sealed trait OptionTInstances extends OptionTInstances1 {
implicit def catsDataMonadRecForOptionT[F[_]](implicit F0: MonadRec[F]): MonadRec[OptionT[F, ?]] =
new OptionTMonadRec[F] { implicit val F = F0 }

implicit def catsDataEqForOptionT[F[_], A](implicit FA: Eq[F[Option[A]]]): Eq[OptionT[F, A]] =
FA.on(_.value)

implicit def catsDataShowForOptionT[F[_], A](implicit F: Show[F[Option[A]]]): Show[OptionT[F, A]] =
functor.Contravariant[Show].contramap(F)(_.value)
}

private[data] trait OptionTMonad[F[_]] extends Monad[OptionT[F, ?]] {
implicit val F: Monad[F]

def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)

def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
fa.flatMap(f)

override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
fa.map(f)
}

private[data] trait OptionTMonadRec[F[_]] extends MonadRec[OptionT[F, ?]] with OptionTMonad[F] {
implicit val F: MonadRec[F]

def tailRecM[A, B](a: A)(f: A => OptionT[F, A Xor B]): OptionT[F, B] =
OptionT(F.tailRecM(a)(a0 => F.map(f(a0).value){
case None => Xor.Right(None)
case Some(Xor.Left(a1)) => Xor.Left(a1)
case Some(Xor.Right(b)) => Xor.Right(Some(b))
}))
}
11 changes: 9 additions & 2 deletions core/src/main/scala/cats/data/Xor.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cats
package data

import scala.annotation.tailrec
import scala.reflect.ClassTag
import scala.util.{Failure, Success, Try}

Expand Down Expand Up @@ -233,13 +234,19 @@ private[data] sealed abstract class XorInstances extends XorInstances1 {
}
}

implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadError[Xor[A, ?], A] =
new Traverse[A Xor ?] with MonadError[Xor[A, ?], A] {
implicit def catsDataInstancesForXor[A]: Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] =
new Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] {
def traverse[F[_]: Applicative, B, C](fa: A Xor B)(f: B => F[C]): F[A Xor C] = fa.traverse(f)
def foldLeft[B, C](fa: A Xor B, c: C)(f: (C, B) => C): C = fa.foldLeft(c)(f)
def foldRight[B, C](fa: A Xor B, lc: Eval[C])(f: (B, Eval[C]) => Eval[C]): Eval[C] = fa.foldRight(lc)(f)
def flatMap[B, C](fa: A Xor B)(f: B => A Xor C): A Xor C = fa.flatMap(f)
def pure[B](b: B): A Xor B = Xor.right(b)
@tailrec def tailRecM[B, C](b: B)(f: B => A Xor (B Xor C)): A Xor C =
f(b) match {
case Xor.Left(a) => Xor.Left(a)
case Xor.Right(Xor.Left(b1)) => tailRecM(b1)(f)
case Xor.Right(Xor.Right(c)) => Xor.Right(c)
}
def handleErrorWith[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] =
fea match {
case Xor.Left(e) => f(e)
Expand Down
22 changes: 20 additions & 2 deletions core/src/main/scala/cats/data/XorT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,11 @@ private[data] abstract class XorTInstances1 extends XorTInstances2 {
}

private[data] abstract class XorTInstances2 extends XorTInstances3 {
implicit def catsDataMonadRecForXorT[F[_], L](implicit F0: MonadRec[F]): MonadRec[XorT[F, L, ?]] =
new XorTMonadRec[F, L] { implicit val F = F0 }
}

private[data] abstract class XorTInstances3 extends XorTInstances4 {
implicit def catsDataMonadErrorForXorT[F[_], L](implicit F: Monad[F]): MonadError[XorT[F, L, ?], L] = {
implicit val F0 = F
new XorTMonadError[F, L] { implicit val F = F0 }
Expand All @@ -299,7 +304,7 @@ private[data] abstract class XorTInstances2 extends XorTInstances3 {
}
}

private[data] abstract class XorTInstances3 {
private[data] abstract class XorTInstances4 {
implicit def catsDataFunctorForXorT[F[_], L](implicit F: Functor[F]): Functor[XorT[F, L, ?]] = {
implicit val F0 = F
new XorTFunctor[F, L] { implicit val F = F0 }
Expand All @@ -311,10 +316,13 @@ private[data] trait XorTFunctor[F[_], L] extends Functor[XorT[F, L, ?]] {
override def map[A, B](fa: XorT[F, L, A])(f: A => B): XorT[F, L, B] = fa map f
}

private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTFunctor[F, L] {
private[data] trait XorTMonad[F[_], L] extends Monad[XorT[F, L, ?]] with XorTFunctor[F, L] {
implicit val F: Monad[F]
def pure[A](a: A): XorT[F, L, A] = XorT.pure[F, L, A](a)
def flatMap[A, B](fa: XorT[F, L, A])(f: A => XorT[F, L, B]): XorT[F, L, B] = fa flatMap f
}

private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTMonad[F, L] {
def handleErrorWith[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] =
XorT(F.flatMap(fea.value) {
case Xor.Left(e) => f(e).value
Expand All @@ -333,6 +341,16 @@ private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L]
fla.recoverWith(pf)
}

private[data] trait XorTMonadRec[F[_], L] extends MonadRec[XorT[F, L, ?]] with XorTMonad[F, L] {
implicit val F: MonadRec[F]
def tailRecM[A, B](a: A)(f: A => XorT[F, L, A Xor B]): XorT[F, L, B] =
XorT(F.tailRecM(a)(a0 => F.map(f(a0).value){
case Xor.Left(l) => Xor.Right(Xor.Left(l))
case Xor.Right(Xor.Left(a1)) => Xor.Left(a1)
case Xor.Right(Xor.Right(b)) => Xor.Right(Xor.Right(b))
}))
}

private[data] trait XorTMonadFilter[F[_], L] extends MonadFilter[XorT[F, L, ?]] with XorTMonadError[F, L] {
implicit val F: Monad[F]
implicit val L: Monoid[L]
Expand Down
11 changes: 9 additions & 2 deletions core/src/main/scala/cats/package.scala
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import scala.annotation.tailrec
import cats.data.Xor

/**
* Symbolic aliases for various types are defined here.
*/
Expand Down Expand Up @@ -26,12 +29,16 @@ package object cats {
* encodes pure unary function application.
*/
type Id[A] = A
implicit val idInstances: Bimonad[Id] with Traverse[Id] =
new Bimonad[Id] with Traverse[Id] {
implicit val idInstances: Bimonad[Id] with MonadRec[Id] with Traverse[Id] =
new Bimonad[Id] with MonadRec[Id] with Traverse[Id] {
def pure[A](a: A): A = a
def extract[A](a: A): A = a
def flatMap[A, B](a: A)(f: A => B): B = f(a)
def coflatMap[A, B](a: A)(f: A => B): B = f(a)
@tailrec def tailRecM[A, B](a: A)(f: A => A Xor B): B = f(a) match {
case Xor.Left(a1) => tailRecM(a1)(f)
case Xor.Right(b) => b
}
override def map[A, B](fa: A)(f: A => B): B = f(fa)
override def ap[A, B](ff: A => B)(fa: A): B = ff(fa)
override def flatten[A](ffa: A): A = ffa
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/std/either.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
package cats
package std

import scala.annotation.tailrec
import cats.data.Xor

trait EitherInstances extends EitherInstances1 {
implicit val catsStdBitraverseForEither: Bitraverse[Either] =
new Bitraverse[Either] {
Expand All @@ -23,8 +26,8 @@ trait EitherInstances extends EitherInstances1 {
}
}

implicit def catsStdInstancesForEither[A]: Monad[Either[A, ?]] with Traverse[Either[A, ?]] =
new Monad[Either[A, ?]] with Traverse[Either[A, ?]] {
implicit def catsStdInstancesForEither[A]: MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] =
new MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] {
def pure[B](b: B): Either[A, B] = Right(b)

def flatMap[B, C](fa: Either[A, B])(f: B => Either[A, C]): Either[A, C] =
Expand All @@ -33,6 +36,14 @@ trait EitherInstances extends EitherInstances1 {
override def map[B, C](fa: Either[A, B])(f: B => C): Either[A, C] =
fa.right.map(f)

@tailrec
def tailRecM[B, C](b: B)(f: B => Either[A, B Xor C]): Either[A, C] =
f(b) match {
case Left(a) => Left(a)
case Right(Xor.Left(b1)) => tailRecM(b1)(f)
case Right(Xor.Right(c)) => Right(c)
}

override def map2Eval[B, C, Z](fb: Either[A, B], fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] =
fb match {
// This should be safe, but we are forced to use `asInstanceOf`,
Expand Down
20 changes: 18 additions & 2 deletions core/src/main/scala/cats/std/list.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import cats.syntax.show._
import scala.annotation.tailrec
import scala.collection.mutable.ListBuffer

import cats.data.Xor

trait ListInstances extends cats.kernel.std.ListInstances {

implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with CoflatMap[List] =
new Traverse[List] with MonadCombine[List] with CoflatMap[List] {
implicit val catsStdInstancesForList: Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] =
new Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] {

def empty[A]: List[A] = Nil

Expand All @@ -26,6 +28,20 @@ trait ListInstances extends cats.kernel.std.ListInstances {
override def map2[A, B, Z](fa: List[A], fb: List[B])(f: (A, B) => Z): List[Z] =
fa.flatMap(a => fb.map(b => f(a, b)))

def tailRecM[A, B](a: A)(f: A => List[A Xor B]): List[B] = {
val buf = List.newBuilder[B]
@tailrec def go(lists: List[List[A Xor B]]): Unit = lists match {
case (ab :: abs) :: tail => ab match {
case Xor.Right(b) => buf += b; go(abs :: tail)
case Xor.Left(a) => go(f(a) :: abs :: tail)
}
case Nil :: tail => go(tail)
case Nil => ()
}
go(f(a) :: Nil)
buf.result
}

def coflatMap[A, B](fa: List[A])(f: List[A] => B): List[B] = {
@tailrec def loop(buf: ListBuffer[B], as: List[A]): List[B] =
as match {
Expand Down
15 changes: 13 additions & 2 deletions core/src/main/scala/cats/std/option.scala
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
package cats
package std

import scala.annotation.tailrec
import cats.data.Xor

trait OptionInstances extends cats.kernel.std.OptionInstances {

implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] =
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] {
implicit val catsStdInstancesForOption: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] =
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] {

def empty[A]: Option[A] = None

Expand All @@ -18,6 +21,14 @@ trait OptionInstances extends cats.kernel.std.OptionInstances {
def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] =
fa.flatMap(f)

@tailrec
def tailRecM[A, B](a: A)(f: A => Option[A Xor B]): Option[B] =
f(a) match {
case None => None
case Some(Xor.Left(a1)) => tailRecM(a1)(f)
case Some(Xor.Right(b)) => Some(b)
}

override def map2[A, B, Z](fa: Option[A], fb: Option[B])(f: (A, B) => Z): Option[Z] =
fa.flatMap(a => fb.map(b => f(a, b)))

Expand Down
9 changes: 7 additions & 2 deletions free/src/main/scala/cats/free/Free.scala
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,16 @@ object Free {
/**
* `Free[S, ?]` has a monad for any type constructor `S[_]`.
*/
implicit def freeMonad[S[_]]: Monad[Free[S, ?]] =
new Monad[Free[S, ?]] {
implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] =
new MonadRec[Free[S, ?]] {
def pure[A](a: A): Free[S, A] = Free.pure(a)
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] =
f(a).flatMap(_ match {
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy
case Xor.Right(b) => pure(b)
})
}
}

Expand Down
14 changes: 11 additions & 3 deletions free/src/test/scala/cats/free/FreeTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@ package free

import cats.tests.CatsSuite
import cats.arrow.NaturalTransformation
import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests}
import cats.data.Xor
import cats.laws.discipline.{CartesianTests, MonadRecTests, SerializableTests}
import cats.laws.discipline.arbitrary.function0Arbitrary

import org.scalacheck.{Arbitrary, Gen}
Expand All @@ -14,8 +15,8 @@ class FreeTests extends CatsSuite {

implicit val iso = CartesianTests.Isomorphisms.invariant[Free[Option, ?]]

checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int])
checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]]))
checkAll("Free[Option, ?]", MonadRecTests[Free[Option, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[Free[Option, ?]]", SerializableTests.serializable(MonadRec[Free[Option, ?]]))

test("mapSuspension id"){
forAll { x: Free[List, Int] =>
Expand Down Expand Up @@ -43,6 +44,13 @@ class FreeTests extends CatsSuite {
}
}

test("tailRecM is stack safe") {
val n = 50000
val fa = MonadRec[Free[Option, ?]].tailRecM(0)(i =>
Free.pure[Option, Int Xor Int](if(i < n) Xor.Left(i+1) else Xor.Right(i)))
fa should === (Free.pure[Option, Int](n))
}

ignore("foldMap is stack safe") {
trait FTestApi[A]
case class TB(i: Int) extends FTestApi[Int]
Expand Down
26 changes: 26 additions & 0 deletions laws/src/main/scala/cats/laws/FlatMapRecLaws.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
package cats
package laws

import cats.data.Xor
import cats.syntax.flatMap._
import cats.syntax.functor._

/**
* Laws that must be obeyed by any `FlatMapRec`.
*/
trait FlatMapRecLaws[F[_]] extends FlatMapLaws[F] {
implicit override def F: FlatMapRec[F]

def tailRecMConsistentFlatMap[A](a: A, f: A => F[A]): IsEq[F[A]] = {
val bounce = F.tailRecM[(A, Int), A]((a, 1)) { case (a0, i) =>
if(i > 0) f(a0).map(a1 => Xor.left((a1, i-1)))
else f(a0).map(Xor.right)
}
bounce <-> f(a).flatMap(f)
}
}

object FlatMapRecLaws {
def apply[F[_]](implicit ev: FlatMapRec[F]): FlatMapRecLaws[F] =
new FlatMapRecLaws[F] { def F: FlatMapRec[F] = ev }
}
Loading