Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Branchless ternary, min and max methods #4976

Merged
merged 20 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/spotty-falcons-explain.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'openzeppelin-solidity': patch
---

Improved Math.sol sqrt, log2, min and max methods
217 changes: 84 additions & 133 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -76,15 +76,23 @@ library Math {
/**
* @dev Returns the largest of two numbers.
*/
function max(uint256 a, uint256 b) internal pure returns (uint256) {
return a > b ? a : b;
function max(uint256 a, uint256 b) internal pure returns (uint256 result) {
assembly ("memory-safe") {
// gas efficient branchless max function:
// max(x,y) = a ^ ((a ^ b) * (a < b))
result := xor(a, mul(xor(a, b), lt(a, b)))
}
}

/**
* @dev Returns the smallest of two numbers.
*/
function min(uint256 a, uint256 b) internal pure returns (uint256) {
return a < b ? a : b;
function min(uint256 a, uint256 b) internal pure returns (uint256 result) {
assembly {
// gas efficient branchless min function:
// min(a,b) = b ^ ((a ^ b) * (a < b))
result := xor(b, mul(xor(a, b), lt(a, b)))
}
}

/**
Expand Down Expand Up @@ -388,109 +396,56 @@ library Math {
* This method is based on Newton's method for computing square roots; the algorithm is restricted to only
* using integer operations.
*/
function sqrt(uint256 a) internal pure returns (uint256) {
unchecked {
// Take care of easy edge cases when a == 0 or a == 1
if (a <= 1) {
return a;
}

// In this function, we use Newton's method to get a root of `f(x) := x² - a`. It involves building a
// sequence x_n that converges toward sqrt(a). For each iteration x_n, we also define the error between
// the current value as `ε_n = | x_n - sqrt(a) |`.
//
// For our first estimation, we consider `e` the smallest power of 2 which is bigger than the square root
// of the target. (i.e. `2**(e-1) ≤ sqrt(a) < 2**e`). We know that `e ≤ 128` because `(2¹²⁸)² = 2²⁵⁶` is
// bigger than any uint256.
//
// By noticing that
// `2**(e-1) ≤ sqrt(a) < 2**e → (2**(e-1))² ≤ a < (2**e)² → 2**(2*e-2) ≤ a < 2**(2*e)`
// we can deduce that `e - 1` is `log2(a) / 2`. We can thus compute `x_n = 2**(e-1)` using a method similar
// to the msb function.
uint256 aa = a;
uint256 xn = 1;

if (aa >= (1 << 128)) {
aa >>= 128;
xn <<= 64;
}
if (aa >= (1 << 64)) {
aa >>= 64;
xn <<= 32;
}
if (aa >= (1 << 32)) {
aa >>= 32;
xn <<= 16;
}
if (aa >= (1 << 16)) {
aa >>= 16;
xn <<= 8;
}
if (aa >= (1 << 8)) {
aa >>= 8;
xn <<= 4;
}
if (aa >= (1 << 4)) {
aa >>= 4;
xn <<= 2;
}
if (aa >= (1 << 2)) {
xn <<= 1;
}

// We now have x_n such that `x_n = 2**(e-1) ≤ sqrt(a) < 2**e = 2 * x_n`. This implies ε_n ≤ 2**(e-1).
//
// We can refine our estimation by noticing that the middle of that interval minimizes the error.
// If we move x_n to equal 2**(e-1) + 2**(e-2), then we reduce the error to ε_n ≤ 2**(e-2).
// This is going to be our x_0 (and ε_0)
xn = (3 * xn) >> 1; // ε_0 := | x_0 - sqrt(a) | ≤ 2**(e-2)

// From here, Newton's method give us:
// x_{n+1} = (x_n + a / x_n) / 2
//
// One should note that:
// x_{n+1}² - a = ((x_n + a / x_n) / 2)² - a
// = ((x_n² + a) / (2 * x_n))² - a
// = (x_n⁴ + 2 * a * x_n² + a²) / (4 * x_n²) - a
// = (x_n⁴ + 2 * a * x_n² + a² - 4 * a * x_n²) / (4 * x_n²)
// = (x_n⁴ - 2 * a * x_n² + a²) / (4 * x_n²)
// = (x_n² - a)² / (2 * x_n)²
// = ((x_n² - a) / (2 * x_n))²
// ≥ 0
// Which proves that for all n ≥ 1, sqrt(a) ≤ x_n
//
// This gives us the proof of quadratic convergence of the sequence:
// ε_{n+1} = | x_{n+1} - sqrt(a) |
// = | (x_n + a / x_n) / 2 - sqrt(a) |
// = | (x_n² + a - 2*x_n*sqrt(a)) / (2 * x_n) |
// = | (x_n - sqrt(a))² / (2 * x_n) |
// = | ε_n² / (2 * x_n) |
// = ε_n² / | (2 * x_n) |
//
// For the first iteration, we have a special case where x_0 is known:
// ε_1 = ε_0² / | (2 * x_0) |
// ≤ (2**(e-2))² / (2 * (2**(e-1) + 2**(e-2)))
// ≤ 2**(2*e-4) / (3 * 2**(e-1))
// ≤ 2**(e-3) / 3
// ≤ 2**(e-3-log2(3))
// ≤ 2**(e-4.5)
function sqrt(uint256 a) internal pure returns (uint256 xn) {
assembly {
// First we approximate the square root by calculate xn = 2 ** (log(x) / 2)
// then we need less iterations of Newton's method to find the result.
//
// For the following iterations, we use the fact that, 2**(e-1) ≤ sqrt(a) ≤ x_n:
// ε_{n+1} = ε_n² / | (2 * x_n) |
// ≤ (2**(e-k))² / (2 * 2**(e-1))
// ≤ 2**(2*e-2*k) / 2**e
// ≤ 2**(e-2*k)
xn = (xn + a / xn) >> 1; // ε_1 := | x_1 - sqrt(a) | ≤ 2**(e-4.5) -- special case, see above
xn = (xn + a / xn) >> 1; // ε_2 := | x_2 - sqrt(a) | ≤ 2**(e-9) -- general case with k = 4.5
xn = (xn + a / xn) >> 1; // ε_3 := | x_3 - sqrt(a) | ≤ 2**(e-18) -- general case with k = 9
xn = (xn + a / xn) >> 1; // ε_4 := | x_4 - sqrt(a) | ≤ 2**(e-36) -- general case with k = 18
xn = (xn + a / xn) >> 1; // ε_5 := | x_5 - sqrt(a) | ≤ 2**(e-72) -- general case with k = 36
xn = (xn + a / xn) >> 1; // ε_6 := | x_6 - sqrt(a) | ≤ 2**(e-144) -- general case with k = 72

// Because e ≤ 128 (as discussed during the first estimation phase), we know have reached a precision
// ε_6 ≤ 2**(e-144) < 1. Given we're operating on integers, then we can ensure that xn is now either
// sqrt(a) or sqrt(a) + 1.
return xn - SafeCast.toUint(xn > a / xn);
// For that we use an optimized log2 function that doesn't do the final approximation step,
// once doing `log(x) / 2` will discard the least significant bit anyway.
xn := shl(7, gt(a, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
let remainder := shr(xn, a)

let shift := shl(6, gt(remainder, 0xFFFFFFFFFFFFFFFF))
remainder := shr(shift, remainder)
xn := or(xn, shift)

shift := shl(5, gt(remainder, 0xFFFFFFFF))
remainder := shr(shift, remainder)
xn := or(xn, shift)

shift := shl(4, gt(remainder, 0xFFFF))
remainder := shr(shift, remainder)
xn := or(xn, shift)

shift := shl(3, gt(remainder, 0xFF))
remainder := shr(shift, remainder)
xn := or(xn, shift)

shift := shl(2, gt(remainder, 0x0F))
remainder := shr(shift, remainder)
xn := or(xn, shift)

shift := shl(1, gt(remainder, 0x03))
xn := or(xn, shift)

// now xn = log2(a), so we compute: xn = 2 ** (xn / 2)
// slither-disable-next-line incorrect-shift
xn := shl(shr(1, xn), 1)

// Newton's method
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))
xn := shr(1, add(xn, div(a, xn)))

// Once we round towards zero, we want the minimum between xn and a/xn
// we can safely assume that |xn - a/xn| is either 0 or 1, so we can easily compute the minimum as
// xn = xn - (xn > a/xn)
xn := sub(xn, gt(xn, div(a, xn)))
}
}

Expand All @@ -508,41 +463,37 @@ library Math {
* @dev Return the log in base 2 of a positive value rounded towards zero.
* Returns 0 if given 0.
*/
function log2(uint256 value) internal pure returns (uint256) {
uint256 result = 0;
uint256 exp;
unchecked {
exp = 128 * SafeCast.toUint(value > (1 << 128) - 1);
value >>= exp;
result += exp;
function log2(uint256 value) internal pure returns (uint256 result) {
assembly {
result := shl(7, gt(value, 0xFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFFF))
value := shr(result, value)

exp = 64 * SafeCast.toUint(value > (1 << 64) - 1);
value >>= exp;
result += exp;
let shift := shl(6, gt(value, 0xFFFFFFFFFFFFFFFF))
value := shr(shift, value)
result := or(result, shift)

exp = 32 * SafeCast.toUint(value > (1 << 32) - 1);
value >>= exp;
result += exp;
shift := shl(5, gt(value, 0xFFFFFFFF))
value := shr(shift, value)
result := or(result, shift)

exp = 16 * SafeCast.toUint(value > (1 << 16) - 1);
value >>= exp;
result += exp;
shift := shl(4, gt(value, 0xFFFF))
value := shr(shift, value)
result := or(result, shift)

exp = 8 * SafeCast.toUint(value > (1 << 8) - 1);
value >>= exp;
result += exp;
shift := shl(3, gt(value, 0xFF))
value := shr(shift, value)
result := or(result, shift)

exp = 4 * SafeCast.toUint(value > (1 << 4) - 1);
value >>= exp;
result += exp;
shift := shl(2, gt(value, 0x0F))
value := shr(shift, value)
result := or(result, shift)

exp = 2 * SafeCast.toUint(value > (1 << 2) - 1);
value >>= exp;
result += exp;
shift := shl(1, gt(value, 0x03))
value := shr(shift, value)
result := or(result, shift)

result += SafeCast.toUint(value > 1);
result := or(result, gt(value, 1))
}
return result;
}

/**
Expand Down
Loading