Skip to content

Commit fc7b8b9

Browse files
johnynekLukaJCB
authored andcommitted
Add scala.util.control.TailCalls.TailRec instances (#3041)
* Add scala.util.control.TailCalls.TailRec instances * format * review comments * fix conflict * Format * Avoid val in trait for bincompat
1 parent 0aaa637 commit fc7b8b9

File tree

8 files changed

+69
-9
lines changed

8 files changed

+69
-9
lines changed

bench/src/main/scala/cats/bench/TrampolineBench.scala

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ import org.openjdk.jmh.annotations.{Benchmark, Scope, State}
55
import cats._
66
import cats.implicits._
77
import cats.free.Trampoline
8+
import scala.util.control.TailCalls
89

910
@State(Scope.Benchmark)
1011
class TrampolineBench {
@@ -30,14 +31,12 @@ class TrampolineBench {
3031
y <- Trampoline.defer(trampolineFib(n - 2))
3132
} yield x + y
3233

33-
// TailRec[A] only has .flatMap in 2.11.
34+
@Benchmark
35+
def stdlib(): Int = stdlibFib(N).result
3436

35-
// @Benchmark
36-
// def stdlib(): Int = stdlibFib(N).result
37-
//
38-
// def stdlibFib(n: Int): TailCalls.TailRec[Int] =
39-
// if (n < 2) TailCalls.done(n) else for {
40-
// x <- TailCalls.tailcall(stdlibFib(n - 1))
41-
// y <- TailCalls.tailcall(stdlibFib(n - 2))
42-
// } yield x + y
37+
def stdlibFib(n: Int): TailCalls.TailRec[Int] =
38+
if (n < 2) TailCalls.done(n) else for {
39+
x <- TailCalls.tailcall(stdlibFib(n - 1))
40+
y <- TailCalls.tailcall(stdlibFib(n - 2))
41+
} yield x + y
4342
}

core/src/main/scala-2.12/cats/instances/all.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ trait AllInstances
3838
with StreamInstances
3939
with StringInstances
4040
with SymbolInstances
41+
with TailRecInstances
4142
with TryInstances
4243
with TupleInstances
4344
with UUIDInstances

core/src/main/scala-2.12/cats/instances/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ package object instances {
3939
object sortedSet extends SortedSetInstances with SortedSetInstancesBinCompat0 with SortedSetInstancesBinCompat1
4040
object stream extends StreamInstances with StreamInstancesBinCompat0
4141
object string extends StringInstances
42+
object tailRec extends TailRecInstances
4243
object try_ extends TryInstances
4344
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
4445
object unit extends UnitInstances

core/src/main/scala-2.13+/cats/instances/all.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ trait AllInstances
3939
with StreamInstances
4040
with StringInstances
4141
with SymbolInstances
42+
with TailRecInstances
4243
with TryInstances
4344
with TupleInstances
4445
with UUIDInstances

core/src/main/scala-2.13+/cats/instances/package.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ package object instances {
4242
object stream extends StreamInstances with StreamInstancesBinCompat0
4343
object lazyList extends LazyListInstances
4444
object string extends StringInstances
45+
object tailRec extends TailRecInstances
4546
object try_ extends TryInstances
4647
object tuple extends TupleInstances with Tuple2InstancesBinCompat0
4748
object unit extends UnitInstances

core/src/main/scala/cats/Eval.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,7 @@ sealed abstract private[cats] class EvalInstances extends EvalInstances0 {
378378
def flatMap[A, B](fa: Eval[A])(f: A => Eval[B]): Eval[B] = fa.flatMap(f)
379379
def extract[A](la: Eval[A]): A = la.value
380380
def coflatMap[A, B](fa: Eval[A])(f: Eval[A] => B): Eval[B] = Later(f(fa))
381+
override def unit: Eval[Unit] = Eval.Unit
381382
}
382383

383384
implicit val catsDeferForEval: Defer[Eval] =
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
package cats
2+
package instances
3+
4+
import scala.util.control.TailCalls.{done, tailcall, TailRec}
5+
6+
trait TailRecInstances {
7+
implicit def catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
8+
TailRecInstances.catsInstancesForTailRec
9+
}
10+
11+
private object TailRecInstances {
12+
val catsInstancesForTailRec: StackSafeMonad[TailRec] with Defer[TailRec] =
13+
new StackSafeMonad[TailRec] with Defer[TailRec] {
14+
def defer[A](fa: => TailRec[A]): TailRec[A] = tailcall(fa)
15+
16+
def pure[A](a: A): TailRec[A] = done(a)
17+
18+
override def map[A, B](fa: TailRec[A])(f: A => B): TailRec[B] =
19+
fa.map(f)
20+
21+
def flatMap[A, B](fa: TailRec[A])(f: A => TailRec[B]): TailRec[B] =
22+
fa.flatMap(f)
23+
24+
override val unit: TailRec[Unit] = done(())
25+
}
26+
}
Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
package cats
2+
package tests
3+
4+
import scala.util.control.TailCalls.{done, tailcall, TailRec}
5+
import org.scalacheck.{Arbitrary, Cogen, Gen}
6+
7+
import Arbitrary.arbitrary
8+
9+
import cats.laws.discipline.{DeferTests, MonadTests, SerializableTests}
10+
11+
class TailRecSuite extends CatsSuite {
12+
13+
implicit def tailRecArb[A: Arbitrary: Cogen]: Arbitrary[TailRec[A]] =
14+
Arbitrary(
15+
Gen.frequency(
16+
(3, arbitrary[A].map(done(_))),
17+
(1, Gen.lzy(arbitrary[(A, A => TailRec[A])].map { case (a, fn) => tailcall(fn(a)) })),
18+
(1, Gen.lzy(arbitrary[(TailRec[A], A => TailRec[A])].map { case (a, fn) => a.flatMap(fn) }))
19+
)
20+
)
21+
22+
implicit def eqTailRec[A: Eq]: Eq[TailRec[A]] =
23+
Eq.by[TailRec[A], A](_.result)
24+
25+
checkAll("TailRec[Int]", MonadTests[TailRec].monad[Int, Int, Int])
26+
checkAll("Monad[TailRec]", SerializableTests.serializable(Monad[TailRec]))
27+
28+
checkAll("TailRec[Int]", DeferTests[TailRec].defer[Int])
29+
checkAll("Defer[TailRec]", SerializableTests.serializable(Defer[TailRec]))
30+
}

0 commit comments

Comments
 (0)