Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add laws to check Short-Circuiting behaviour #3375

Merged
Merged
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,13 @@ package alleycats.tests

import cats.Traverse
import cats.instances.all._
import cats.laws.discipline.{SerializableTests, TraverseFilterTests}
import cats.laws.discipline.arbitrary._
import cats.laws.discipline.{SerializableTests, ShortCircuitingTests, TraverseFilterTests}

class MapSuite extends AlleycatsSuite {
checkAll("Traverse[Map[Int, *]]", SerializableTests.serializable(Traverse[Map[Int, *]]))

checkAll("TraverseFilter[Map[Int, *]]", TraverseFilterTests[Map[Int, *]].traverseFilter[Int, Int, Int])

checkAll("Map[Int, *]", ShortCircuitingTests[Map[Int, *]].traverseFilter[Int])
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@ import alleycats.std.all._
import cats.Foldable
import cats.instances.all._
import cats.kernel.laws.discipline.SerializableTests
import cats.laws.discipline.TraverseFilterTests
import cats.laws.discipline.arbitrary._
import cats.laws.discipline.{ShortCircuitingTests, TraverseFilterTests}

class SetSuite extends AlleycatsSuite {
checkAll("FlatMapRec[Set]", FlatMapRecTests[Set].tailRecM[Int])

checkAll("Foldable[Set]", SerializableTests.serializable(Foldable[Set]))

checkAll("TraverseFilter[Set]", TraverseFilterTests[Set].traverseFilter[Int, Int, Int])

checkAll("Set[Int]", ShortCircuitingTests[Set].traverseFilter[Int])
}
10 changes: 7 additions & 3 deletions core/src/main/scala-2.12/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -199,12 +199,16 @@ private[instances] trait StreamInstancesBinCompat0 {
override def flattenOption[A](fa: Stream[Option[A]]): Stream[A] = fa.flatten

def traverseFilter[G[_], A, B](fa: Stream[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Stream[B]] =
fa.foldRight(Eval.now(G.pure(Stream.empty[B])))((x, xse) => G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o)))
traverse
.foldRight(fa, Eval.now(G.pure(Stream.empty[B])))((x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ #:: o))
)
.value

override def filterA[G[_], A](fa: Stream[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Stream[A]] =
fa.foldRight(Eval.now(G.pure(Stream.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, as) => if (b) x +: as else as)
traverse
.foldRight(fa, Eval.now(G.pure(Stream.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, stream) => if (b) x #:: stream else stream)
)
.value

Expand Down
3 changes: 2 additions & 1 deletion core/src/main/scala-2.13+/cats/instances/lazyList.scala
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,8 @@ trait LazyListInstances extends cats.kernel.instances.LazyListInstances {
.value

override def filterA[G[_], A](fa: LazyList[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[LazyList[A]] =
fa.foldRight(Eval.now(G.pure(LazyList.empty[A])))((x, xse) =>
traverse
.foldRight(fa, Eval.now(G.pure(LazyList.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, as) => if (b) x +: as else as)
)
.value
Expand Down
10 changes: 7 additions & 3 deletions core/src/main/scala-2.13+/cats/instances/stream.scala
Original file line number Diff line number Diff line change
Expand Up @@ -201,12 +201,16 @@ private[instances] trait StreamInstancesBinCompat0 {
override def flattenOption[A](fa: Stream[Option[A]]): Stream[A] = fa.flatten

def traverseFilter[G[_], A, B](fa: Stream[A])(f: (A) => G[Option[B]])(implicit G: Applicative[G]): G[Stream[B]] =
fa.foldRight(Eval.now(G.pure(Stream.empty[B])))((x, xse) => G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ +: o)))
traverse
.foldRight(fa, Eval.now(G.pure(Stream.empty[B])))((x, xse) =>
G.map2Eval(f(x), xse)((i, o) => i.fold(o)(_ #:: o))
)
.value

override def filterA[G[_], A](fa: Stream[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Stream[A]] =
fa.foldRight(Eval.now(G.pure(Stream.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, as) => if (b) x +: as else as)
traverse
.foldRight(fa, Eval.now(G.pure(Stream.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, stream) => if (b) x #:: stream else stream)
)
.value

Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/cats/instances/queue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,9 @@ private object QueueInstances {
.value

override def filterA[G[_], A](fa: Queue[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Queue[A]] =
fa.foldRight(Eval.now(G.pure(Queue.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, vec) => if (b) x +: vec else vec)
traverse
.foldRight(fa, Eval.now(G.pure(Queue.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, queue) => if (b) x +: queue else queue)
)
.value
}
Expand Down
5 changes: 3 additions & 2 deletions core/src/main/scala/cats/instances/vector.scala
Original file line number Diff line number Diff line change
Expand Up @@ -179,8 +179,9 @@ private[instances] trait VectorInstancesBinCompat0 {
.value

override def filterA[G[_], A](fa: Vector[A])(f: (A) => G[Boolean])(implicit G: Applicative[G]): G[Vector[A]] =
fa.foldRight(Eval.now(G.pure(Vector.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, vec) => if (b) x +: vec else vec)
traverse
.foldRight(fa, Eval.now(G.pure(Vector.empty[A])))((x, xse) =>
G.map2Eval(f(x), xse)((b, vector) => if (b) x +: vector else vector)
)
.value
}
Expand Down
93 changes: 93 additions & 0 deletions laws/src/main/scala/cats/laws/ShortCircuitingLaws.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
package cats.laws

import java.util.concurrent.atomic.AtomicLong

import cats.instances.option._
import cats.syntax.foldable._
import cats.syntax.traverse._
import cats.syntax.traverseFilter._
import cats.{Applicative, Traverse, TraverseFilter}

trait ShortCircuitingLaws[F[_]] {

def traverseShortCircuits[A](fa: F[A])(implicit F: Traverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.traverse(f)
f.invocations.get <-> (maxInvocationsAllowed + 1).min(size)
}

def traverseWontShortCircuit[A](fa: F[A])(implicit F: Traverse[F]): IsEq[Long] = {
val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Some(i), maxInvocationsAllowed, None)

fa.traverse(f)(nonShortCircuitingApplicative)
f.invocations.get <-> size
}

def traverseFilterShortCircuits[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Option(Option(i)), maxInvocationsAllowed, None)

fa.traverseFilter(f)
f.invocations.get <-> (maxInvocationsAllowed + 1).min(size)
}

def traverseFilterWontShortCircuit[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((i: A) => Option(Option(i)), maxInvocationsAllowed, None)

fa.traverseFilter(f)(nonShortCircuitingApplicative)
f.invocations.get <-> size
}

def filterAShortCircuits[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((_: A) => Some(true), maxInvocationsAllowed, None)

fa.filterA(f)
f.invocations.get <-> (maxInvocationsAllowed + 1).min(size)
}

def filterAWontShortCircuit[A](fa: F[A])(implicit TF: TraverseFilter[F]): IsEq[Long] = {
implicit val F: Traverse[F] = TF.traverse

val size = fa.size
val maxInvocationsAllowed = size / 2
val f = new RestrictedFunction((_: A) => Some(true), maxInvocationsAllowed, None)

fa.filterA(f)(nonShortCircuitingApplicative)
f.invocations.get <-> size
}

private[this] class RestrictedFunction[-A, +B](f: A => B, maxInvocationsAllowed: Long, empty: => B) extends (A => B) {
val invocations = new AtomicLong(0)

override def apply(v1: A): B =
if (invocations.getAndIncrement < maxInvocationsAllowed)
f(v1)
else
empty
}

private[this] val nonShortCircuitingApplicative: Applicative[Option] = new Applicative[Option] {
override def pure[A](a: A): Option[A] = Some(a)
override def ap[A, B](ff: Option[A => B])(fa: Option[A]): Option[B] = ff.flatMap(f => fa.map(f))
}
}

object ShortCircuitingLaws {
def apply[F[_]]: ShortCircuitingLaws[F] = new ShortCircuitingLaws[F] {}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
package cats.laws.discipline

import cats.laws.ShortCircuitingLaws
import cats.{Eq, Traverse, TraverseFilter}
import org.scalacheck.Arbitrary
import org.scalacheck.Prop.forAll
import org.typelevel.discipline.Laws

trait ShortCircuitingTests[F[_]] extends Laws {
def laws: ShortCircuitingLaws[F]

def traverse[A: Arbitrary](implicit F: Traverse[F], ArbFA: Arbitrary[F[A]], lEq: Eq[Long]): RuleSet =
new DefaultRuleSet(
name = "traverseShortCircuiting",
parent = None,
"traverse short-circuits if Applicative[G].map2Eval shorts" -> forAll(laws.traverseShortCircuits[A] _),
"traverse won't short-circuit if Applicative[G].map2Eval won't" -> forAll(laws.traverseWontShortCircuit[A] _)
)

def traverseFilter[A: Arbitrary](
implicit TF: TraverseFilter[F],
ArbFA: Arbitrary[F[A]],
lEq: Eq[Long]
): RuleSet = {
implicit val T: Traverse[F] = TF.traverse
new DefaultRuleSet(
name = "traverseFilterShortCircuiting",
parent = Some(traverse[A]),
"traverseFilter short-circuits if Applicative[G].map2Eval shorts" ->
forAll(laws.traverseFilterShortCircuits[A] _),
"traverseFilter short-circuits if Applicative[G].map2Eval won't" ->
forAll(laws.traverseFilterWontShortCircuit[A] _),
"filterA short-circuits if Applicative[G].map2Eval shorts" -> forAll(laws.filterAShortCircuits[A] _),
"filterA short-circuits if Applicative[G].map2Eval won't" -> forAll(laws.filterAWontShortCircuit[A] _)
)
}
}

object ShortCircuitingTests {
def apply[F[_]]: ShortCircuitingTests[F] = new ShortCircuitingTests[F] {
override def laws: ShortCircuitingLaws[F] = ShortCircuitingLaws[F]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class NonEmptyStreamSuite extends CatsSuite {
checkAll("NonEmptyStream[Int]", SemigroupTests[NonEmptyStream[Int]].semigroup)
checkAll("Semigroup[NonEmptyStream[Int]]", SerializableTests.serializable(Semigroup[NonEmptyStream[Int]]))

checkAll("NonEmptyStream[Int]", ShortCircuitingTests[NonEmptyStream].traverse[Int])

{
// Test functor and subclasses don't have implicit conflicts
implicitly[Functor[NonEmptyStream]]
Expand Down
3 changes: 3 additions & 0 deletions tests/src/test/scala-2.13+/cats/tests/LazyListSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import cats.laws.discipline.{
MonadTests,
SemigroupalTests,
SerializableTests,
ShortCircuitingTests,
TraverseFilterTests,
TraverseTests
}
Expand Down Expand Up @@ -39,6 +40,8 @@ class LazyListSuite extends CatsSuite {
checkAll("LazyList[Int]", AlignTests[LazyList].align[Int, Int, Int, Int])
checkAll("Align[LazyList]", SerializableTests.serializable(Align[LazyList]))

checkAll("LazyList[Int]", ShortCircuitingTests[LazyList].traverseFilter[Int])

// Can't test applicative laws as they don't terminate
checkAll("ZipLazyList[Int]", CommutativeApplyTests[ZipLazyList].apply[Int, Int, Int])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ import cats.{Align, Bimonad, SemigroupK, Show, Traverse}
import cats.data.{NonEmptyLazyList, NonEmptyLazyListOps}
import cats.kernel.{Eq, Hash, Order, PartialOrder, Semigroup}
import cats.kernel.laws.discipline.{EqTests, HashTests, OrderTests, PartialOrderTests, SemigroupTests}
import cats.laws.discipline.{AlignTests, BimonadTests, NonEmptyTraverseTests, SemigroupKTests, SerializableTests}
import cats.laws.discipline.{
AlignTests,
BimonadTests,
NonEmptyTraverseTests,
SemigroupKTests,
SerializableTests,
ShortCircuitingTests
}
import cats.laws.discipline.arbitrary._
import cats.syntax.either._
import cats.syntax.foldable._
Expand Down Expand Up @@ -36,6 +43,8 @@ class NonEmptyLazyListSuite extends NonEmptyCollectionSuite[LazyList, NonEmptyLa
checkAll("NonEmptyLazyList[Int]", AlignTests[NonEmptyLazyList].align[Int, Int, Int, Int])
checkAll("Align[NonEmptyLazyList]", SerializableTests.serializable(Align[NonEmptyLazyList]))

checkAll("NonEmptyLazyList[Int]", ShortCircuitingTests[NonEmptyLazyList].traverse[Int])

test("show") {
Show[NonEmptyLazyList[Int]].show(NonEmptyLazyList(1, 2, 3)) should ===("NonEmptyLazyList(1, ?)")
}
Expand Down
3 changes: 3 additions & 0 deletions tests/src/test/scala/cats/tests/ChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import cats.laws.discipline.{
CoflatMapTests,
MonadTests,
SerializableTests,
ShortCircuitingTests,
TraverseFilterTests,
TraverseTests
}
Expand Down Expand Up @@ -44,6 +45,8 @@ class ChainSuite extends CatsSuite {
checkAll("Chain[Int]", TraverseFilterTests[Chain].traverseFilter[Int, Int, Int])
checkAll("TraverseFilter[Chain]", SerializableTests.serializable(TraverseFilter[Chain]))

checkAll("Chain[Int]", ShortCircuitingTests[Chain].traverseFilter[Int])

{
implicit val partialOrder: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
checkAll("Chain[ListWrapper[Int]]", PartialOrderTests[Chain[ListWrapper[Int]]].partialOrder)
Expand Down
3 changes: 3 additions & 0 deletions tests/src/test/scala/cats/tests/ListSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import cats.laws.discipline.{
MonadTests,
SemigroupalTests,
SerializableTests,
ShortCircuitingTests,
TraverseFilterTests,
TraverseTests
}
Expand Down Expand Up @@ -41,6 +42,8 @@ class ListSuite extends CatsSuite {
checkAll("List[Int]", AlignTests[List].align[Int, Int, Int, Int])
checkAll("Align[List]", SerializableTests.serializable(Align[List]))

checkAll("List[Int]", ShortCircuitingTests[List].traverseFilter[Int])

checkAll("ZipList[Int]", CommutativeApplyTests[ZipList].commutativeApply[Int, Int, Int])

test("nel => list => nel returns original nel")(
Expand Down
11 changes: 10 additions & 1 deletion tests/src/test/scala/cats/tests/NonEmptyChainSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,14 @@ import cats.{Align, Bimonad, SemigroupK, Show, Traverse}
import cats.data.{Chain, NonEmptyChain, NonEmptyChainOps}
import cats.kernel.{Eq, Order, PartialOrder, Semigroup}
import cats.kernel.laws.discipline.{EqTests, OrderTests, PartialOrderTests, SemigroupTests}
import cats.laws.discipline.{AlignTests, BimonadTests, NonEmptyTraverseTests, SemigroupKTests, SerializableTests}
import cats.laws.discipline.{
AlignTests,
BimonadTests,
NonEmptyTraverseTests,
SemigroupKTests,
SerializableTests,
ShortCircuitingTests
}
import cats.laws.discipline.arbitrary._
import cats.syntax.either._
import cats.syntax.foldable._
Expand Down Expand Up @@ -33,6 +40,8 @@ class NonEmptyChainSuite extends NonEmptyCollectionSuite[Chain, NonEmptyChain, N
checkAll("NonEmptyChain[Int]", AlignTests[NonEmptyChain].align[Int, Int, Int, Int])
checkAll("Align[NonEmptyChain]", SerializableTests.serializable(Align[NonEmptyChain]))

checkAll("NonEmptyChain[Int]", ShortCircuitingTests[NonEmptyChain].traverse[Int])

{
implicit val partialOrder: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
checkAll("NonEmptyChain[ListWrapper[Int]]", PartialOrderTests[NonEmptyChain[ListWrapper[Int]]].partialOrder)
Expand Down
5 changes: 4 additions & 1 deletion tests/src/test/scala/cats/tests/NonEmptyListSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ import cats.laws.discipline.{
NonEmptyTraverseTests,
ReducibleTests,
SemigroupKTests,
SerializableTests
SerializableTests,
ShortCircuitingTests
}
import cats.syntax.foldable._
import cats.syntax.reducible._
Expand Down Expand Up @@ -55,6 +56,8 @@ class NonEmptyListSuite extends NonEmptyCollectionSuite[List, NonEmptyList, NonE

checkAll("ZipNonEmptyList[Int]", CommutativeApplyTests[ZipNonEmptyList].commutativeApply[Int, Int, Int])

checkAll("NonEmptyList[Int]", ShortCircuitingTests[NonEmptyList].traverse[Int])

{
implicit val A: PartialOrder[ListWrapper[Int]] = ListWrapper.partialOrder[Int]
checkAll("NonEmptyList[ListWrapper[Int]]", PartialOrderTests[NonEmptyList[ListWrapper[Int]]].partialOrder)
Expand Down
6 changes: 5 additions & 1 deletion tests/src/test/scala/cats/tests/NonEmptyVectorSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,15 @@ import cats.laws.discipline.{
NonEmptyTraverseTests,
ReducibleTests,
SemigroupKTests,
SerializableTests
SerializableTests,
ShortCircuitingTests
}
import cats.laws.discipline.arbitrary._
import cats.platform.Platform
import cats.syntax.foldable._
import cats.syntax.reducible._
import cats.syntax.show._

import scala.util.Properties

class NonEmptyVectorSuite extends NonEmptyCollectionSuite[Vector, NonEmptyVector, NonEmptyVector] {
Expand Down Expand Up @@ -81,6 +83,8 @@ class NonEmptyVectorSuite extends NonEmptyCollectionSuite[Vector, NonEmptyVector
checkAll("NonEmptyVector[Int]", BimonadTests[NonEmptyVector].bimonad[Int, Int, Int])
checkAll("Bimonad[NonEmptyVector]", SerializableTests.serializable(Bimonad[NonEmptyVector]))

checkAll("NonEmptyVector[Int]", ShortCircuitingTests[NonEmptyVector].traverse[Int])

test("size is consistent with toList.size") {
forAll { (nonEmptyVector: NonEmptyVector[Int]) =>
nonEmptyVector.size should ===(nonEmptyVector.toList.size.toLong)
Expand Down
Loading