Skip to content

Commit

Permalink
Add Semigroup and Monoid combinators reverse and intercalate (#3279)
Browse files Browse the repository at this point in the history
* Add Semigroup and Monoid combinators reverse and intercalate

* format

* fix combineN on Monoid.reverse and Duration tests

* respond to review comments

* remove optimization of commutative semigroup intercalate combineN which fails for almost commutative BigDecimal
  • Loading branch information
johnynek authored Feb 25, 2020
1 parent b2fd66f commit 8501d0b
Show file tree
Hide file tree
Showing 14 changed files with 140 additions and 17 deletions.
5 changes: 4 additions & 1 deletion core/src/main/scala/cats/Foldable.scala
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,10 @@ import Foldable.sentinel
* }}}
*/
def intercalate[A](fa: F[A], a: A)(implicit A: Monoid[A]): A =
A.combineAll(intersperseList(toList(fa), a))
combineAllOption(fa)(A.intercalate(a)) match {
case None => A.empty
case Some(a) => a
}

protected def intersperseList[A](xs: List[A], x: A): List[A] = {
val bld = List.newBuilder[A]
Expand Down
6 changes: 1 addition & 5 deletions core/src/main/scala/cats/Reducible.scala
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,7 @@ import simulacrum.{noop, typeclass}
* }}}
*/
def nonEmptyIntercalate[A](fa: F[A], a: A)(implicit A: Semigroup[A]): A =
toNonEmptyList(fa) match {
case NonEmptyList(hd, Nil) => hd
case NonEmptyList(hd, tl) =>
Reducible[NonEmptyList].reduce(NonEmptyList(hd, a :: intersperseList(tl, a)))
}
reduce(fa)(A.intercalate(a))

/**
* Partition this Reducible by a separating function `A => Either[B, C]`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,41 @@ trait SemigroupLaws[A] {
def combineAllOption(xs: Vector[A]): IsEq[Option[A]] =
S.combineAllOption(xs) <-> xs.reduceOption(S.combine)

def reverseReverses(a: A, b: A): IsEq[A] =
S.combine(a, b) <-> S.reverse.combine(b, a)

def reverseRepeat1(a: A): IsEq[A] = {
val rev = S.reverse
rev.combineN(a, 1) <-> a
}

def reverseRepeat2(a: A): IsEq[A] = {
val rev = S.reverse
rev.combineN(a, 2) <-> rev.combine(a, a)
}

def reverseCombineAllOption(xs: Vector[A]): IsEq[Option[A]] = {
val rev = S.reverse
rev.combineAllOption(xs) <-> xs.reduceOption(rev.combine)
}

def intercalateIntercalates(a: A, m: A, b: A): IsEq[A] =
S.combine(a, S.combine(m, b)) <-> S.intercalate(m).combine(a, b)

def intercalateRepeat1(m: A, a: A): IsEq[A] = {
val withMiddle = S.intercalate(m)
withMiddle.combineN(a, 1) <-> a
}

def intercalateRepeat2(m: A, a: A): IsEq[A] = {
val withMiddle = S.intercalate(m)
withMiddle.combineN(a, 2) <-> withMiddle.combine(a, a)
}

def intercalateCombineAllOption(m: A, xs: Vector[A]): IsEq[Option[A]] = {
val withMiddle = S.intercalate(m)
withMiddle.combineAllOption(xs) <-> xs.reduceOption(withMiddle.combine)
}
}

object SemigroupLaws {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,15 @@ trait SemigroupTests[A] extends Laws {
"associative" -> forAll(laws.semigroupAssociative _),
"repeat1" -> forAll(laws.repeat1 _),
"repeat2" -> forAll(laws.repeat2 _),
"combineAllOption" -> forAll(laws.combineAllOption _)
"combineAllOption" -> forAll(laws.combineAllOption _),
"reverseReverses" -> forAll(laws.reverseReverses _),
"reverseRepeat1" -> forAll(laws.reverseRepeat1 _),
"reverseRepeat2" -> forAll(laws.reverseRepeat2 _),
"reverseCombineAllOption" -> forAll(laws.reverseCombineAllOption _),
"intercalateIntercalates" -> forAll(laws.intercalateIntercalates _),
"intercalateRepeat1" -> forAll(laws.intercalateRepeat1 _),
"intercalateRepeat2" -> forAll(laws.intercalateRepeat2 _),
"intercalateCombineAllOption" -> forAll(laws.intercalateCombineAllOption _)
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ object KernelCheck {
implicit val arbitraryDuration: Arbitrary[Duration] = {
// max range is +/- 292 years, but we give ourselves some extra headroom
// to ensure that we can add these things up. they crash on overflow.
val n = (292L * 365) / 50
val n = (292L * 365) / 500
Arbitrary(
Gen.oneOf(
Gen.choose(-n, n).map(Duration(_, DAYS)),
Expand All @@ -51,7 +51,7 @@ object KernelCheck {
implicit val arbitraryFiniteDuration: Arbitrary[FiniteDuration] = {
// max range is +/- 292 years, but we give ourselves some extra headroom
// to ensure that we can add these things up. they crash on overflow.
val n = (292L * 365) / 50
val n = (292L * 365) / 500
Arbitrary(
Gen.oneOf(
Gen.choose(-n, n).map(FiniteDuration(_, DAYS)),
Expand Down
6 changes: 5 additions & 1 deletion kernel/src/main/scala/cats/kernel/Band.scala
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,11 @@ import scala.{specialized => sp}
* Bands are semigroups whose operation
* (i.e. combine) is also idempotent.
*/
trait Band[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A]
trait Band[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A] {
override def combineN(a: A, n: Int): A =
if (n <= 0) throw new IllegalArgumentException("Repeated combining for semigroups must have n > 0")
else a // combine(a, a) == a
}

object Band extends SemigroupFunctions[Band] {

Expand Down
7 changes: 6 additions & 1 deletion kernel/src/main/scala/cats/kernel/BoundedSemilattice.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,12 @@ package cats.kernel

import scala.{specialized => sp}

trait BoundedSemilattice[@sp(Int, Long, Float, Double) A] extends Any with Semilattice[A] with CommutativeMonoid[A]
trait BoundedSemilattice[@sp(Int, Long, Float, Double) A] extends Any with Semilattice[A] with CommutativeMonoid[A] {
override def combineN(a: A, n: Int): A =
if (n < 0) throw new IllegalArgumentException("Repeated combining for monoids must have n >= 0")
else if (n == 0) empty
else a // combine(a, a) == a for a semilattice
}

object BoundedSemilattice extends SemilatticeFunctions[BoundedSemilattice] {

Expand Down
5 changes: 4 additions & 1 deletion kernel/src/main/scala/cats/kernel/CommutativeMonoid.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@ import scala.{specialized => sp}
*
* A monoid is commutative if for all x and y, x |+| y === y |+| x.
*/
trait CommutativeMonoid[@sp(Int, Long, Float, Double) A] extends Any with Monoid[A] with CommutativeSemigroup[A]
trait CommutativeMonoid[@sp(Int, Long, Float, Double) A] extends Any with Monoid[A] with CommutativeSemigroup[A] {
self =>
override def reverse: CommutativeMonoid[A] = self
}

object CommutativeMonoid extends MonoidFunctions[CommutativeMonoid] {

Expand Down
9 changes: 8 additions & 1 deletion kernel/src/main/scala/cats/kernel/CommutativeSemigroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@ import scala.{specialized => sp}
*
* A semigroup is commutative if for all x and y, x |+| y === y |+| x.
*/
trait CommutativeSemigroup[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A]
trait CommutativeSemigroup[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A] { self =>
override def reverse: CommutativeSemigroup[A] = self
override def intercalate(middle: A): CommutativeSemigroup[A] =
new CommutativeSemigroup[A] {
def combine(a: A, b: A): A =
self.combine(a, self.combine(middle, b))
}
}

object CommutativeSemigroup extends SemigroupFunctions[CommutativeSemigroup] {

Expand Down
11 changes: 10 additions & 1 deletion kernel/src/main/scala/cats/kernel/Monoid.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import compat.scalaVersionSpecific._
* `combine(x, empty) == combine(empty, x) == x`. For example, if we have `Monoid[String]`,
* with `combine` as string concatenation, then `empty = ""`.
*/
trait Monoid[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A] {
trait Monoid[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A] { self =>

/**
* Return the identity element for this monoid.
Expand Down Expand Up @@ -83,6 +83,15 @@ trait Monoid[@sp(Int, Long, Float, Double) A] extends Any with Semigroup[A] {

override def combineAllOption(as: IterableOnce[A]): Option[A] =
if (as.iterator.isEmpty) None else Some(combineAll(as))

override def reverse: Monoid[A] =
new Monoid[A] {
def empty = self.empty
def combine(a: A, b: A) = self.combine(b, a)
// a + a + a + ... is the same when reversed
override def combineN(a: A, n: Int): A = self.combineN(a, n)
override def reverse = self
}
}

@suppressUnusedImportWarningForScalaVersionSpecific
Expand Down
4 changes: 4 additions & 0 deletions kernel/src/main/scala/cats/kernel/Order.scala
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,10 @@ object Order extends OrderFunctions[Order] with OrderToOrderingConversion {
new Monoid[Order[A]] with Band[Order[A]] {
val empty: Order[A] = allEqual[A]
def combine(x: Order[A], y: Order[A]): Order[A] = Order.whenEqual(x, y)
override def combineN(a: Order[A], n: Int): Order[A] =
if (n < 0) throw new IllegalArgumentException("Repeated combining for monoids must have n >= 0")
else if (n == 0) empty
else a // combine(a, a) == a for a band
}

def fromOrdering[A](implicit ev: Ordering[A]): Order[A] =
Expand Down
24 changes: 23 additions & 1 deletion kernel/src/main/scala/cats/kernel/Semigroup.scala
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import compat.scalaVersionSpecific._
/**
* A semigroup is any set `A` with an associative operation (`combine`).
*/
trait Semigroup[@sp(Int, Long, Float, Double) A] extends Any with Serializable {
trait Semigroup[@sp(Int, Long, Float, Double) A] extends Any with Serializable { self =>

/**
* Associative operation which combines two values.
Expand Down Expand Up @@ -77,6 +77,28 @@ trait Semigroup[@sp(Int, Long, Float, Double) A] extends Any with Serializable {
*/
def combineAllOption(as: IterableOnce[A]): Option[A] =
as.reduceOption(combine)

/**
* return a semigroup that reverses the order
* so combine(a, b) == reverse.combine(b, a)
*/
def reverse: Semigroup[A] =
new Semigroup[A] {
def combine(a: A, b: A): A = self.combine(b, a)
// a + a + a + ... is the same when reversed
override def combineN(a: A, n: Int): A = self.combineN(a, n)
override def reverse = self
}

/**
* Between each pair of elements insert middle
* This name matches the term used in Foldable and Reducible and a similar Haskell function.
*/
def intercalate(middle: A): Semigroup[A] =
new Semigroup[A] {
def combine(a: A, b: A): A =
self.combine(a, self.combine(middle, b))
}
}

abstract class SemigroupFunctions[S[T] <: Semigroup[T]] {
Expand Down
15 changes: 14 additions & 1 deletion kernel/src/main/scala/cats/kernel/instances/ListInstances.scala
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ class ListEq[A](implicit ev: Eq[A]) extends Eq[List[A]] {
}
}

class ListMonoid[A] extends Monoid[List[A]] {
class ListMonoid[A] extends Monoid[List[A]] { self =>
def empty: List[A] = Nil
def combine(x: List[A], y: List[A]): List[A] = x ::: y

Expand All @@ -92,4 +92,17 @@ class ListMonoid[A] extends Monoid[List[A]] {

override def combineAll(xs: IterableOnce[List[A]]): List[A] =
StaticMethods.combineAllIterable(List.newBuilder[A], xs)

override def reverse: Monoid[List[A]] =
new Monoid[List[A]] {
def empty: List[A] = Nil
def combine(x: List[A], y: List[A]) = y ::: x

override def combineAll(xs: IterableOnce[List[A]]): List[A] =
xs.iterator.foldLeft(empty) { (acc, item) =>
item ::: acc
}

override def reverse = self
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class StringOrder extends Order[String] with Hash[String] with StringLowerBounde
override val partialOrder: PartialOrder[String] = self
}

class StringMonoid extends Monoid[String] {
class StringMonoid extends Monoid[String] { self =>
def empty: String = ""
def combine(x: String, y: String): String = x + y

Expand All @@ -32,4 +32,18 @@ class StringMonoid extends Monoid[String] {
xs.iterator.foreach(sb.append)
sb.toString
}

override def reverse: Monoid[String] =
new Monoid[String] {
def empty = self.empty
def combine(x: String, y: String) = y + x
override def combineAll(xs: IterableOnce[String]): String = {
val revStrings = xs.iterator.foldLeft(List.empty[String]) { (acc, s) =>
s :: acc
}
self.combineAll(revStrings)
}

override def reverse = self
}
}

0 comments on commit 8501d0b

Please sign in to comment.