Skip to content

Commit

Permalink
Add and fix contracts for inline functions
Browse files Browse the repository at this point in the history
  • Loading branch information
kyay10 committed Nov 15, 2024
1 parent 3279876 commit 8562c9b
Show file tree
Hide file tree
Showing 9 changed files with 345 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,10 @@ public sealed class Either<out A, out B> {
* <!--- TEST lines.isEmpty() -->
*/
public inline fun isLeft(predicate: (A) -> Boolean): Boolean {
contract { returns(true) implies (this@Either is Left<A>) }
contract {
returns(true) implies (this@Either is Left<A>)
callsInPlace(predicate, InvocationKind.AT_MOST_ONCE)
}
return this@Either is Left<A> && predicate(value)
}

Expand All @@ -554,7 +557,10 @@ public sealed class Either<out A, out B> {
* <!--- TEST lines.isEmpty() -->
*/
public inline fun isRight(predicate: (B) -> Boolean): Boolean {
contract { returns(true) implies (this@Either is Right<B>) }
contract {
returns(true) implies (this@Either is Right<B>)
callsInPlace(predicate, InvocationKind.AT_MOST_ONCE)
}
return this@Either is Right<B> && predicate(value)
}

Expand Down Expand Up @@ -799,12 +805,16 @@ public sealed class Either<out A, out B> {

public companion object {
@JvmStatic
public inline fun <R> catch(f: () -> R): Either<Throwable, R> =
arrow.core.raise.catch({ f().right() }) { it.left() }
public inline fun <R> catch(f: () -> R): Either<Throwable, R> {
contract { callsInPlace(f, InvocationKind.AT_MOST_ONCE) }
return arrow.core.raise.catch({ f().right() }) { it.left() }
}

@JvmStatic
public inline fun <reified T : Throwable, R> catchOrThrow(f: () -> R): Either<T, R> =
arrow.core.raise.catch<T, Either<T, R>>({ f().right() }) { it.left() }
public inline fun <reified T : Throwable, R> catchOrThrow(f: () -> R): Either<T, R> {
contract { callsInPlace(f, InvocationKind.AT_MOST_ONCE) }
return arrow.core.raise.catch<T, Either<T, R>>({ f().right() }) { it.left() }
}

public inline fun <E, A, B, Z> zipOrAccumulate(
combine: (E, E) -> E,
Expand Down Expand Up @@ -1369,8 +1379,12 @@ public operator fun <A : Comparable<A>, B : Comparable<B>> Either<A, B>.compareT
* If both are [Right] then combine both [B] values using [combineRight] or if both are [Left] then combine both [A] values using [combineLeft],
* otherwise return the sole [Left] value (either `this` or [other]).
*/
public fun <A, B> Either<A, B>.combine(other: Either<A, B>, combineLeft: (A, A) -> A, combineRight: (B, B) -> B): Either<A, B> =
when (val one = this) {
public inline fun <A, B> Either<A, B>.combine(other: Either<A, B>, combineLeft: (A, A) -> A, combineRight: (B, B) -> B): Either<A, B> {
contract {
callsInPlace(combineLeft, InvocationKind.AT_MOST_ONCE)
callsInPlace(combineRight, InvocationKind.AT_MOST_ONCE)
}
return when (val one = this) {
is Left -> when (other) {
is Left -> Left(combineLeft(one.value, other.value))
is Right -> one
Expand All @@ -1381,6 +1395,7 @@ public fun <A, B> Either<A, B>.combine(other: Either<A, B>, combineLeft: (A, A)
is Right -> Right(combineRight(one.value, other.value))
}
}
}

public const val NicheAPI: String =
"This API is niche and will be removed in the future. If this method is crucial for you, please let us know on the Arrow Github. Thanks!\n https://github.com/arrow-kt/arrow/issues\n"
Expand Down
37 changes: 29 additions & 8 deletions arrow-libs/core/arrow-core/src/commonMain/kotlin/arrow/core/Ior.kt
Original file line number Diff line number Diff line change
Expand Up @@ -339,6 +339,7 @@ public sealed class Ior<out A, out B> {
contract {
returns(true) implies (this@Ior is Left<A>)
returns(false) implies (this@Ior is Right<B> || this@Ior is Both<A, B>)
callsInPlace(predicate, InvocationKind.AT_MOST_ONCE)
}
return this@Ior is Left<A> && predicate(value)
}
Expand All @@ -364,6 +365,7 @@ public sealed class Ior<out A, out B> {
contract {
returns(true) implies (this@Ior is Right<B>)
returns(false) implies (this@Ior is Left<A> || this@Ior is Both<A, B>)
callsInPlace(predicate, InvocationKind.AT_MOST_ONCE)
}
return this@Ior is Right<B> && predicate(value)
}
Expand All @@ -390,6 +392,8 @@ public sealed class Ior<out A, out B> {
contract {
returns(true) implies (this@Ior is Both<A, B>)
returns(false) implies (this@Ior is Left<A> || this@Ior is Right<B>)
callsInPlace(leftPredicate, InvocationKind.AT_MOST_ONCE)
callsInPlace(rightPredicate, InvocationKind.AT_MOST_ONCE)
}
return this@Ior is Both<A, B> && leftPredicate(leftValue) && rightPredicate(rightValue)
}
Expand All @@ -400,8 +404,12 @@ public sealed class Ior<out A, out B> {
*
* @param f The function to bind across [Ior.Right].
*/
public inline fun <A, B, D> Ior<A, B>.flatMap(combine: (A, A) -> A, f: (B) -> Ior<A, D>): Ior<A, D> =
when (this) {
public inline fun <A, B, D> Ior<A, B>.flatMap(combine: (A, A) -> A, f: (B) -> Ior<A, D>): Ior<A, D> {
contract {
callsInPlace(combine, InvocationKind.AT_MOST_ONCE)
callsInPlace(f, InvocationKind.AT_MOST_ONCE)
}
return when (this) {
is Left -> this
is Right -> f(value)
is Both -> when (val r = f(rightValue)) {
Expand All @@ -410,14 +418,19 @@ public inline fun <A, B, D> Ior<A, B>.flatMap(combine: (A, A) -> A, f: (B) -> Io
is Both -> Both(combine(this.leftValue, r.leftValue), r.rightValue)
}
}
}

/**
* Binds the given function across [Ior.Left].
*
* @param f The function to bind across [Ior.Left].
*/
public inline fun <A, B, D> Ior<A, B>.handleErrorWith(combine: (B, B) -> B, f: (A) -> Ior<D, B>): Ior<D, B> =
when (this) {
public inline fun <A, B, D> Ior<A, B>.handleErrorWith(combine: (B, B) -> B, f: (A) -> Ior<D, B>): Ior<D, B> {
contract {
callsInPlace(combine, InvocationKind.AT_MOST_ONCE)
callsInPlace(f, InvocationKind.AT_MOST_ONCE)
}
return when (this) {
is Left -> f(value)
is Right -> this
is Both -> when (val l = f(leftValue)) {
Expand All @@ -426,6 +439,7 @@ public inline fun <A, B, D> Ior<A, B>.handleErrorWith(combine: (B, B) -> B, f: (
is Both -> Both(l.leftValue, combine(this.rightValue, l.rightValue))
}
}
}

public inline fun <A, B> Ior<A, B>.getOrElse(default: (A) -> B): B {
contract { callsInPlace(default, InvocationKind.AT_MOST_ONCE) }
Expand All @@ -443,8 +457,12 @@ public fun <A> A.leftIor(): Ior<A, Nothing> = Ior.Left(this)

public fun <A> A.rightIor(): Ior<Nothing, A> = Ior.Right(this)

public fun <A, B> Ior<A, B>.combine(other: Ior<A, B>, combineA: (A, A) -> A, combineB: (B, B) -> B): Ior<A, B> =
when (this) {
public inline fun <A, B> Ior<A, B>.combine(other: Ior<A, B>, combineA: (A, A) -> A, combineB: (B, B) -> B): Ior<A, B> {
contract {
callsInPlace(combineA, InvocationKind.AT_MOST_ONCE)
callsInPlace(combineB, InvocationKind.AT_MOST_ONCE)
}
return when (this) {
is Ior.Left -> when (other) {
is Ior.Left -> Ior.Left(combineA(value, other.value))
is Ior.Right -> Ior.Both(value, other.value)
Expand All @@ -463,9 +481,12 @@ public fun <A, B> Ior<A, B>.combine(other: Ior<A, B>, combineA: (A, A) -> A, com
is Ior.Both -> Ior.Both(combineA(leftValue, other.leftValue), combineB(rightValue, other.rightValue))
}
}
}

public inline fun <A, B> Ior<A, Ior<A, B>>.flatten(combine: (A, A) -> A): Ior<A, B> =
flatMap(combine, ::identity)
public inline fun <A, B> Ior<A, Ior<A, B>>.flatten(combine: (A, A) -> A): Ior<A, B> {
contract { callsInPlace(combine, InvocationKind.AT_MOST_ONCE) }
return flatMap(combine, ::identity)
}

/**
* Given an [Ior] with an error type [A], returns an [IorNel] with the same
Expand Down
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
@file:OptIn(ExperimentalTypeInference::class)
@file:OptIn(ExperimentalTypeInference::class, ExperimentalContracts::class)

package arrow.core

import arrow.core.raise.RaiseAccumulate
import arrow.core.raise.either
import arrow.core.raise.withError
import arrow.core.raise.mapOrAccumulate as raiseMapOrAccumulate
import kotlin.contracts.ExperimentalContracts
import kotlin.contracts.InvocationKind
import kotlin.contracts.contract
import kotlin.collections.unzip as stdlibUnzip
import kotlin.experimental.ExperimentalTypeInference
import kotlin.jvm.JvmInline
Expand Down Expand Up @@ -209,17 +215,23 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
public override operator fun plus(element: @UnsafeVariance A): NonEmptyList<A> =
NonEmptyList(all + element)

public inline fun <B> foldLeft(b: B, f: (B, A) -> B): B =
all.fold(b, f)
@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B> foldLeft(b: B, f: (B, A) -> B): B {
contract { callsInPlace(f, InvocationKind.AT_LEAST_ONCE) }
return all.fold(b, f)
}

public fun <B> coflatMap(f: (NonEmptyList<A>) -> B): NonEmptyList<B> =
buildList {
@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B> coflatMap(f: (NonEmptyList<A>) -> B): NonEmptyList<B> {
contract { callsInPlace(f, InvocationKind.AT_LEAST_ONCE) }
return buildList {
var current = all
while (current.isNotEmpty()) {
add(f(NonEmptyList(current)))
current = current.drop(1)
}
}.let(::NonEmptyList)
}

public fun extract(): A =
this.head
Expand All @@ -233,8 +245,11 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
public fun <B> padZip(other: NonEmptyList<B>): NonEmptyList<Pair<A?, B?>> =
padZip(other, { it to null }, { null to it }, { a, b -> a to b })

public inline fun <B, C> padZip(other: NonEmptyList<B>, left: (A) -> C, right: (B) -> C, both: (A, B) -> C): NonEmptyList<C> =
NonEmptyList(both(head, other.head), tail.padZip(other.tail, left, right, both))
@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C> padZip(other: NonEmptyList<B>, left: (A) -> C, right: (B) -> C, both: (A, B) -> C): NonEmptyList<C> {
contract { callsInPlace(both, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.padZip(other, left, right, both))
}

public companion object {
@PublishedApi
Expand All @@ -245,46 +260,62 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
public fun <B> zip(fb: NonEmptyList<B>): NonEmptyList<Pair<A, B>> =
zip(fb, ::Pair)

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, Z> zip(
b: NonEmptyList<B>,
map: (A, B) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
map: (A, B, C) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
d: NonEmptyList<D>,
map: (A, B, C, D) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
d: NonEmptyList<D>,
e: NonEmptyList<E>,
map: (A, B, C, D, E) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, F, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
d: NonEmptyList<D>,
e: NonEmptyList<E>,
f: NonEmptyList<F>,
map: (A, B, C, D, E, F) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, F, G, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
Expand All @@ -293,9 +324,12 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
f: NonEmptyList<F>,
g: NonEmptyList<G>,
map: (A, B, C, D, E, F, G) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, F, G, H, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
Expand All @@ -305,9 +339,12 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
g: NonEmptyList<G>,
h: NonEmptyList<H>,
map: (A, B, C, D, E, F, G, H) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, F, G, H, I, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
Expand All @@ -318,9 +355,12 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
h: NonEmptyList<H>,
i: NonEmptyList<I>,
map: (A, B, C, D, E, F, G, H, I) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, i.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, i.all, map))
}

@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <B, C, D, E, F, G, H, I, J, Z> zip(
b: NonEmptyList<B>,
c: NonEmptyList<C>,
Expand All @@ -332,8 +372,10 @@ public value class NonEmptyList<out A> @PublishedApi internal constructor(
i: NonEmptyList<I>,
j: NonEmptyList<J>,
map: (A, B, C, D, E, F, G, H, I, J) -> Z
): NonEmptyList<Z> =
NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, i.all, j.all, map))
): NonEmptyList<Z> {
contract { callsInPlace(map, InvocationKind.AT_LEAST_ONCE) }
return NonEmptyList(all.zip(b.all, c.all, d.all, e.all, f.all, g.all, h.all, i.all, j.all, map))
}
}

@JvmName("nonEmptyListOf")
Expand Down Expand Up @@ -368,10 +410,13 @@ public inline fun <T : Comparable<T>> NonEmptyList<T>.max(): T =
public fun <A, B> NonEmptyList<Pair<A, B>>.unzip(): Pair<NonEmptyList<A>, NonEmptyList<B>> =
this.unzip(::identity)

public fun <A, B, C> NonEmptyList<C>.unzip(f: (C) -> Pair<A, B>): Pair<NonEmptyList<A>, NonEmptyList<B>> =
map(f).stdlibUnzip().let { (l1, l2) ->
@Suppress("LEAKED_IN_PLACE_LAMBDA", "WRONG_INVOCATION_KIND")
public inline fun <A, B, C> NonEmptyList<C>.unzip(f: (C) -> Pair<A, B>): Pair<NonEmptyList<A>, NonEmptyList<B>> {
contract { callsInPlace(f, InvocationKind.AT_LEAST_ONCE) }
return map(f).stdlibUnzip().let { (l1, l2) ->
l1.toNonEmptyListOrNull()!! to l2.toNonEmptyListOrNull()!!
}
}

public inline fun <E, A, B> NonEmptyList<A>.mapOrAccumulate(
combine: (E, E) -> E,
Expand Down
Loading

0 comments on commit 8562c9b

Please sign in to comment.