Skip to content

[resubmit] Fix bug of FastReducer used in BigInteger.ModPow #55122

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Diagnostics;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Numerics
{
Expand All @@ -20,25 +22,12 @@ private readonly ref struct FastReducer
private readonly Span<uint> _q1;
private readonly Span<uint> _q2;

public FastReducer(ReadOnlySpan<uint> modulus, Span<uint> r, Span<uint> mu, Span<uint> q1, Span<uint> q2)
public FastReducer(FastReducerConstructorHelper helper)
{
Debug.Assert(!modulus.IsEmpty);
Debug.Assert(r.Length == modulus.Length * 2 + 1);
Debug.Assert(mu.Length == r.Length - modulus.Length + 1);
Debug.Assert(q1.Length == modulus.Length * 2 + 2);
Debug.Assert(q2.Length == modulus.Length * 2 + 2);

// Let r = 4^k, with 2^k > m
r[r.Length - 1] = 1;

// Let mu = 4^k / m
Divide(r, modulus, mu);
_modulus = modulus;

_q1 = q1;
_q2 = q2;

_mu = mu.Slice(0, ActualLength(mu));
_modulus = helper.Modulus;
_mu = helper.Mu;
_q1 = helper.Q1;
_q2 = helper.Q2;
}

public int Reduce(Span<uint> value)
Expand All @@ -49,16 +38,17 @@ public int Reduce(Span<uint> value)
if (value.Length < _modulus.Length)
return value.Length;

// Let q1 = v/2^(k-1) * mu
// Let q1 = v/2^(k-32) * mu
_q1.Clear();
int l1 = DivMul(value, _mu, _q1, _modulus.Length - 1);

// Let q2 = q1/2^(k+1) * m
// Let q2 = q1/2^(k+32) * m
_q2.Clear();
int l2 = DivMul(_q1.Slice(0, l1), _modulus, _q2, _modulus.Length + 1);

// Let v = (v - q2) % 2^(k+1) - i*m
var length = SubMod(value, _q2.Slice(0, l2), _modulus, _modulus.Length + 1);
// Let v = (v - q2) % 2^k
// while m <= v: Let v = v - m
var length = SubMod(value, _q2.Slice(0, l2), _modulus, _modulus.Length);
value = value.Slice(length);
value.Clear();

Expand All @@ -75,6 +65,10 @@ private static int DivMul(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Spa
// but skips the first k limbs of left, which is equivalent to
// preceding division by 2^(32*k). To spare memory allocations
// we write the result to an already allocated memory.
// Note that the k used here has different scale from the k used
// in the description of barrett reduction.
// The former refers to the number of elements in the array,
// while the latter refers to the number of bits.

if (left.Length > k)
{
Expand All @@ -101,17 +95,23 @@ private static int DivMul(ReadOnlySpan<uint> left, ReadOnlySpan<uint> right, Spa

private static int SubMod(Span<uint> left, ReadOnlySpan<uint> right, ReadOnlySpan<uint> modulus, int k)
{
Debug.Assert(left.Length >= k);

// Executes the subtraction algorithm for left and right,
// but considers only the first k limbs, which is equivalent to
// preceding reduction by 2^(32*k). Furthermore, if left is
// still greater than modulus, further subtractions are used.
// Note that the k used here has different scale from the k used
// in the description of barrett reduction.
// The former refers to the number of elements in the array,
// while the latter refers to the number of bits.

if (left.Length > k)
left = left.Slice(0, k);
if (right.Length > k)
right = right.Slice(0, k);

SubtractSelf(left, right);
OverflowableSubtractSelf(left, right);
left = left.Slice(0, ActualLength(left));

while (Compare(left, modulus) >= 0)
Expand All @@ -122,6 +122,79 @@ private static int SubMod(Span<uint> left, ReadOnlySpan<uint> right, ReadOnlySpa

return left.Length;
}

private static void OverflowableSubtractSelf(Span<uint> left, ReadOnlySpan<uint> right)
{
Debug.Assert(left.Length >= right.Length);

int i = 0;
long carry = 0L;

// Switching to managed references helps eliminating
// index bounds check...
ref uint leftPtr = ref MemoryMarshal.GetReference(left);

// Executes the "grammar-school" algorithm for computing z = a - b.
// We're writing the result directly to a and
// stop execution, if we're out of b.

for (; i < right.Length; i++)
{
long digit = (Unsafe.Add(ref leftPtr, i) + carry) - right[i];
Unsafe.Add(ref leftPtr, i) = unchecked((uint)digit);
carry = digit >> 32;
}
for (; carry != 0 && i < left.Length; i++)
{
long digit = left[i] + carry;
left[i] = (uint)digit;
carry = digit >> 32;
}
}
}

// Helper for constructor of FastReducer.
// need to add q1 and q2 after constructing the FastReducer, but we
// can't do it with the FastReducer structure itself because it's
// a read-only structure.
private ref struct FastReducerConstructorHelper
{
internal ReadOnlySpan<uint> Modulus;
internal ReadOnlySpan<uint> Mu;
internal Span<uint> Q1;
internal Span<uint> Q2;

public FastReducerConstructorHelper(ReadOnlySpan<uint> modulus, Span<uint> r, Span<uint> mu)
{
Debug.Assert(!modulus.IsEmpty);
Debug.Assert(r.Length == modulus.Length * 2 + 1);
Debug.Assert(mu.Length == r.Length - modulus.Length + 1);

// Let r = 2^(2k), with 2^k > m and k % 32 = 0
r[r.Length - 1] = 1;

// Let mu = r / m
Divide(r, modulus, mu);
Modulus = modulus;

Mu = mu.Slice(0, ActualLength(mu));
Q1 = default;
Q2 = default;
}

public int GetMuLength()
{
return Mu.Length;
}

public void AddQs(Span<uint> q1, Span<uint> q2)
{
Debug.Assert(q1.Length == Mu.Length + Modulus.Length + 1);
Debug.Assert(q2.Length == Mu.Length + Modulus.Length);

Q1 = q1;
Q2 = q2;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -323,23 +323,30 @@ stackalloc uint[StackAllocThreshold]
: muFromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
mu.Clear();

size = modulus.Length * 2 + 2;
FastReducerConstructorHelper helper = new FastReducerConstructorHelper(modulus, r, mu);

if (rFromPool != null)
ArrayPool<uint>.Shared.Return(rFromPool);

int muLength = helper.GetMuLength();

size = muLength + modulus.Length + 1;
uint[]? q1FromPool = null;
Span<uint> q1 = ((uint)size <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: q1FromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
q1.Clear();

size = muLength + modulus.Length;
uint[]? q2FromPool = null;
Span<uint> q2 = ((uint)size <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: q2FromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
q2.Clear();

FastReducer reducer = new FastReducer(modulus, r, mu, q1, q2);
helper.AddQs(q1, q2);

if (rFromPool != null)
ArrayPool<uint>.Shared.Return(rFromPool);
FastReducer reducer = new FastReducer(helper);

PowCore(value, valueLength, power, reducer, bits, 1, temp).CopyTo(bits);

Expand Down Expand Up @@ -379,23 +386,30 @@ stackalloc uint[StackAllocThreshold]
: muFromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
mu.Clear();

size = modulus.Length * 2 + 2;
FastReducerConstructorHelper helper = new FastReducerConstructorHelper(modulus, r, mu);

if (rFromPool != null)
ArrayPool<uint>.Shared.Return(rFromPool);

int muLength = helper.GetMuLength();

size = muLength + modulus.Length + 1;
uint[]? q1FromPool = null;
Span<uint> q1 = ((uint)size <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: q1FromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
q1.Clear();

size = muLength + modulus.Length;
uint[]? q2FromPool = null;
Span<uint> q2 = ((uint)size <= StackAllocThreshold ?
stackalloc uint[StackAllocThreshold]
: q2FromPool = ArrayPool<uint>.Shared.Rent(size)).Slice(0, size);
q2.Clear();

FastReducer reducer = new FastReducer(modulus, r, mu, q1, q2);
helper.AddQs(q1, q2);

if (rFromPool != null)
ArrayPool<uint>.Shared.Return(rFromPool);
FastReducer reducer = new FastReducer(helper);

PowCore(value, valueLength, power, reducer, bits, 1, temp).CopyTo(bits);

Expand Down
49 changes: 49 additions & 0 deletions src/libraries/System.Runtime.Numerics/tests/BigInteger/modpow.cs
Original file line number Diff line number Diff line change
Expand Up @@ -286,6 +286,55 @@ public static void ModPowBoundary()
VerifyModPowString(Math.Pow(2, 35) + " " + Math.Pow(2, 33) + " 2 tModPow");
}

[Fact]
[OuterLoop]
public static void ModPowFastReducerBoundary()
{
BigIntTools.Utils.RunWithFakeThreshold("ReducerThreshold", 8, () =>
{
byte[] tempByteArray1 = new byte[40];
byte[] tempByteArray2 = new byte[40];
byte[] tempByteArray3 = new byte[40];
byte[] tempByteArray4 = new byte[40];
byte[] tempByteArray5 = new byte[40];
byte[] tempByteArray6 = new byte[40];

for (int i = 0; i < 32; i++)
{
tempByteArray2[i] = 0xff;
}
tempByteArray3[0] = 1;
for (int i = 0; i < 36; i++)
{
tempByteArray4[i] = 0xff;
}
tempByteArray5[36] = 1;
tempByteArray6[0] = 1;
tempByteArray6[36] = 1;

for (int i = 32; i < 40; i++)
{
for (int j = 0; j < 8; j++)
{
tempByteArray1[i] = (byte)(1 << j);
tempByteArray2[i] |= (byte)(1 << j);
tempByteArray3[i] = (byte)(1 << j);
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray1) + "tModPow");
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray2) + "tModPow");
VerifyModPowString(Print(tempByteArray4) + "2 " + Print(tempByteArray3) + "tModPow");
VerifyModPowString(Print(tempByteArray5) + "2 " + Print(tempByteArray3) + "tModPow");
VerifyModPowString(Print(tempByteArray6) + "2 " + Print(tempByteArray3) + "tModPow");
}
tempByteArray1[i] = 0;
tempByteArray3[i] = 0;
}
});
}

private static void VerifyModPowString(string opstring)
{
StackCalc sc = new StackCalc(opstring);
Expand Down