Skip to content

Commit 2617222

Browse files
committed
MonadRec instances for Id, Option, OptionT, Either, Xor, XorT, Free, List.
1 parent 8a87d55 commit 2617222

File tree

20 files changed

+304
-37
lines changed

20 files changed

+304
-37
lines changed

core/src/main/scala/cats/data/OptionT.scala

Lines changed: 31 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -132,7 +132,7 @@ object OptionT extends OptionTInstances {
132132
def liftF[F[_], A](fa: F[A])(implicit F: Functor[F]): OptionT[F, A] = OptionT(F.map(fa)(Some(_)))
133133
}
134134

135-
private[data] sealed trait OptionTInstances1 {
135+
private[data] sealed trait OptionTInstances2 {
136136
implicit def optionTFunctor[F[_]:Functor]: Functor[OptionT[F, ?]] =
137137
new Functor[OptionT[F, ?]] {
138138
override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
@@ -148,22 +148,41 @@ private[data] sealed trait OptionTInstances1 {
148148
}
149149
}
150150

151-
private[data] sealed trait OptionTInstances extends OptionTInstances1 {
152-
153-
implicit def optionTMonad[F[_]](implicit F: Monad[F]): Monad[OptionT[F, ?]] =
154-
new Monad[OptionT[F, ?]] {
155-
def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)
156-
157-
def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
158-
fa.flatMap(f)
151+
private[data] sealed trait OptionTInstances1 extends OptionTInstances2 {
152+
implicit def optionTMonad[F[_]](implicit F0: Monad[F]): Monad[OptionT[F, ?]] =
153+
new OptionTMonad[F] { implicit val F = F0 }
154+
}
159155

160-
override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
161-
fa.map(f)
162-
}
156+
private[data] sealed trait OptionTInstances extends OptionTInstances1 {
157+
implicit def optionTMonadRec[F[_]](implicit F0: MonadRec[F]): MonadRec[OptionT[F, ?]] =
158+
new OptionTMonadRec[F] { implicit val F = F0 }
163159

164160
implicit def optionTEq[F[_], A](implicit FA: Eq[F[Option[A]]]): Eq[OptionT[F, A]] =
165161
FA.on(_.value)
166162

167163
implicit def optionTShow[F[_], A](implicit F: Show[F[Option[A]]]): Show[OptionT[F, A]] =
168164
functor.Contravariant[Show].contramap(F)(_.value)
169165
}
166+
167+
private[data] trait OptionTMonad[F[_]] extends Monad[OptionT[F, ?]] {
168+
implicit val F: Monad[F]
169+
170+
def pure[A](a: A): OptionT[F, A] = OptionT.pure(a)
171+
172+
def flatMap[A, B](fa: OptionT[F, A])(f: A => OptionT[F, B]): OptionT[F, B] =
173+
fa.flatMap(f)
174+
175+
override def map[A, B](fa: OptionT[F, A])(f: A => B): OptionT[F, B] =
176+
fa.map(f)
177+
}
178+
179+
private[data] trait OptionTMonadRec[F[_]] extends MonadRec[OptionT[F, ?]] with OptionTMonad[F] {
180+
implicit val F: MonadRec[F]
181+
182+
def tailRecM[A, B](a: A)(f: A => OptionT[F, A Xor B]): OptionT[F, B] =
183+
OptionT(F.tailRecM(a)(a0 => F.map(f(a0).value){
184+
case None => Xor.Right(None)
185+
case Some(Xor.Left(a1)) => Xor.Left(a1)
186+
case Some(Xor.Right(b)) => Xor.Right(Some(b))
187+
}))
188+
}

core/src/main/scala/cats/data/Xor.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package cats
22
package data
33

4+
import scala.annotation.tailrec
45
import scala.reflect.ClassTag
56
import scala.util.{Failure, Success, Try}
67

@@ -203,13 +204,19 @@ private[data] sealed abstract class XorInstances extends XorInstances1 {
203204
}
204205
}
205206

206-
implicit def xorInstances[A]: Traverse[A Xor ?] with MonadError[Xor[A, ?], A] =
207-
new Traverse[A Xor ?] with MonadError[Xor[A, ?], A] {
207+
implicit def xorInstances[A]: Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] =
208+
new Traverse[A Xor ?] with MonadRec[A Xor ?] with MonadError[Xor[A, ?], A] {
208209
def traverse[F[_]: Applicative, B, C](fa: A Xor B)(f: B => F[C]): F[A Xor C] = fa.traverse(f)
209210
def foldLeft[B, C](fa: A Xor B, c: C)(f: (C, B) => C): C = fa.foldLeft(c)(f)
210211
def foldRight[B, C](fa: A Xor B, lc: Eval[C])(f: (B, Eval[C]) => Eval[C]): Eval[C] = fa.foldRight(lc)(f)
211212
def flatMap[B, C](fa: A Xor B)(f: B => A Xor C): A Xor C = fa.flatMap(f)
212213
def pure[B](b: B): A Xor B = Xor.right(b)
214+
@tailrec def tailRecM[B, C](b: B)(f: B => A Xor (B Xor C)): A Xor C =
215+
f(b) match {
216+
case Xor.Left(a) => Xor.Left(a)
217+
case Xor.Right(Xor.Left(b1)) => tailRecM(b1)(f)
218+
case Xor.Right(Xor.Right(c)) => Xor.Right(c)
219+
}
213220
def handleErrorWith[B](fea: Xor[A, B])(f: A => Xor[A, B]): Xor[A, B] =
214221
fea match {
215222
case Xor.Left(e) => f(e)

core/src/main/scala/cats/data/XorT.scala

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,11 @@ private[data] abstract class XorTInstances1 extends XorTInstances2 {
241241
}
242242

243243
private[data] abstract class XorTInstances2 extends XorTInstances3 {
244+
implicit def xorTMonadRec[F[_], L](implicit F0: MonadRec[F]): MonadRec[XorT[F, L, ?]] =
245+
new XorTMonadRec[F, L] { implicit val F = F0 }
246+
}
247+
248+
private[data] abstract class XorTInstances3 extends XorTInstances4 {
244249
implicit def xorTMonadError[F[_], L](implicit F: Monad[F]): MonadError[XorT[F, L, ?], L] = {
245250
implicit val F0 = F
246251
new XorTMonadError[F, L] { implicit val F = F0 }
@@ -261,7 +266,7 @@ private[data] abstract class XorTInstances2 extends XorTInstances3 {
261266
}
262267
}
263268

264-
private[data] abstract class XorTInstances3 {
269+
private[data] abstract class XorTInstances4 {
265270
implicit def xorTFunctor[F[_], L](implicit F: Functor[F]): Functor[XorT[F, L, ?]] = {
266271
implicit val F0 = F
267272
new XorTFunctor[F, L] { implicit val F = F0 }
@@ -273,10 +278,13 @@ private[data] trait XorTFunctor[F[_], L] extends Functor[XorT[F, L, ?]] {
273278
override def map[A, B](fa: XorT[F, L, A])(f: A => B): XorT[F, L, B] = fa map f
274279
}
275280

276-
private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTFunctor[F, L] {
281+
private[data] trait XorTMonad[F[_], L] extends Monad[XorT[F, L, ?]] with XorTFunctor[F, L] {
277282
implicit val F: Monad[F]
278283
def pure[A](a: A): XorT[F, L, A] = XorT.pure[F, L, A](a)
279284
def flatMap[A, B](fa: XorT[F, L, A])(f: A => XorT[F, L, B]): XorT[F, L, B] = fa flatMap f
285+
}
286+
287+
private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L] with XorTMonad[F, L] {
280288
def handleErrorWith[A](fea: XorT[F, L, A])(f: L => XorT[F, L, A]): XorT[F, L, A] =
281289
XorT(F.flatMap(fea.value) {
282290
case Xor.Left(e) => f(e).value
@@ -295,6 +303,16 @@ private[data] trait XorTMonadError[F[_], L] extends MonadError[XorT[F, L, ?], L]
295303
fla.recoverWith(pf)
296304
}
297305

306+
private[data] trait XorTMonadRec[F[_], L] extends MonadRec[XorT[F, L, ?]] with XorTMonad[F, L] {
307+
implicit val F: MonadRec[F]
308+
def tailRecM[A, B](a: A)(f: A => XorT[F, L, A Xor B]): XorT[F, L, B] =
309+
XorT(F.tailRecM(a)(a0 => F.map(f(a0).value){
310+
case Xor.Left(l) => Xor.Right(Xor.Left(l))
311+
case Xor.Right(Xor.Left(a1)) => Xor.Left(a1)
312+
case Xor.Right(Xor.Right(b)) => Xor.Right(Xor.Right(b))
313+
}))
314+
}
315+
298316
private[data] trait XorTMonadFilter[F[_], L] extends MonadFilter[XorT[F, L, ?]] with XorTMonadError[F, L] {
299317
implicit val F: Monad[F]
300318
implicit val L: Monoid[L]

core/src/main/scala/cats/package.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import scala.annotation.tailrec
2+
import cats.data.Xor
3+
14
/**
25
* Symbolic aliases for various types are defined here.
36
*/
@@ -26,12 +29,16 @@ package object cats {
2629
* encodes pure unary function application.
2730
*/
2831
type Id[A] = A
29-
implicit val idInstances: Bimonad[Id] with Traverse[Id] =
30-
new Bimonad[Id] with Traverse[Id] {
32+
implicit val idInstances: Bimonad[Id] with MonadRec[Id] with Traverse[Id] =
33+
new Bimonad[Id] with MonadRec[Id] with Traverse[Id] {
3134
def pure[A](a: A): A = a
3235
def extract[A](a: A): A = a
3336
def flatMap[A, B](a: A)(f: A => B): B = f(a)
3437
def coflatMap[A, B](a: A)(f: A => B): B = f(a)
38+
@tailrec def tailRecM[A, B](a: A)(f: A => A Xor B): B = f(a) match {
39+
case Xor.Left(a1) => tailRecM(a1)(f)
40+
case Xor.Right(b) => b
41+
}
3542
override def map[A, B](fa: A)(f: A => B): B = f(fa)
3643
override def ap[A, B](ff: A => B)(fa: A): B = ff(fa)
3744
override def flatten[A](ffa: A): A = ffa

core/src/main/scala/cats/std/either.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
11
package cats
22
package std
33

4+
import scala.annotation.tailrec
5+
import cats.data.Xor
6+
47
trait EitherInstances extends EitherInstances1 {
58
implicit val eitherBitraverse: Bitraverse[Either] =
69
new Bitraverse[Either] {
@@ -23,8 +26,8 @@ trait EitherInstances extends EitherInstances1 {
2326
}
2427
}
2528

26-
implicit def eitherInstances[A]: Monad[Either[A, ?]] with Traverse[Either[A, ?]] =
27-
new Monad[Either[A, ?]] with Traverse[Either[A, ?]] {
29+
implicit def eitherInstances[A]: MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] =
30+
new MonadRec[Either[A, ?]] with Traverse[Either[A, ?]] {
2831
def pure[B](b: B): Either[A, B] = Right(b)
2932

3033
def flatMap[B, C](fa: Either[A, B])(f: B => Either[A, C]): Either[A, C] =
@@ -33,6 +36,14 @@ trait EitherInstances extends EitherInstances1 {
3336
override def map[B, C](fa: Either[A, B])(f: B => C): Either[A, C] =
3437
fa.right.map(f)
3538

39+
@tailrec
40+
def tailRecM[B, C](b: B)(f: B => Either[A, B Xor C]): Either[A, C] =
41+
f(b) match {
42+
case Left(a) => Left(a)
43+
case Right(Xor.Left(b1)) => tailRecM(b1)(f)
44+
case Right(Xor.Right(c)) => Right(c)
45+
}
46+
3647
override def map2Eval[B, C, Z](fb: Either[A, B], fc: Eval[Either[A, C]])(f: (B, C) => Z): Eval[Either[A, Z]] =
3748
fb match {
3849
// This should be safe, but we are forced to use `asInstanceOf`,

core/src/main/scala/cats/std/list.scala

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import cats.syntax.show._
66
import scala.annotation.tailrec
77
import scala.collection.mutable.ListBuffer
88

9+
import cats.data.Xor
10+
911
trait ListInstances extends cats.kernel.std.ListInstances {
1012

11-
implicit val listInstance: Traverse[List] with MonadCombine[List] with CoflatMap[List] =
12-
new Traverse[List] with MonadCombine[List] with CoflatMap[List] {
13+
implicit val listInstance: Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] =
14+
new Traverse[List] with MonadCombine[List] with MonadRec[List] with CoflatMap[List] {
1315

1416
def empty[A]: List[A] = Nil
1517

@@ -26,6 +28,20 @@ trait ListInstances extends cats.kernel.std.ListInstances {
2628
override def map2[A, B, Z](fa: List[A], fb: List[B])(f: (A, B) => Z): List[Z] =
2729
fa.flatMap(a => fb.map(b => f(a, b)))
2830

31+
def tailRecM[A, B](a: A)(f: A => List[A Xor B]): List[B] = {
32+
val buf = List.newBuilder[B]
33+
@tailrec def go(lists: List[List[A Xor B]]): Unit = lists match {
34+
case (ab :: abs) :: tail => ab match {
35+
case Xor.Right(b) => buf += b; go(abs :: tail)
36+
case Xor.Left(a) => go(f(a) :: abs :: tail)
37+
}
38+
case Nil :: tail => go(tail)
39+
case Nil => ()
40+
}
41+
go(f(a) :: Nil)
42+
buf.result
43+
}
44+
2945
def coflatMap[A, B](fa: List[A])(f: List[A] => B): List[B] = {
3046
@tailrec def loop(buf: ListBuffer[B], as: List[A]): List[B] =
3147
as match {

core/src/main/scala/cats/std/option.scala

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
package cats
22
package std
33

4+
import scala.annotation.tailrec
5+
import cats.data.Xor
6+
47
trait OptionInstances extends cats.kernel.std.OptionInstances {
58

6-
implicit val optionInstance: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] =
7-
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with CoflatMap[Option] with Alternative[Option] {
9+
implicit val optionInstance: Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] =
10+
new Traverse[Option] with MonadError[Option, Unit] with MonadCombine[Option] with MonadRec[Option] with CoflatMap[Option] with Alternative[Option] {
811

912
def empty[A]: Option[A] = None
1013

@@ -18,6 +21,14 @@ trait OptionInstances extends cats.kernel.std.OptionInstances {
1821
def flatMap[A, B](fa: Option[A])(f: A => Option[B]): Option[B] =
1922
fa.flatMap(f)
2023

24+
@tailrec
25+
def tailRecM[A, B](a: A)(f: A => Option[A Xor B]): Option[B] =
26+
f(a) match {
27+
case None => None
28+
case Some(Xor.Left(a1)) => tailRecM(a1)(f)
29+
case Some(Xor.Right(b)) => Some(b)
30+
}
31+
2132
override def map2[A, B, Z](fa: Option[A], fb: Option[B])(f: (A, B) => Z): Option[Z] =
2233
fa.flatMap(a => fb.map(b => f(a, b)))
2334

free/src/main/scala/cats/free/Free.scala

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,11 +40,16 @@ object Free {
4040
/**
4141
* `Free[S, ?]` has a monad for any type constructor `S[_]`.
4242
*/
43-
implicit def freeMonad[S[_]]: Monad[Free[S, ?]] =
44-
new Monad[Free[S, ?]] {
43+
implicit def freeMonad[S[_]]: MonadRec[Free[S, ?]] =
44+
new MonadRec[Free[S, ?]] {
4545
def pure[A](a: A): Free[S, A] = Free.pure(a)
4646
override def map[A, B](fa: Free[S, A])(f: A => B): Free[S, B] = fa.map(f)
4747
def flatMap[A, B](a: Free[S, A])(f: A => Free[S, B]): Free[S, B] = a.flatMap(f)
48+
def tailRecM[A, B](a: A)(f: A => Free[S, A Xor B]): Free[S, B] =
49+
f(a).flatMap(_ match {
50+
case Xor.Left(a1) => tailRecM(a1)(f) // recursion OK here, since Free is lazy
51+
case Xor.Right(b) => pure(b)
52+
})
4853
}
4954
}
5055

free/src/test/scala/cats/free/FreeTests.scala

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@ package free
33

44
import cats.tests.CatsSuite
55
import cats.arrow.NaturalTransformation
6-
import cats.laws.discipline.{CartesianTests, MonadTests, SerializableTests}
6+
import cats.data.Xor
7+
import cats.laws.discipline.{CartesianTests, MonadRecTests, SerializableTests}
78
import cats.laws.discipline.arbitrary.function0Arbitrary
89

910
import org.scalacheck.{Arbitrary, Gen}
@@ -14,8 +15,8 @@ class FreeTests extends CatsSuite {
1415

1516
implicit val iso = CartesianTests.Isomorphisms.invariant[Free[Option, ?]]
1617

17-
checkAll("Free[Option, ?]", MonadTests[Free[Option, ?]].monad[Int, Int, Int])
18-
checkAll("Monad[Free[Option, ?]]", SerializableTests.serializable(Monad[Free[Option, ?]]))
18+
checkAll("Free[Option, ?]", MonadRecTests[Free[Option, ?]].monadRec[Int, Int, Int])
19+
checkAll("MonadRec[Free[Option, ?]]", SerializableTests.serializable(MonadRec[Free[Option, ?]]))
1920

2021
test("mapSuspension id"){
2122
forAll { x: Free[List, Int] =>
@@ -43,6 +44,13 @@ class FreeTests extends CatsSuite {
4344
}
4445
}
4546

47+
test("tailRecM is stack safe") {
48+
val n = 50000
49+
val fa = MonadRec[Free[Option, ?]].tailRecM(0)(i =>
50+
Free.pure[Option, Int Xor Int](if(i < n) Xor.Left(i+1) else Xor.Right(i)))
51+
fa should === (Free.pure[Option, Int](n))
52+
}
53+
4654
ignore("foldMap is stack safe") {
4755
trait FTestApi[A]
4856
case class TB(i: Int) extends FTestApi[Int]
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package cats
2+
package laws
3+
4+
import cats.data.Xor
5+
import cats.syntax.flatMap._
6+
import cats.syntax.functor._
7+
8+
/**
9+
* Laws that must be obeyed by any `FlatMapRec`.
10+
*/
11+
trait FlatMapRecLaws[F[_]] extends FlatMapLaws[F] {
12+
implicit override def F: FlatMapRec[F]
13+
14+
def tailRecMConsistentFlatMap[A](a: A, f: A => F[A]): IsEq[F[A]] = {
15+
val bounce = F.tailRecM[(A, Int), A]((a, 1)) { case (a0, i) =>
16+
if(i > 0) f(a0).map(a1 => Xor.left((a1, i-1)))
17+
else f(a0).map(Xor.right)
18+
}
19+
bounce <-> f(a).flatMap(f)
20+
}
21+
}
22+
23+
object FlatMapRecLaws {
24+
def apply[F[_]](implicit ev: FlatMapRec[F]): FlatMapRecLaws[F] =
25+
new FlatMapRecLaws[F] { def F: FlatMapRec[F] = ev }
26+
}

0 commit comments

Comments
 (0)