Skip to content

Commit 8e50b78

Browse files
backported #3041, added tailrec instance for StacksafeMonad and Defer (#3332)
1 parent 8227961 commit 8e50b78

File tree

6 files changed

+67
-10
lines changed

6 files changed

+67
-10
lines changed

bench/src/main/scala/cats/bench/TrampolineBench.scala

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
package cats.bench
22

33
import org.openjdk.jmh.annotations.{Benchmark, Scope, State}
4-
54
import cats._
65
import cats.implicits._
76
import cats.free.Trampoline
87

8+
import scala.util.control.TailCalls
9+
910
@State(Scope.Benchmark)
1011
class TrampolineBench {
1112

@@ -30,14 +31,12 @@ class TrampolineBench {
3031
y <- Trampoline.defer(trampolineFib(n - 2))
3132
} yield x + y
3233

33-
// TailRec[A] only has .flatMap in 2.11.
34+
@Benchmark
35+
def stdlib(): Int = stdlibFib(N).result
3436

35-
// @Benchmark
36-
// def stdlib(): Int = stdlibFib(N).result
37-
//
38-
// def stdlibFib(n: Int): TailCalls.TailRec[Int] =
39-
// if (n < 2) TailCalls.done(n) else for {
40-
// x <- TailCalls.tailcall(stdlibFib(n - 1))
41-
// y <- TailCalls.tailcall(stdlibFib(n - 2))
42-
// } yield x + y
37+
def stdlibFib(n: Int): TailCalls.TailRec[Int] =
38+
if (n < 2) TailCalls.done(n) else for {
39+
x <- TailCalls.tailcall(stdlibFib(n - 1))
40+
y <- TailCalls.tailcall(stdlibFib(n - 2))
41+
} yield x + y
4342
}

core/src/main/scala/cats/Eval.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 {
378378
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
379379
def extract[A](la: Eval[A]): A = la.value
380380
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
381+
override def unit: Eval[Unit] = Eval.Unit
381382
}
382383

383384
implicit val catsDeferForEval: Defer[Eval] =

core/src/main/scala/cats/instances/all.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,4 +69,5 @@ trait AllInstancesBinCompat7
6969
with VectorInstancesBinCompat1
7070
with EitherInstancesBinCompat0
7171
with StreamInstancesBinCompat1
72+
with TailRecInstances
7273
with SortedSetInstancesBinCompat2

core/src/main/scala/cats/instances/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ package object instances {
4343
with SortedSetInstancesBinCompat2
4444
object stream extends StreamInstances with StreamInstancesBinCompat0 with StreamInstancesBinCompat1
4545
object string extends StringInstances
46+
object tailRec extends TailRecInstances
4647
object try_ extends TryInstances
4748
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
4849
object unit extends UnitInstances
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package cats.instances
2+
3+
import cats.{Defer, StackSafeMonad}
4+
import scala.util.control.TailCalls.{done, tailcall, TailRec}
5+
6+
trait TailRecInstances {
7+
implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
8+
TailRecInstances.catsInstancesForTailRec
9+
}
10+
11+
private object TailRecInstances {
12+
val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
13+
new StackSafeMonad[TailRec] with Defer[TailRec] {
14+
def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)
15+
16+
def pure[A](a: A): TailRec[A] = done(a)
17+
18+
override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] =
19+
fa.map(f)
20+
21+
def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] =
22+
fa.flatMap(f)
23+
24+
override val unit: TailRec[Unit] = done(())
25+
}
26+
}
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
package cats.tests
2+
3+
import cats.{Defer, Eq, Monad}
4+
import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests}
5+
import org.scalacheck.Arbitrary.arbitrary
6+
import org.scalacheck.{Arbitrary, Cogen, Gen}
7+
8+
import scala.util.control.TailCalls.{done, tailcall, TailRec}
9+
10+
class TailRecSuite extends CatsSuite {
11+
12+
implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] =
13+
Arbitrary(
14+
Gen.frequency(
15+
(3, arbitrary[A].map(done)),
16+
(1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })),
17+
(1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) }))
18+
)
19+
)
20+
21+
implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] =
22+
Eq.by[TailRec[A], A](_.result)
23+
24+
checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int])
25+
checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec]))
26+
27+
checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int])
28+
checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec]))
29+
}

0 commit comments

Comments
 (0)