Skip to content

Commit e86bb45

Browse files
Amxxernestognw
andauthored
Add a Math.inv function that inverse a number in Z/nZ (#4839)
Co-authored-by: ernestognw <ernestognw@gmail.com>
1 parent e5f02bc commit e86bb45

File tree

4 files changed

+139
-4
lines changed

4 files changed

+139
-4
lines changed

.changeset/cool-mangos-compare.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
---
2+
'openzeppelin-solidity': minor
3+
---
4+
5+
`Math`: add an `invMod` function to get the modular multiplicative inverse of a number in Z/nZ.

contracts/utils/math/Math.sol

Lines changed: 61 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,10 @@ library Math {
121121
}
122122

123123
/**
124-
* @notice Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
124+
* @dev Calculates floor(x * y / denominator) with full precision. Throws if result overflows a uint256 or
125125
* denominator == 0.
126-
* @dev Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
126+
*
127+
* Original credit to Remco Bloemen under MIT license (https://xn--2-umb.com/21/muldiv) with further edits by
127128
* Uniswap Labs also under MIT license.
128129
*/
129130
function mulDiv(uint256 x, uint256 y, uint256 denominator) internal pure returns (uint256 result) {
@@ -208,7 +209,7 @@ library Math {
208209
}
209210

210211
/**
211-
* @notice Calculates x * y / denominator with full precision, following the selected rounding direction.
212+
* @dev Calculates x * y / denominator with full precision, following the selected rounding direction.
212213
*/
213214
function mulDiv(uint256 x, uint256 y, uint256 denominator, Rounding rounding) internal pure returns (uint256) {
214215
uint256 result = mulDiv(x, y, denominator);
@@ -218,6 +219,62 @@ library Math {
218219
return result;
219220
}
220221

222+
/**
223+
* @dev Calculate the modular multiplicative inverse of a number in Z/nZ.
224+
*
225+
* If n is a prime, then Z/nZ is a field. In that case all elements are inversible, expect 0.
226+
* If n is not a prime, then Z/nZ is not a field, and some elements might not be inversible.
227+
*
228+
* If the input value is not inversible, 0 is returned.
229+
*/
230+
function invMod(uint256 a, uint256 n) internal pure returns (uint256) {
231+
unchecked {
232+
if (n == 0) return 0;
233+
234+
// The inverse modulo is calculated using the Extended Euclidean Algorithm (iterative version)
235+
// Used to compute integers x and y such that: ax + ny = gcd(a, n).
236+
// When the gcd is 1, then the inverse of a modulo n exists and it's x.
237+
// ax + ny = 1
238+
// ax = 1 + (-y)n
239+
// ax ≡ 1 (mod n) # x is the inverse of a modulo n
240+
241+
// If the remainder is 0 the gcd is n right away.
242+
uint256 remainder = a % n;
243+
uint256 gcd = n;
244+
245+
// Therefore the initial coefficients are:
246+
// ax + ny = gcd(a, n) = n
247+
// 0a + 1n = n
248+
int256 x = 0;
249+
int256 y = 1;
250+
251+
while (remainder != 0) {
252+
uint256 quotient = gcd / remainder;
253+
254+
(gcd, remainder) = (
255+
// The old remainder is the next gcd to try.
256+
remainder,
257+
// Compute the next remainder.
258+
// Can't overflow given that (a % gcd) * (gcd // (a % gcd)) <= gcd
259+
// where gcd is at most n (capped to type(uint256).max)
260+
gcd - remainder * quotient
261+
);
262+
263+
(x, y) = (
264+
// Increment the coefficient of a.
265+
y,
266+
// Decrement the coefficient of n.
267+
// Can overflow, but the result is casted to uint256 so that the
268+
// next value of y is "wrapped around" to a value between 0 and n - 1.
269+
x - y * int256(quotient)
270+
);
271+
}
272+
273+
if (gcd != 1) return 0; // No inverse exists.
274+
return x < 0 ? (n - uint256(-x)) : uint256(x); // Wrap the result if it's negative.
275+
}
276+
}
277+
221278
/**
222279
* @dev Returns the square root of a number. If the number is not a perfect square, the value is rounded
223280
* towards zero.
@@ -258,7 +315,7 @@ library Math {
258315
}
259316

260317
/**
261-
* @notice Calculates sqrt(a), following the selected rounding direction.
318+
* @dev Calculates sqrt(a), following the selected rounding direction.
262319
*/
263320
function sqrt(uint256 a, Rounding rounding) internal pure returns (uint256) {
264321
unchecked {

test/utils/math/Math.t.sol

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,41 @@ contract MathTest is Test {
5555
return value * value < ref;
5656
}
5757

58+
// INV
59+
function testInvMod(uint256 value, uint256 p) public {
60+
_testInvMod(value, p, true);
61+
}
62+
63+
function testInvMod2(uint256 seed) public {
64+
uint256 p = 2; // prime
65+
_testInvMod(bound(seed, 1, p - 1), p, false);
66+
}
67+
68+
function testInvMod17(uint256 seed) public {
69+
uint256 p = 17; // prime
70+
_testInvMod(bound(seed, 1, p - 1), p, false);
71+
}
72+
73+
function testInvMod65537(uint256 seed) public {
74+
uint256 p = 65537; // prime
75+
_testInvMod(bound(seed, 1, p - 1), p, false);
76+
}
77+
78+
function testInvModP256(uint256 seed) public {
79+
uint256 p = 0xffffffff00000001000000000000000000000000ffffffffffffffffffffffff; // prime
80+
_testInvMod(bound(seed, 1, p - 1), p, false);
81+
}
82+
83+
function _testInvMod(uint256 value, uint256 p, bool allowZero) private {
84+
uint256 inverse = Math.invMod(value, p);
85+
if (inverse != 0) {
86+
assertEq(mulmod(value, inverse, p), 1);
87+
assertLt(inverse, p);
88+
} else {
89+
assertTrue(allowZero);
90+
}
91+
}
92+
5893
// LOG2
5994
function testLog2(uint256 input, uint8 r) public {
6095
Math.Rounding rounding = _asRounding(r);

test/utils/math/Math.test.js

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ const { PANIC_CODES } = require('@nomicfoundation/hardhat-chai-matchers/panic');
55

66
const { Rounding } = require('../../helpers/enums');
77
const { min, max } = require('../../helpers/math');
8+
const { randomArray, generators } = require('../../helpers/random');
89

910
const RoundingDown = [Rounding.Floor, Rounding.Trunc];
1011
const RoundingUp = [Rounding.Ceil, Rounding.Expand];
@@ -298,6 +299,43 @@ describe('Math', function () {
298299
});
299300
});
300301

302+
describe('invMod', function () {
303+
for (const factors of [
304+
[0n],
305+
[1n],
306+
[2n],
307+
[17n],
308+
[65537n],
309+
[0xffffffff00000001000000000000000000000000ffffffffffffffffffffffffn],
310+
[3n, 5n],
311+
[3n, 7n],
312+
[47n, 53n],
313+
]) {
314+
const p = factors.reduce((acc, f) => acc * f, 1n);
315+
316+
describe(`using p=${p} which is ${p > 1 && factors.length > 1 ? 'not ' : ''}a prime`, function () {
317+
it('trying to inverse 0 returns 0', async function () {
318+
expect(await this.mock.$invMod(0, p)).to.equal(0n);
319+
expect(await this.mock.$invMod(p, p)).to.equal(0n); // p is 0 mod p
320+
});
321+
322+
if (p != 0) {
323+
for (const value of randomArray(generators.uint256, 16)) {
324+
const isInversible = factors.every(f => value % f);
325+
it(`trying to inverse ${value}`, async function () {
326+
const result = await this.mock.$invMod(value, p);
327+
if (isInversible) {
328+
expect((value * result) % p).to.equal(1n);
329+
} else {
330+
expect(result).to.equal(0n);
331+
}
332+
});
333+
}
334+
}
335+
});
336+
}
337+
});
338+
301339
describe('sqrt', function () {
302340
it('rounds down', async function () {
303341
for (const rounding of RoundingDown) {

0 commit comments

Comments
 (0)