Skip to content

Commit

Permalink
Add sqrt for math (#3242)
Browse files Browse the repository at this point in the history
  • Loading branch information
jjz authored Jun 7, 2022
1 parent 3aa7ff7 commit 3ac4add
Show file tree
Hide file tree
Showing 4 changed files with 113 additions and 0 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
* `ERC20FlashMint`: Add customizable flash fee receiver. ([#3327](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3327))
* `ERC20TokenizedVault`: add an extension of `ERC20` that implements the ERC4626 Tokenized Vault Standard. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171))
* `Math`: add a `mulDiv` function that can round the result either up or down. ([#3171](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3171))
* `Math`: Add a `sqrt` function to compute square roots of integers, rounding either up or down. ([#3242](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3242))
* `Strings`: add a new overloaded function `toHexString` that converts an `address` with fixed length of 20 bytes to its not checksummed ASCII `string` hexadecimal representation. ([#3403](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3403))
* `EnumerableMap`: add new `UintToUintMap` map type. ([#3338](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3338))
* `EnumerableMap`: add new `Bytes32ToUintMap` map type. ([#3416](https://github.com/OpenZeppelin/openzeppelin-contracts/pull/3416))
Expand Down
4 changes: 4 additions & 0 deletions contracts/mocks/MathMock.sol
Original file line number Diff line number Diff line change
Expand Up @@ -29,4 +29,8 @@ contract MathMock {
) public pure returns (uint256) {
return Math.mulDiv(a, b, denominator, direction);
}

function sqrt(uint256 a, Math.Rounding direction) public pure returns (uint256) {
return Math.sqrt(a, direction);
}
}
74 changes: 74 additions & 0 deletions contracts/utils/math/Math.sol
Original file line number Diff line number Diff line change
Expand Up @@ -149,4 +149,78 @@ library Math {
}
return result;
}

/**
* @dev Returns the square root of a number. It the number is not a perfect square, the value is rounded down.
*
* Inspired by Henry S. Warren, Jr.'s "Hacker's Delight" (Chapter 11).
*/
function sqrt(uint256 a) internal pure returns (uint256) {
if (a == 0) {
return 0;
}

// For our first guess, we get the biggest power of 2 which is smaller than the square root of the target.
// We know that the "msb" (most significant bit) of our target number `a` is a power of 2 such that we have
// `msb(a) <= a < 2*msb(a)`.
// We also know that `k`, the position of the most significant bit, is such that `msb(a) = 2**k`.
// This gives `2**k < a <= 2**(k+1)` → `2**(k/2) <= sqrt(a) < 2 ** (k/2+1)`.
// Using an algorithm similar to the msb conmputation, we are able to compute `result = 2**(k/2)` which is a
// good first aproximation of `sqrt(a)` with at least 1 correct bit.
uint256 result = 1;
uint256 x = a;
if (x >> 128 > 0) {
x >>= 128;
result <<= 64;
}
if (x >> 64 > 0) {
x >>= 64;
result <<= 32;
}
if (x >> 32 > 0) {
x >>= 32;
result <<= 16;
}
if (x >> 16 > 0) {
x >>= 16;
result <<= 8;
}
if (x >> 8 > 0) {
x >>= 8;
result <<= 4;
}
if (x >> 4 > 0) {
x >>= 4;
result <<= 2;
}
if (x >> 2 > 0) {
result <<= 1;
}

// At this point `result` is an estimation with one bit of precision. We know the true value is a uint128,
// since it is the square root of a uint256. Newton's method converges quadratically (precision doubles at
// every iteration). We thus need at most 7 iteration to turn our partial result with one bit of precision
// into the expected uint128 result.
unchecked {
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
result = (result + a / result) >> 1;
return min(result, a / result);
}
}

/**
* @notice Calculates sqrt(a), following the selected rounding direction.
*/
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
uint256 result = sqrt(a);
if (rounding == Rounding.Up && result * result < a) {
result += 1;
}
return result;
}
}
34 changes: 34 additions & 0 deletions test/utils/math/Math.test.js
Original file line number Diff line number Diff line change
Expand Up @@ -182,4 +182,38 @@ contract('Math', function (accounts) {
});
});
});

describe('sqrt', function () {
it('rounds down', async function () {
expect(await this.math.sqrt(new BN('0'), Rounding.Down)).to.be.bignumber.equal('0');
expect(await this.math.sqrt(new BN('1'), Rounding.Down)).to.be.bignumber.equal('1');
expect(await this.math.sqrt(new BN('2'), Rounding.Down)).to.be.bignumber.equal('1');
expect(await this.math.sqrt(new BN('3'), Rounding.Down)).to.be.bignumber.equal('1');
expect(await this.math.sqrt(new BN('4'), Rounding.Down)).to.be.bignumber.equal('2');
expect(await this.math.sqrt(new BN('144'), Rounding.Down)).to.be.bignumber.equal('12');
expect(await this.math.sqrt(new BN('999999'), Rounding.Down)).to.be.bignumber.equal('999');
expect(await this.math.sqrt(new BN('1000000'), Rounding.Down)).to.be.bignumber.equal('1000');
expect(await this.math.sqrt(new BN('1000001'), Rounding.Down)).to.be.bignumber.equal('1000');
expect(await this.math.sqrt(new BN('1002000'), Rounding.Down)).to.be.bignumber.equal('1000');
expect(await this.math.sqrt(new BN('1002001'), Rounding.Down)).to.be.bignumber.equal('1001');
expect(await this.math.sqrt(MAX_UINT256, Rounding.Down))
.to.be.bignumber.equal('340282366920938463463374607431768211455');
});

it('rounds up', async function () {
expect(await this.math.sqrt(new BN('0'), Rounding.Up)).to.be.bignumber.equal('0');
expect(await this.math.sqrt(new BN('1'), Rounding.Up)).to.be.bignumber.equal('1');
expect(await this.math.sqrt(new BN('2'), Rounding.Up)).to.be.bignumber.equal('2');
expect(await this.math.sqrt(new BN('3'), Rounding.Up)).to.be.bignumber.equal('2');
expect(await this.math.sqrt(new BN('4'), Rounding.Up)).to.be.bignumber.equal('2');
expect(await this.math.sqrt(new BN('144'), Rounding.Up)).to.be.bignumber.equal('12');
expect(await this.math.sqrt(new BN('999999'), Rounding.Up)).to.be.bignumber.equal('1000');
expect(await this.math.sqrt(new BN('1000000'), Rounding.Up)).to.be.bignumber.equal('1000');
expect(await this.math.sqrt(new BN('1000001'), Rounding.Up)).to.be.bignumber.equal('1001');
expect(await this.math.sqrt(new BN('1002000'), Rounding.Up)).to.be.bignumber.equal('1001');
expect(await this.math.sqrt(new BN('1002001'), Rounding.Up)).to.be.bignumber.equal('1001');
expect(await this.math.sqrt(MAX_UINT256, Rounding.Up))
.to.be.bignumber.equal('340282366920938463463374607431768211456');
});
});
});

0 comments on commit 3ac4add

Please sign in to comment.