From f2d202b17068d6a6228a4f75b80c27b04b42823a Mon Sep 17 00:00:00 2001 From: Jose Rojas Chaves Date: Mon, 11 Oct 2021 13:51:05 -0600 Subject: [PATCH] Joserochh/ntt avoid memcpy (#72) Avoiding memcpy calls on NTT * Avoiding memcpys on Fwd NTT * Avoiding memcpy on INV NTT * Fixing some lines length * using only one out-of-place on first passes * Adding out-of-place for raddix 4 NTT * Adding gpg issue * Adding test cases for out place NTT * Removing commented code and testing GPG Signing * Fboemer/fix 32 bit invntt (#73) * Fix 32-bit AVX512DQ InvNT * Refactor NTT tests for better coverage * Added performance tips to README (#74) * small fix on test case (missed during merge) Co-authored-by: Fabian Boemer --- CONTRIBUTING.md | 17 ++ benchmark/bench-ntt.cpp | 167 +++++++++++++-- hexl/ntt/fwd-ntt-avx512.cpp | 77 ++++--- hexl/ntt/fwd-ntt-avx512.hpp | 2 +- hexl/ntt/inv-ntt-avx512.cpp | 50 +++-- hexl/ntt/inv-ntt-avx512.hpp | 2 +- hexl/ntt/ntt-default.hpp | 155 +++++++------- hexl/ntt/ntt-internal.cpp | 26 +-- hexl/ntt/ntt-internal.hpp | 20 +- hexl/ntt/ntt-radix-2.cpp | 317 +++++++++++++++++++++-------- hexl/ntt/ntt-radix-4.cpp | 394 +++++++++++++++++++++++++++--------- test/test-ntt-avx512.cpp | 39 ++-- test/test-ntt.cpp | 75 ++++++- 13 files changed, 969 insertions(+), 372 deletions(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 66ca4956..6d851ca9 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -9,3 +9,20 @@ cmake --build build --target check unittest to make sure the formatting checks and all unit tests pass. Please sign your commits before making a pull request. See instructions [here](https://docs.github.com/en/github/authenticating-to-github/managing-commit-signature-verification/signing-commits) for how to sign commits. + +### Known Issues ### + +* ```Executable `cpplint` not found``` + + Make sure you install cpplint: ```pip install cpplint```. + If you install `cpplint` locally, make sure to add it to your PATH. + +* ```/bin/sh: 1: pre-commit: not found``` + + Install `pre-commit`. More info at https://pre-commit.com/ + +* ``` + error: gpg failed to sign the data + fatal: failed to write commit object + ``` + Try adding ```export GPG_TTY=$(tty)``` to `~/.bashrc` diff --git a/benchmark/bench-ntt.cpp b/benchmark/bench-ntt.cpp index b10b902a..3ad08f07 100644 --- a/benchmark/bench-ntt.cpp +++ b/benchmark/bench-ntt.cpp @@ -21,7 +21,7 @@ namespace hexl { //================================================================= -static void BM_FwdNTTNativeRadix2(benchmark::State& state) { // NOLINT +static void BM_FwdNTTNativeRadix2InPlace(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; @@ -30,19 +30,43 @@ static void BM_FwdNTTNativeRadix2(benchmark::State& state) { // NOLINT for (auto _ : state) { ForwardTransformToBitReverseRadix2( - input.data(), ntt_size, modulus, ntt.GetRootOfUnityPowers().data(), + input.data(), input.data(), ntt_size, modulus, + ntt.GetRootOfUnityPowers().data(), ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); } } -BENCHMARK(BM_FwdNTTNativeRadix2) +BENCHMARK(BM_FwdNTTNativeRadix2InPlace) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) ->Args({16384}); //================================================================= -static void BM_FwdNTTNativeRadix4(benchmark::State& state) { // NOLINT +static void BM_FwdNTTNativeRadix2Copy(benchmark::State& state) { // NOLINT + size_t ntt_size = state.range(0); + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + AlignedVector64 output(ntt_size, 1); + NTT ntt(ntt_size, modulus); + + for (auto _ : state) { + ForwardTransformToBitReverseRadix2( + output.data(), input.data(), ntt_size, modulus, + ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); + } +} + +BENCHMARK(BM_FwdNTTNativeRadix2Copy) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); +//================================================================= + +static void BM_FwdNTTNativeRadix4InPlace(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; @@ -51,18 +75,43 @@ static void BM_FwdNTTNativeRadix4(benchmark::State& state) { // NOLINT for (auto _ : state) { ForwardTransformToBitReverseRadix4( - input.data(), ntt_size, modulus, ntt.GetRootOfUnityPowers().data(), + input.data(), input.data(), ntt_size, modulus, + ntt.GetRootOfUnityPowers().data(), ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); } } -BENCHMARK(BM_FwdNTTNativeRadix4) +BENCHMARK(BM_FwdNTTNativeRadix4InPlace) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) ->Args({16384}); //================================================================= +static void BM_FwdNTTNativeRadix4Copy(benchmark::State& state) { // NOLINT + size_t ntt_size = state.range(0); + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + AlignedVector64 output(ntt_size, 1); + NTT ntt(ntt_size, modulus); + + for (auto _ : state) { + ForwardTransformToBitReverseRadix4( + output.data(), input.data(), ntt_size, modulus, + ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); + } +} + +BENCHMARK(BM_FwdNTTNativeRadix4Copy) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); + +//================================================================= + #ifdef HEXL_HAS_AVX512IFMA // state[0] is the degree static void BM_FwdNTT_AVX512IFMA(benchmark::State& state) { // NOLINT @@ -80,7 +129,7 @@ static void BM_FwdNTT_AVX512IFMA(benchmark::State& state) { // NOLINT for (auto _ : state) { ForwardTransformToBitReverseAVX512( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 2, 1); } } @@ -109,7 +158,7 @@ static void BM_FwdNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT for (auto _ : state) { ForwardTransformToBitReverseAVX512( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 4, 4); } } @@ -144,7 +193,7 @@ static void BM_FwdNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT ntt.GetAVX512Precon32RootOfUnityPowers(); for (auto _ : state) { ForwardTransformToBitReverseAVX512<32>( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 4, output_mod_factor); } } @@ -175,7 +224,7 @@ static void BM_FwdNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT ntt.GetAVX512Precon64RootOfUnityPowers(); for (auto _ : state) { ForwardTransformToBitReverseAVX512<64>( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 4, output_mod_factor); } } @@ -234,6 +283,28 @@ BENCHMARK(BM_FwdNTTCopy) ->Args({4096}) ->Args({16384}); +//================================================================= + +static void BM_InvNTTInPlace(benchmark::State& state) { // NOLINT + size_t ntt_size = state.range(0); + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + NTT ntt(ntt_size, modulus); + + for (auto _ : state) { + ntt.ComputeInverse(input.data(), input.data(), 2, 1); + } +} + +BENCHMARK(BM_InvNTTInPlace) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); + +//================================================================= + // state[0] is the degree static void BM_InvNTTCopy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); @@ -258,7 +329,7 @@ BENCHMARK(BM_InvNTTCopy) // Inverse transforms -static void BM_InvNTTNativeRadix2(benchmark::State& state) { // NOLINT +static void BM_InvNTTNativeRadix2InPlace(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; @@ -269,13 +340,13 @@ static void BM_InvNTTNativeRadix2(benchmark::State& state) { // NOLINT const AlignedVector64 precon_root_of_unity = ntt.GetPrecon64InvRootOfUnityPowers(); for (auto _ : state) { - InverseTransformFromBitReverseRadix2(input.data(), ntt_size, modulus, - root_of_unity.data(), + InverseTransformFromBitReverseRadix2(input.data(), input.data(), ntt_size, + modulus, root_of_unity.data(), precon_root_of_unity.data(), 1, 1); } } -BENCHMARK(BM_InvNTTNativeRadix2) +BENCHMARK(BM_InvNTTNativeRadix2InPlace) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) @@ -283,24 +354,76 @@ BENCHMARK(BM_InvNTTNativeRadix2) //================================================================= -static void BM_InvNTTNativeRadix4(benchmark::State& state) { // NOLINT +static void BM_InvNTTNativeRadix2Copy(benchmark::State& state) { // NOLINT size_t ntt_size = state.range(0); size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + AlignedVector64 output(ntt_size, 1); + NTT ntt(ntt_size, modulus); + + const AlignedVector64 root_of_unity = ntt.GetInvRootOfUnityPowers(); + const AlignedVector64 precon_root_of_unity = + ntt.GetPrecon64InvRootOfUnityPowers(); + for (auto _ : state) { + InverseTransformFromBitReverseRadix2(output.data(), input.data(), ntt_size, + modulus, root_of_unity.data(), + precon_root_of_unity.data(), 1, 1); + } +} + +BENCHMARK(BM_InvNTTNativeRadix2Copy) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); + +//================================================================= + +static void BM_InvNTTNativeRadix4InPlace(benchmark::State& state) { // NOLINT + size_t ntt_size = state.range(0); + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + NTT ntt(ntt_size, modulus); + + const AlignedVector64 root_of_unity = ntt.GetInvRootOfUnityPowers(); + const AlignedVector64 precon_root_of_unity = + ntt.GetPrecon64InvRootOfUnityPowers(); + for (auto _ : state) { + InverseTransformFromBitReverseRadix4(input.data(), input.data(), ntt_size, + modulus, root_of_unity.data(), + precon_root_of_unity.data(), 1, 1); + } +} + +BENCHMARK(BM_InvNTTNativeRadix4InPlace) + ->Unit(benchmark::kMicrosecond) + ->Args({1024}) + ->Args({4096}) + ->Args({16384}); + +//================================================================= + +static void BM_InvNTTNativeRadix4Copy(benchmark::State& state) { // NOLINT + size_t ntt_size = state.range(0); + size_t modulus = GeneratePrimes(1, 45, true, ntt_size)[0]; + + auto input = GenerateInsecureUniformRandomValues(ntt_size, 0, modulus); + AlignedVector64 output(ntt_size, 1); NTT ntt(ntt_size, modulus); const AlignedVector64 root_of_unity = ntt.GetInvRootOfUnityPowers(); const AlignedVector64 precon_root_of_unity = ntt.GetPrecon64InvRootOfUnityPowers(); for (auto _ : state) { - InverseTransformFromBitReverseRadix4(input.data(), ntt_size, modulus, - root_of_unity.data(), + InverseTransformFromBitReverseRadix4(output.data(), input.data(), ntt_size, + modulus, root_of_unity.data(), precon_root_of_unity.data(), 1, 1); } } -BENCHMARK(BM_InvNTTNativeRadix4) +BENCHMARK(BM_InvNTTNativeRadix4Copy) ->Unit(benchmark::kMicrosecond) ->Args({1024}) ->Args({4096}) @@ -322,7 +445,7 @@ static void BM_InvNTT_AVX512IFMA(benchmark::State& state) { // NOLINT ntt.GetPrecon52InvRootOfUnityPowers(); for (auto _ : state) { InverseTransformFromBitReverseAVX512( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 1, 1); } } @@ -348,7 +471,7 @@ static void BM_InvNTT_AVX512IFMALazy(benchmark::State& state) { // NOLINT ntt.GetPrecon52InvRootOfUnityPowers(); for (auto _ : state) { InverseTransformFromBitReverseAVX512( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), 2, 2); } } @@ -379,7 +502,7 @@ static void BM_InvNTT_AVX512DQ_32(benchmark::State& state) { // NOLINT for (auto _ : state) { InverseTransformFromBitReverseAVX512<32>( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), output_mod_factor, output_mod_factor); } } @@ -407,7 +530,7 @@ static void BM_InvNTT_AVX512DQ_64(benchmark::State& state) { // NOLINT for (auto _ : state) { InverseTransformFromBitReverseAVX512( - input.data(), ntt_size, modulus, root_of_unity.data(), + input.data(), input.data(), ntt_size, modulus, root_of_unity.data(), precon_root_of_unity.data(), output_mod_factor, output_mod_factor); } } diff --git a/hexl/ntt/fwd-ntt-avx512.cpp b/hexl/ntt/fwd-ntt-avx512.cpp index 19571fd3..6776156a 100644 --- a/hexl/ntt/fwd-ntt-avx512.cpp +++ b/hexl/ntt/fwd-ntt-avx512.cpp @@ -3,6 +3,7 @@ #include "ntt/fwd-ntt-avx512.hpp" +#include #include #include @@ -18,7 +19,7 @@ namespace hexl { #ifdef HEXL_HAS_AVX512IFMA template void ForwardTransformToBitReverseAVX512( - uint64_t* operand, uint64_t degree, uint64_t mod, + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, @@ -27,14 +28,14 @@ template void ForwardTransformToBitReverseAVX512( #ifdef HEXL_HAS_AVX512DQ template void ForwardTransformToBitReverseAVX512<32>( - uint64_t* operand, uint64_t degree, uint64_t mod, + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, uint64_t recursion_half); template void ForwardTransformToBitReverseAVX512( - uint64_t* operand, uint64_t degree, uint64_t mod, + uint64_t* result, const uint64_t* operand, uint64_t degree, uint64_t mod, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, @@ -183,33 +184,47 @@ void FwdT4(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, } } +// Out-of-place implementation template -void FwdT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, - uint64_t t, uint64_t m, const uint64_t* W, +void FwdT8(uint64_t* result, const uint64_t* operand, __m512i v_neg_modulus, + __m512i v_twice_mod, uint64_t t, uint64_t m, const uint64_t* W, const uint64_t* W_precon) { size_t j1 = 0; HEXL_LOOP_UNROLL_4 for (size_t i = 0; i < m; i++) { - uint64_t* X = operand + j1; - uint64_t* Y = X + t; + // Referencing operand + const uint64_t* X_op = operand + j1; + const uint64_t* Y_op = X_op + t; + + const __m512i* v_X_op_pt = reinterpret_cast(X_op); + const __m512i* v_Y_op_pt = reinterpret_cast(Y_op); + + // Referencing result + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + __m512i* v_X_r_pt = reinterpret_cast<__m512i*>(X_r); + __m512i* v_Y_r_pt = reinterpret_cast<__m512i*>(Y_r); + + // Weights and weights' preconditions __m512i v_W = _mm512_set1_epi64(static_cast(*W++)); __m512i v_W_precon = _mm512_set1_epi64(static_cast(*W_precon++)); - __m512i* v_X_pt = reinterpret_cast<__m512i*>(X); - __m512i* v_Y_pt = reinterpret_cast<__m512i*>(Y); - // assume 8 | t for (size_t j = t / 8; j > 0; --j) { - __m512i v_X = _mm512_loadu_si512(v_X_pt); - __m512i v_Y = _mm512_loadu_si512(v_Y_pt); + __m512i v_X = _mm512_loadu_si512(v_X_op_pt); + __m512i v_Y = _mm512_loadu_si512(v_Y_op_pt); FwdButterfly(&v_X, &v_Y, v_W, v_W_precon, v_neg_modulus, v_twice_mod); - _mm512_storeu_si512(v_X_pt++, v_X); - _mm512_storeu_si512(v_Y_pt++, v_Y); + _mm512_storeu_si512(v_X_r_pt++, v_X); + _mm512_storeu_si512(v_Y_r_pt++, v_Y); + + // Increase operand pointers as well + v_X_op_pt++; + v_Y_op_pt++; } j1 += (t << 1); } @@ -217,7 +232,7 @@ void FwdT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, template void ForwardTransformToBitReverseAVX512( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, @@ -262,17 +277,23 @@ void ForwardTransformToBitReverseAVX512( size_t t = (n >> 1); size_t m = 1; size_t W_idx = (m << recursion_depth) + (recursion_half * m); + + // Copy for out-of-place in case m is <= base_ntt_size from start + if (result != operand) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + // First iteration assumes input in [0,p) if (m < (n >> 3)) { const uint64_t* W = &root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; if ((input_mod_factor <= 2) && (recursion_depth == 0)) { - FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, - W_precon); + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); } else { - FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, - W_precon); + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); } t >>= 1; @@ -282,8 +303,8 @@ void ForwardTransformToBitReverseAVX512( for (; m < (n >> 3); m <<= 1) { const uint64_t* W = &root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; - FwdT8(operand, v_neg_modulus, v_twice_mod, t, m, W, - W_precon); + FwdT8(result, result, v_neg_modulus, v_twice_mod, t, m, + W, W_precon); t >>= 1; W_idx <<= 1; } @@ -324,27 +345,27 @@ void ForwardTransformToBitReverseAVX512( size_t new_W_idx = compute_new_W_idx(W_idx); const uint64_t* W = &root_of_unity_powers[new_W_idx]; const uint64_t* W_precon = &precon_root_of_unity_powers[new_W_idx]; - FwdT4(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); + FwdT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); m <<= 1; W_idx <<= 1; new_W_idx = compute_new_W_idx(W_idx); W = &root_of_unity_powers[new_W_idx]; W_precon = &precon_root_of_unity_powers[new_W_idx]; - FwdT2(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); + FwdT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); m <<= 1; W_idx <<= 1; new_W_idx = compute_new_W_idx(W_idx); W = &root_of_unity_powers[new_W_idx]; W_precon = &precon_root_of_unity_powers[new_W_idx]; - FwdT1(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); + FwdT1(result, v_neg_modulus, v_twice_mod, m, W, W_precon); } if (output_mod_factor == 1) { // n power of two at least 8 => n divisible by 8 HEXL_CHECK(n % 8 == 0, "n " << n << " not a power of 2"); - __m512i* v_X_pt = reinterpret_cast<__m512i*>(operand); + __m512i* v_X_pt = reinterpret_cast<__m512i*>(result); for (size_t i = 0; i < n; i += 8) { __m512i v_X = _mm512_loadu_si512(v_X_pt); @@ -367,16 +388,16 @@ void ForwardTransformToBitReverseAVX512( const uint64_t* W = &root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_root_of_unity_powers[W_idx]; - FwdT8(operand, v_neg_modulus, v_twice_mod, t, 1, W, + FwdT8(result, operand, v_neg_modulus, v_twice_mod, t, 1, W, W_precon); ForwardTransformToBitReverseAVX512( - operand, n / 2, modulus, root_of_unity_powers, + result, result, n / 2, modulus, root_of_unity_powers, precon_root_of_unity_powers, input_mod_factor, output_mod_factor, recursion_depth + 1, recursion_half * 2); ForwardTransformToBitReverseAVX512( - &operand[n / 2], n / 2, modulus, root_of_unity_powers, + &result[n / 2], &result[n / 2], n / 2, modulus, root_of_unity_powers, precon_root_of_unity_powers, input_mod_factor, output_mod_factor, recursion_depth + 1, recursion_half * 2 + 1); } diff --git a/hexl/ntt/fwd-ntt-avx512.hpp b/hexl/ntt/fwd-ntt-avx512.hpp index e8c5a232..e2f25b01 100644 --- a/hexl/ntt/fwd-ntt-avx512.hpp +++ b/hexl/ntt/fwd-ntt-avx512.hpp @@ -34,7 +34,7 @@ namespace hexl { /// performance on larger transform sizes. template void ForwardTransformToBitReverseAVX512( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth = 0, diff --git a/hexl/ntt/inv-ntt-avx512.cpp b/hexl/ntt/inv-ntt-avx512.cpp index 4d00be50..a94d4e91 100644 --- a/hexl/ntt/inv-ntt-avx512.cpp +++ b/hexl/ntt/inv-ntt-avx512.cpp @@ -5,6 +5,7 @@ #include +#include #include #include @@ -20,8 +21,8 @@ namespace hexl { #ifdef HEXL_HAS_AVX512IFMA template void InverseTransformFromBitReverseAVX512( - uint64_t* operand, uint64_t degree, uint64_t modulus, - const uint64_t* inv_root_of_unity_powers, + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, uint64_t recursion_half); @@ -29,15 +30,15 @@ template void InverseTransformFromBitReverseAVX512( #ifdef HEXL_HAS_AVX512DQ template void InverseTransformFromBitReverseAVX512<32>( - uint64_t* operand, uint64_t degree, uint64_t modulus, - const uint64_t* inv_root_of_unity_powers, + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, uint64_t recursion_half); template void InverseTransformFromBitReverseAVX512( - uint64_t* operand, uint64_t degree, uint64_t modulus, - const uint64_t* inv_root_of_unity_powers, + uint64_t* result, const uint64_t* operand, uint64_t degree, + uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, uint64_t recursion_half); @@ -217,7 +218,7 @@ void InvT8(uint64_t* operand, __m512i v_neg_modulus, __m512i v_twice_mod, template void InverseTransformFromBitReverseAVX512( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth, @@ -256,16 +257,20 @@ void InverseTransformFromBitReverseAVX512( static const size_t base_ntt_size = 1024; if (n <= base_ntt_size) { // Perform breadth-first InvNTT + if (operand != result) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + // Extract t=1, t=2, t=4 loops separately { // t = 1 const uint64_t* W = &inv_root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; if ((input_mod_factor == 1) && (recursion_depth == 0)) { - InvT1(operand, v_neg_modulus, v_twice_mod, m, W, + InvT1(result, v_neg_modulus, v_twice_mod, m, W, W_precon); } else { - InvT1(operand, v_neg_modulus, v_twice_mod, m, W, + InvT1(result, v_neg_modulus, v_twice_mod, m, W, W_precon); } @@ -278,7 +283,7 @@ void InverseTransformFromBitReverseAVX512( // t = 2 W = &inv_root_of_unity_powers[W_idx]; W_precon = &precon_inv_root_of_unity_powers[W_idx]; - InvT2(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); + InvT2(result, v_neg_modulus, v_twice_mod, m, W, W_precon); t <<= 1; m >>= 1; @@ -288,7 +293,7 @@ void InverseTransformFromBitReverseAVX512( // t = 4 W = &inv_root_of_unity_powers[W_idx]; W_precon = &precon_inv_root_of_unity_powers[W_idx]; - InvT4(operand, v_neg_modulus, v_twice_mod, m, W, W_precon); + InvT4(result, v_neg_modulus, v_twice_mod, m, W, W_precon); t <<= 1; m >>= 1; W_idx_delta >>= 1; @@ -298,7 +303,7 @@ void InverseTransformFromBitReverseAVX512( for (; m > 1;) { W = &inv_root_of_unity_powers[W_idx]; W_precon = &precon_inv_root_of_unity_powers[W_idx]; - InvT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); t <<= 1; m >>= 1; W_idx_delta >>= 1; @@ -307,13 +312,14 @@ void InverseTransformFromBitReverseAVX512( } } else { InverseTransformFromBitReverseAVX512( - operand, n / 2, modulus, inv_root_of_unity_powers, + result, operand, n / 2, modulus, inv_root_of_unity_powers, precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor, recursion_depth + 1, 2 * recursion_half); InverseTransformFromBitReverseAVX512( - &operand[n / 2], n / 2, modulus, inv_root_of_unity_powers, - precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor, - recursion_depth + 1, 2 * recursion_half + 1); + &result[n / 2], &operand[n / 2], n / 2, modulus, + inv_root_of_unity_powers, precon_inv_root_of_unity_powers, + input_mod_factor, output_mod_factor, recursion_depth + 1, + 2 * recursion_half + 1); uint64_t W_idx_delta = m * ((1ULL << (recursion_depth + 1)) - recursion_half); @@ -325,7 +331,7 @@ void InverseTransformFromBitReverseAVX512( if (m == 2) { const uint64_t* W = &inv_root_of_unity_powers[W_idx]; const uint64_t* W_precon = &precon_inv_root_of_unity_powers[W_idx]; - InvT8(operand, v_neg_modulus, v_twice_mod, t, m, W, W_precon); + InvT8(result, v_neg_modulus, v_twice_mod, t, m, W, W_precon); t <<= 1; m >>= 1; W_idx_delta >>= 1; @@ -335,8 +341,8 @@ void InverseTransformFromBitReverseAVX512( // Final loop through data if (recursion_depth == 0) { - HEXL_VLOG(4, "AVX512 intermediate operand " - << std::vector(operand, operand + n)); + HEXL_VLOG(4, "AVX512 intermediate result " + << std::vector(result, result + n)); const uint64_t W = inv_root_of_unity_powers[W_idx]; MultiplyFactor mf_inv_n(InverseMod(n, modulus), BitShift, modulus); @@ -350,7 +356,7 @@ void InverseTransformFromBitReverseAVX512( HEXL_VLOG(4, "inv_n_w " << inv_n_w); - uint64_t* X = operand; + uint64_t* X = result; uint64_t* Y = X + (n >> 1); __m512i v_inv_n = _mm512_set1_epi64(static_cast(inv_n)); @@ -416,8 +422,8 @@ void InverseTransformFromBitReverseAVX512( _mm512_storeu_si512(v_Y_pt++, v_Y); } - HEXL_VLOG(5, "AVX512 returning operand " - << std::vector(operand, operand + n)); + HEXL_VLOG(5, "AVX512 returning result " + << std::vector(result, result + n)); } } diff --git a/hexl/ntt/inv-ntt-avx512.hpp b/hexl/ntt/inv-ntt-avx512.hpp index f0731a16..36831073 100644 --- a/hexl/ntt/inv-ntt-avx512.hpp +++ b/hexl/ntt/inv-ntt-avx512.hpp @@ -34,7 +34,7 @@ namespace hexl { /// performance on larger transform sizes. template void InverseTransformFromBitReverseAVX512( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor, uint64_t recursion_depth = 0, diff --git a/hexl/ntt/ntt-default.hpp b/hexl/ntt/ntt-default.hpp index 82892d4f..be68f2f2 100644 --- a/hexl/ntt/ntt-default.hpp +++ b/hexl/ntt/ntt-default.hpp @@ -10,10 +10,13 @@ namespace intel { namespace hexl { -/// @brief The Harvey butterfly: assume \p X, \p Y in [0, 4q), and return X', Y' -/// in [0, 4q) such that X' = X + WY, Y' = X - WY (mod q). -/// @param[in,out] X Butterfly data -/// @param[in,out] Y Butterfly data +/// @brief Out of place Harvey butterfly: assume \p X_op, \p Y_op in [0, 4q), +/// and return X_r, Y_r in [0, 4q) such that X_r = X_op + WY_op, Y_r = X_op - +/// WY_op (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data /// @param[in] W Root of unity /// @param[in] W_precon Preconditioned \p W for BitShift-bit Barrett /// reduction @@ -22,53 +25,58 @@ namespace hexl { /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit /// signed integers in SIMD form /// @details See Algorithm 4 of https://arxiv.org/pdf/1205.2926.pdf -inline void FwdButterflyRadix2(uint64_t* X, uint64_t* Y, uint64_t W, - uint64_t W_precon, uint64_t modulus, +inline void FwdButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, uint64_t twice_modulus) { HEXL_VLOG(5, "FwdButterflyRadix2"); - HEXL_VLOG(5, "Inputs: X " << *X << ", Y " << *Y << ", W " << W << ", modulus " - << modulus); - uint64_t tx = ReduceMod<2>(*X, twice_modulus); - uint64_t T = MultiplyModLazy<64>(*Y, W, W_precon, modulus); + HEXL_VLOG(5, "Inputs: X_op " << *X_op << ", Y_op " << *Y_op << ", W " << W + << ", modulus " << modulus); + uint64_t tx = ReduceMod<2>(*X_op, twice_modulus); + uint64_t T = MultiplyModLazy<64>(*Y_op, W, W_precon, modulus); HEXL_VLOG(5, "T " << T); - *X = tx + T; - *Y = tx + twice_modulus - T; + *X_r = tx + T; + *Y_r = tx + twice_modulus - T; - HEXL_VLOG(5, "Output X " << *X << ", Y " << *Y); + HEXL_VLOG(5, "Output X " << *X_r << ", Y " << *Y_r); } -// Assume X, Y in [0, n*q) and return X', Y' in [0, (n+2)*q) -// such that X' = X + WY mod q and Y' = X - WY mod q -inline void FwdButterflyRadix4Lazy(uint64_t* X, uint64_t* Y, uint64_t W, - uint64_t W_precon, uint64_t modulus, - uint64_t twice_modulus) { +// Assume X, Y in [0, n*q) and return X_r, Y_r in [0, (n+2)*q) +// such that X_r = X_op + WY_op mod q and Y_r = X_op - WY_op mod q +inline void FwdButterflyRadix4Lazy(uint64_t* X_r, uint64_t* Y_r, + const uint64_t X_op, const uint64_t Y_op, + uint64_t W, uint64_t W_precon, + uint64_t modulus, uint64_t twice_modulus) { HEXL_VLOG(3, "FwdButterflyRadix4Lazy"); - HEXL_VLOG(3, "Inputs: X " << *X << ", Y " << *Y << ", W " << W << ", modulus " - << modulus); + HEXL_VLOG(3, "Inputs: X_op " << X_op << ", Y_op " << Y_op << ", W " << W + << ", modulus " << modulus); - uint64_t tx = *X; - uint64_t T = MultiplyModLazy<64>(*Y, W, W_precon, modulus); + uint64_t T = MultiplyModLazy<64>(Y_op, W, W_precon, modulus); HEXL_VLOG(3, "T " << T); - *X = tx + T; - *Y = tx + twice_modulus - T; + *X_r = X_op + T; + *Y_r = X_op + twice_modulus - T; - HEXL_VLOG(3, "Outputs: X " << *X << ", Y " << *Y); + HEXL_VLOG(3, "Outputs: X_r " << *X_r << ", Y_r " << *Y_r); } // Assume X0, X1, X2, X3 in [0, 4q) and return X0, X1, X2, X3 in [0, 4q) -inline void FwdButterflyRadix4(uint64_t* X0, uint64_t* X1, uint64_t* X2, - uint64_t* X3, uint64_t W1, uint64_t W1_precon, - uint64_t W2, uint64_t W2_precon, uint64_t W3, - uint64_t W3_precon, uint64_t modulus, - uint64_t twice_modulus, - uint64_t four_times_modulus) { +inline void FwdButterflyRadix4( + uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, uint64_t* X_r3, + const uint64_t* X_op0, const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, uint64_t W3_precon, uint64_t modulus, + uint64_t twice_modulus, uint64_t four_times_modulus) { HEXL_VLOG(3, "FwdButterflyRadix4"); HEXL_UNUSED(four_times_modulus); - FwdButterflyRadix2(X0, X2, W1, W1_precon, modulus, twice_modulus); - FwdButterflyRadix2(X1, X3, W1, W1_precon, modulus, twice_modulus); - FwdButterflyRadix2(X0, X1, W2, W2_precon, modulus, twice_modulus); - FwdButterflyRadix2(X2, X3, W3, W3_precon, modulus, twice_modulus); + FwdButterflyRadix2(X_r0, X_r2, X_op0, X_op2, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r1, X_r3, X_op1, X_op3, W1, W1_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r0, X_r1, X_r0, X_r1, W2, W2_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r2, X_r3, X_r2, X_r3, W3, W3_precon, modulus, + twice_modulus); // Alternate implementation // // Returns Xs in [0, 6q) @@ -86,54 +94,65 @@ inline void FwdButterflyRadix4(uint64_t* X0, uint64_t* X1, uint64_t* X2, // *X3 = ReduceMod<2>(*X3, four_times_modulus); } -/// @brief The Harvey butterfly: assume X, Y in [0, 2q), and return X', Y' in -/// [0, 2q) such that X' = X + Y (mod q), Y' = W(X - Y) (mod q). -/// @param[in,out] X Butterfly data -/// @param[in,out] Y Butterfly data +/// @brief Out-of-place Harvey butterfly: assume X_op, Y_op in [0, 2q), and +/// return X_r, Y_r in [0, 2q) such that X_r = X_op + Y_op (mod q), +/// Y_r = W(X_op - Y_op) (mod q). +/// @param[out] X_r Butterfly data +/// @param[out] Y_r Butterfly data +/// @param[in] X_op Butterfly data +/// @param[in] Y_op Butterfly data /// @param[in] W Root of unity /// @param[in] W_precon Preconditioned root of unity for 64-bit Barrett /// reduction -/// @param[in] neg_modulus Negative modulus, i.e. (-q) represented as 8 64-bit +/// @param[in] modulus Modulus, i.e. (q) represented as 8 64-bit /// signed integers in SIMD form /// @param[in] twice_modulus Twice the modulus, i.e. 2*q represented as 8 64-bit /// signed integers in SIMD form /// @details See Algorithm 3 of https://arxiv.org/pdf/1205.2926.pdf -inline void InvButterflyRadix2(uint64_t* X, uint64_t* Y, uint64_t W, - uint64_t W_precon, uint64_t modulus, +inline void InvButterflyRadix2(uint64_t* X_r, uint64_t* Y_r, + const uint64_t* X_op, const uint64_t* Y_op, + uint64_t W, uint64_t W_precon, uint64_t modulus, uint64_t twice_modulus) { - HEXL_VLOG(4, "InvButterflyRadix2 X " << *X << ", Y " << *Y << " W " << W - << " W_precon " << W_precon - << " modulus " << modulus); - uint64_t tx = *X + *Y; - uint64_t ty = *X + twice_modulus - *Y; - - *X = ReduceMod<2>(tx, twice_modulus); - *Y = MultiplyModLazy<64>(ty, W, W_precon, modulus); - - HEXL_VLOG(4, "InvButterflyRadix2 returning X " << *X << ", Y " << *Y); + HEXL_VLOG(4, "InvButterflyRadix2 X_op " + << *X_op << ", Y_op " << *Y_op << " W " << W << " W_precon " + << W_precon << " modulus " << modulus); + uint64_t tx = *X_op + *Y_op; + *Y_r = *X_op + twice_modulus - *Y_op; + *X_r = ReduceMod<2>(tx, twice_modulus); + *Y_r = MultiplyModLazy<64>(*Y_r, W, W_precon, modulus); + + HEXL_VLOG(4, "InvButterflyRadix2 returning X_r " << *X_r << ", Y_r " << *Y_r); } // Assume X0, X1, X2, X3 in [0, 2q) and return X0, X1, X2, X3 in [0, 2q) -inline void InvButterflyRadix4(uint64_t* X0, uint64_t* X1, uint64_t* X2, - uint64_t* X3, uint64_t W1, uint64_t W1_precon, - uint64_t W2, uint64_t W2_precon, uint64_t W3, +inline void InvButterflyRadix4(uint64_t* X_r0, uint64_t* X_r1, uint64_t* X_r2, + uint64_t* X_r3, const uint64_t* X_op0, + const uint64_t* X_op1, const uint64_t* X_op2, + const uint64_t* X_op3, uint64_t W1, + uint64_t W1_precon, uint64_t W2, + uint64_t W2_precon, uint64_t W3, uint64_t W3_precon, uint64_t modulus, uint64_t twice_modulus) { - HEXL_VLOG(4, "InvButterflyRadix4 " // - << "X0 " << *X0 << ", X1 " << *X1 << ", X2 " << *X2 << " X3 " - << *X3 // - << " W1 " << W1 << " W1_precon " << W1_precon // - << " W2 " << W2 << " W2_precon " << W2_precon // - << " W3 " << W3 << " W3_precon " << W3_precon // + HEXL_VLOG(4, "InvButterflyRadix4 " // + << "X_op0 " << *X_op0 << ", X_op1 " << *X_op1 // + << ", X_op2 " << *X_op2 << " X_op3 " << *X_op3 // + << " W1 " << W1 << " W1_precon " << W1_precon // + << " W2 " << W2 << " W2_precon " << W2_precon // + << " W3 " << W3 << " W3_precon " << W3_precon // << " modulus " << modulus); - InvButterflyRadix2(X0, X1, W1, W1_precon, modulus, twice_modulus); - InvButterflyRadix2(X2, X3, W2, W2_precon, modulus, twice_modulus); - InvButterflyRadix2(X0, X2, W3, W3_precon, modulus, twice_modulus); - InvButterflyRadix2(X1, X3, W3, W3_precon, modulus, twice_modulus); - - HEXL_VLOG(4, "InvButterflyRadix4 returning X0 " - << *X0 << ", X1 " << *X1 << ", X2 " << *X2 << " X3 " << *X3); + InvButterflyRadix2(X_r0, X_r1, X_op0, X_op1, W1, W1_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r2, X_r3, X_op2, X_op3, W2, W2_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r0, X_r2, X_r0, X_r2, W3, W3_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r1, X_r3, X_r1, X_r3, W3, W3_precon, modulus, + twice_modulus); + + HEXL_VLOG(4, "InvButterflyRadix4 returning X0 " << *X_r0 << ", X_r1 " << *X_r1 + << ", X_r2 " << *X_r2 // + << " X_r3 " << *X_r3); } } // namespace hexl diff --git a/hexl/ntt/ntt-internal.cpp b/hexl/ntt/ntt-internal.cpp index 80833862..d54d8dc0 100644 --- a/hexl/ntt/ntt-internal.cpp +++ b/hexl/ntt/ntt-internal.cpp @@ -199,10 +199,6 @@ void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, operand, m_degree, m_q * input_mod_factor, "value in operand exceeds bound " << m_q * input_mod_factor); - if (result != operand) { - std::memcpy(result, operand, m_degree * sizeof(uint64_t)); - } - #ifdef HEXL_HAS_AVX512IFMA if (has_avx512ifma && (m_q < s_max_fwd_ifma_modulus && (m_degree >= 16))) { const uint64_t* root_of_unity_powers = GetAVX512RootOfUnityPowers().data(); @@ -211,7 +207,7 @@ void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA FwdNTT"); ForwardTransformToBitReverseAVX512( - result, m_degree, m_q, root_of_unity_powers, + result, operand, m_degree, m_q, root_of_unity_powers, precon_root_of_unity_powers, input_mod_factor, output_mod_factor); return; } @@ -226,7 +222,7 @@ void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, const uint64_t* precon_root_of_unity_powers = GetAVX512Precon32RootOfUnityPowers().data(); ForwardTransformToBitReverseAVX512<32>( - result, m_degree, m_q, root_of_unity_powers, + result, operand, m_degree, m_q, root_of_unity_powers, precon_root_of_unity_powers, input_mod_factor, output_mod_factor); } else { HEXL_VLOG(3, "Calling 64-bit AVX512-DQ FwdNTT"); @@ -236,7 +232,7 @@ void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, GetAVX512Precon64RootOfUnityPowers().data(); ForwardTransformToBitReverseAVX512( - result, m_degree, m_q, root_of_unity_powers, + result, operand, m_degree, m_q, root_of_unity_powers, precon_root_of_unity_powers, input_mod_factor, output_mod_factor); } return; @@ -249,8 +245,8 @@ void NTT::ComputeForward(uint64_t* result, const uint64_t* operand, GetPrecon64RootOfUnityPowers().data(); ForwardTransformToBitReverseRadix2( - result, m_degree, m_q, root_of_unity_powers, precon_root_of_unity_powers, - input_mod_factor, output_mod_factor); + result, operand, m_degree, m_q, root_of_unity_powers, + precon_root_of_unity_powers, input_mod_factor, output_mod_factor); } void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, @@ -265,10 +261,6 @@ void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, HEXL_CHECK_BOUNDS(operand, m_degree, m_q * input_mod_factor, "operand exceeds bound " << m_q * input_mod_factor); - if (operand != result) { - std::memcpy(result, operand, m_degree * sizeof(uint64_t)); - } - #ifdef HEXL_HAS_AVX512IFMA if (has_avx512ifma && (m_q < s_max_inv_ifma_modulus) && (m_degree >= 16)) { HEXL_VLOG(3, "Calling 52-bit AVX512-IFMA InvNTT"); @@ -276,7 +268,7 @@ void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, const uint64_t* precon_inv_root_of_unity_powers = GetPrecon52InvRootOfUnityPowers().data(); InverseTransformFromBitReverseAVX512( - result, m_degree, m_q, inv_root_of_unity_powers, + result, operand, m_degree, m_q, inv_root_of_unity_powers, precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); return; } @@ -291,7 +283,7 @@ void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, const uint64_t* precon_inv_root_of_unity_powers = GetPrecon32InvRootOfUnityPowers().data(); InverseTransformFromBitReverseAVX512<32>( - result, m_degree, m_q, inv_root_of_unity_powers, + result, operand, m_degree, m_q, inv_root_of_unity_powers, precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); } else { HEXL_VLOG(3, "Calling 64-bit AVX512 InvNTT"); @@ -301,7 +293,7 @@ void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, GetPrecon64InvRootOfUnityPowers().data(); InverseTransformFromBitReverseAVX512( - result, m_degree, m_q, inv_root_of_unity_powers, + result, operand, m_degree, m_q, inv_root_of_unity_powers, precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); } return; @@ -313,7 +305,7 @@ void NTT::ComputeInverse(uint64_t* result, const uint64_t* operand, const uint64_t* precon_inv_root_of_unity_powers = GetPrecon64InvRootOfUnityPowers().data(); InverseTransformFromBitReverseRadix2( - result, m_degree, m_q, inv_root_of_unity_powers, + result, operand, m_degree, m_q, inv_root_of_unity_powers, precon_inv_root_of_unity_powers, input_mod_factor, output_mod_factor); } diff --git a/hexl/ntt/ntt-internal.hpp b/hexl/ntt/ntt-internal.hpp index f5e158b3..b1449ce1 100644 --- a/hexl/ntt/ntt-internal.hpp +++ b/hexl/ntt/ntt-internal.hpp @@ -18,7 +18,8 @@ namespace intel { namespace hexl { /// @brief Radix-2 native C++ NTT implementation of the forward NTT -/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. /// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a /// power of two. /// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n @@ -31,13 +32,14 @@ namespace hexl { /// @param[in] output_mod_factor Upper bound for result; result must be in [0, /// output_mod_factor * q) void ForwardTransformToBitReverseRadix2( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); /// @brief Radix-4 native C++ NTT implementation of the forward NTT -/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. /// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a /// power of two. /// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n @@ -50,7 +52,7 @@ void ForwardTransformToBitReverseRadix2( /// @param[in] output_mod_factor Upper bound for result; result must be in [0, /// output_mod_factor * q) void ForwardTransformToBitReverseRadix4( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); @@ -67,7 +69,8 @@ void ReferenceForwardTransformToBitReverse( const uint64_t* root_of_unity_powers); /// @brief Radix-2 native C++ NTT implementation of the inverse NTT -/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. /// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a /// power of two. /// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n @@ -80,13 +83,14 @@ void ReferenceForwardTransformToBitReverse( /// @param[in] output_mod_factor Upper bound for result; result must be in [0, /// output_mod_factor * q) void InverseTransformFromBitReverseRadix2( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); /// @brief Radix-4 native C++ NTT implementation of the inverse NTT -/// @param[in, out] operand Input data. Overwritten with NTT output +/// @param[out] result Output data. Overwritten with NTT output +/// @param[in] operand Input data. /// @param[in] n Size of the transform, i.e. the polynomial degree. Must be a /// power of two. /// @param[in] modulus Prime modulus q. Must satisfy q == 1 mod 2n @@ -99,7 +103,7 @@ void InverseTransformFromBitReverseRadix2( /// @param[in] output_mod_factor Upper bound for result; result must be in [0, /// output_mod_factor * q) void InverseTransformFromBitReverseRadix4( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor = 1, uint64_t output_mod_factor = 1); diff --git a/hexl/ntt/ntt-radix-2.cpp b/hexl/ntt/ntt-radix-2.cpp index 3039f2c0..1549740b 100644 --- a/hexl/ntt/ntt-radix-2.cpp +++ b/hexl/ntt/ntt-radix-2.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2020-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include + #include "hexl/logging/logging.hpp" #include "hexl/ntt/ntt.hpp" #include "hexl/number-theory/number-theory.hpp" @@ -13,7 +15,7 @@ namespace intel { namespace hexl { void ForwardTransformToBitReverseRadix2( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -34,7 +36,88 @@ void ForwardTransformToBitReverseRadix2( uint64_t twice_modulus = modulus << 1; size_t t = (n >> 1); - for (size_t m = 1; m < n; m <<= 1) { + // In case of out-of-place operation do first pass and convert to in-place + { + const uint64_t W = root_of_unity_powers[1]; + const uint64_t W_precon = precon_root_of_unity_powers[1]; + + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + + // First pass for out-of-order case + switch (t) { + case 8: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 4: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 2: { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + break; + } + case 1: { + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); + break; + } + default: { + HEXL_LOOP_UNROLL_8 + for (size_t j = 0; j < t; j += 8) { + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + } + } + } + t >>= 1; + } + + // Continue with in-place operation + for (size_t m = 2; m < n; m <<= 1) { size_t j1 = 0; switch (t) { case 8: { @@ -45,16 +128,27 @@ void ForwardTransformToBitReverseRadix2( const uint64_t W = root_of_unity_powers[m + i]; const uint64_t W_precon = precon_root_of_unity_powers[m + i]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -66,12 +160,19 @@ void ForwardTransformToBitReverseRadix2( const uint64_t W = root_of_unity_powers[m + i]; const uint64_t W_precon = precon_root_of_unity_powers[m + i]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -83,10 +184,15 @@ void ForwardTransformToBitReverseRadix2( const uint64_t W = root_of_unity_powers[m + i]; const uint64_t W_precon = precon_root_of_unity_powers[m + i]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -98,9 +204,13 @@ void ForwardTransformToBitReverseRadix2( const uint64_t W = root_of_unity_powers[m + i]; const uint64_t W_precon = precon_root_of_unity_powers[m + i]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - FwdButterflyRadix2(X, Y, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + + FwdButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); } break; } @@ -112,18 +222,29 @@ void ForwardTransformToBitReverseRadix2( const uint64_t W = root_of_unity_powers[m + i]; const uint64_t W_precon = precon_root_of_unity_powers[m + i]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < t; j += 8) { - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); } } } @@ -132,9 +253,9 @@ void ForwardTransformToBitReverseRadix2( } if (output_mod_factor == 1) { for (size_t i = 0; i < n; ++i) { - operand[i] = ReduceMod<4>(operand[i], modulus, &twice_modulus); - HEXL_CHECK(operand[i] < modulus, "Incorrect modulus reduction in NTT " - << operand[i] << " >= " << modulus); + result[i] = ReduceMod<4>(result[i], modulus, &twice_modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); } } } @@ -170,7 +291,7 @@ void ReferenceForwardTransformToBitReverse( } void InverseTransformFromBitReverseRadix2( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -203,10 +324,12 @@ void InverseTransformFromBitReverseRadix2( const uint64_t W = inv_root_of_unity_powers[root_index]; const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - - InvButterflyRadix2(X, Y, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand + j1; + const uint64_t* Y_op = X_op + t; + InvButterflyRadix2(X_r, Y_r, X_op, Y_op, W, W_precon, modulus, + twice_modulus); } break; } @@ -218,11 +341,14 @@ void InverseTransformFromBitReverseRadix2( const uint64_t W = inv_root_of_unity_powers[root_index]; const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -234,13 +360,18 @@ void InverseTransformFromBitReverseRadix2( const uint64_t W = inv_root_of_unity_powers[root_index]; const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -252,17 +383,26 @@ void InverseTransformFromBitReverseRadix2( const uint64_t W = inv_root_of_unity_powers[root_index]; const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; - - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } break; } @@ -274,18 +414,29 @@ void InverseTransformFromBitReverseRadix2( const uint64_t W = inv_root_of_unity_powers[root_index]; const uint64_t W_precon = precon_inv_root_of_unity_powers[root_index]; - uint64_t* X = operand + j1; - uint64_t* Y = X + t; + uint64_t* X_r = result + j1; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = X_r; + const uint64_t* Y_op = Y_r; + HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < t; j += 8) { - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); - InvButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, + modulus, twice_modulus); } } } @@ -293,6 +444,12 @@ void InverseTransformFromBitReverseRadix2( t <<= 1; } + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + // Fold multiplication by N^{-1} to final stage butterfly const uint64_t W = inv_root_of_unity_powers[n - 1]; @@ -302,7 +459,7 @@ void InverseTransformFromBitReverseRadix2( uint64_t inv_n_w_precon = MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); - uint64_t* X = operand; + uint64_t* X = result; uint64_t* Y = X + n_div_2; for (size_t j = 0; j < n_div_2; ++j) { // Assume X, Y in [0, 2q) and compute @@ -317,9 +474,9 @@ void InverseTransformFromBitReverseRadix2( if (output_mod_factor == 1) { // Reduce from [0, 2q) to [0,q) for (size_t i = 0; i < n; ++i) { - operand[i] = ReduceMod<2>(operand[i], modulus); - HEXL_CHECK(operand[i] < modulus, "Incorrect modulus reduction in InvNTT" - << operand[i] << " >= " << modulus); + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); } } } diff --git a/hexl/ntt/ntt-radix-4.cpp b/hexl/ntt/ntt-radix-4.cpp index 925896d7..9696f101 100644 --- a/hexl/ntt/ntt-radix-4.cpp +++ b/hexl/ntt/ntt-radix-4.cpp @@ -1,6 +1,8 @@ // Copyright (C) 2020-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 +#include + #include "hexl/logging/logging.hpp" #include "hexl/ntt/ntt.hpp" #include "hexl/number-theory/number-theory.hpp" @@ -13,7 +15,7 @@ namespace intel { namespace hexl { void ForwardTransformToBitReverseRadix4( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* root_of_unity_powers, const uint64_t* precon_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -53,20 +55,149 @@ void ForwardTransformToBitReverseRadix4( const uint64_t W = root_of_unity_powers[1]; const uint64_t W_precon = precon_root_of_unity_powers[1]; - uint64_t* X = operand; - uint64_t* Y = X + t; + uint64_t* X_r = result; + uint64_t* Y_r = X_r + t; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + t; + HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < t; j++) { - FwdButterflyRadix2(X++, Y++, W, W_precon, modulus, twice_modulus); + FwdButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, W, W_precon, modulus, + twice_modulus); } // Data in [0, 4q) + HEXL_VLOG(3, "after radix 2 outputs " + << std::vector(result, result + n)); } - HEXL_VLOG(3, "after radix 2 outputs " - << std::vector(operand, operand + n)); + uint64_t m_start = 2; + size_t t = n >> 3; + if (is_power_of_4) { + t = n >> 2; + + uint64_t* X_r0 = result; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand; + const uint64_t* X_op1 = operand + t; + const uint64_t* X_op2 = operand + 2 * t; + const uint64_t* X_op3 = operand + 3 * t; - uint64_t m_start = is_power_of_4 ? 1 : 2; - size_t t = (n >> m_start) >> 1; + uint64_t W1_ind = 1; + uint64_t W2_ind = 2 * W1_ind; + uint64_t W3_ind = 2 * W1_ind + 1; + + const uint64_t W1 = root_of_unity_powers[W1_ind]; + const uint64_t W2 = root_of_unity_powers[W2_ind]; + const uint64_t W3 = root_of_unity_powers[W3_ind]; + + const uint64_t W1_precon = precon_root_of_unity_powers[W1_ind]; + const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; + const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; + + switch (t) { + case 4: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + case 1: { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + break; + } + default: { + for (size_t j = 0; j < t; j += 16) { + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, + four_times_modulus); + } + } + } + t >>= 2; + m_start = 4; + } + + // uint64_t m_start = is_power_of_4 ? 1 : 2; + // size_t t = (n >> m_start) >> 1; for (size_t m = m_start; m < n; m <<= 2) { HEXL_VLOG(3, "m " << m); @@ -80,10 +211,14 @@ void ForwardTransformToBitReverseRadix4( if (i != 0) { X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; uint64_t W1_ind = m + i; uint64_t W2_ind = 2 * W1_ind; @@ -97,16 +232,20 @@ void ForwardTransformToBitReverseRadix4( const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0, X1, X2, X3, W1, W1_precon, W2, W2_precon, W3, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, W3_precon, modulus, twice_modulus, four_times_modulus); } @@ -118,10 +257,14 @@ void ForwardTransformToBitReverseRadix4( if (i != 0) { X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; uint64_t W1_ind = m + i; uint64_t W2_ind = 2 * W1_ind; @@ -135,7 +278,8 @@ void ForwardTransformToBitReverseRadix4( const uint64_t W2_precon = precon_root_of_unity_powers[W2_ind]; const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; - FwdButterflyRadix4(X0, X1, X2, X3, W1, W1_precon, W2, W2_precon, W3, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, W3_precon, modulus, twice_modulus, four_times_modulus); } @@ -146,10 +290,14 @@ void ForwardTransformToBitReverseRadix4( if (i != 0) { X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; uint64_t W1_ind = m + i; uint64_t W2_ind = 2 * W1_ind; @@ -164,53 +312,69 @@ void ForwardTransformToBitReverseRadix4( const uint64_t W3_precon = precon_root_of_unity_powers[W3_ind]; for (size_t j = 0; j < t; j += 16) { - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); - FwdButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus, + FwdButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus, four_times_modulus); } } @@ -221,22 +385,22 @@ void ForwardTransformToBitReverseRadix4( if (output_mod_factor == 1) { for (size_t i = 0; i < n; ++i) { - if (operand[i] >= twice_modulus) { - operand[i] -= twice_modulus; + if (result[i] >= twice_modulus) { + result[i] -= twice_modulus; } - if (operand[i] >= modulus) { - operand[i] -= modulus; + if (result[i] >= modulus) { + result[i] -= modulus; } - HEXL_CHECK(operand[i] < modulus, "Incorrect modulus reduction in NTT " - << operand[i] << " >= " << modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in NTT " + << result[i] << " >= " << modulus); } } - HEXL_VLOG(3, "outputs " << std::vector(operand, operand + n)); + HEXL_VLOG(3, "outputs " << std::vector(result, result + n)); } void InverseTransformFromBitReverseRadix4( - uint64_t* operand, uint64_t n, uint64_t modulus, + uint64_t* result, const uint64_t* operand, uint64_t n, uint64_t modulus, const uint64_t* inv_root_of_unity_powers, const uint64_t* precon_inv_root_of_unity_powers, uint64_t input_mod_factor, uint64_t output_mod_factor) { @@ -258,16 +422,21 @@ void InverseTransformFromBitReverseRadix4( bool is_power_of_4 = IsPowerOfFour(n); // Radix-2 step for powers of 4 if (is_power_of_4) { - uint64_t* X = operand; - uint64_t* Y = X + 1; + uint64_t* X_r = result; + uint64_t* Y_r = X_r + 1; + const uint64_t* X_op = operand; + const uint64_t* Y_op = X_op + 1; const uint64_t* W = inv_root_of_unity_powers + 1; const uint64_t* W_precon = precon_inv_root_of_unity_powers + 1; HEXL_LOOP_UNROLL_8 for (size_t j = 0; j < n / 2; j++) { - InvButterflyRadix2(X++, Y++, *W++, *W_precon++, modulus, twice_modulus); - X++; - Y++; + InvButterflyRadix2(X_r++, Y_r++, X_op++, Y_op++, *W++, *W_precon++, + modulus, twice_modulus); + X_r++; + Y_r++; + X_op++; + Y_op++; } // Data in [0, 2q) } @@ -294,10 +463,14 @@ void InverseTransformFromBitReverseRadix4( X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = operand + X0_offset; + const uint64_t* X_op1 = X_op0 + t; + const uint64_t* X_op2 = X_op0 + 2 * t; + const uint64_t* X_op3 = X_op0 + 3 * t; uint64_t W1_ind = w1_root_index++; uint64_t W2_ind = w1_root_index++; @@ -311,7 +484,8 @@ void InverseTransformFromBitReverseRadix4( const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; - InvButterflyRadix4(X0, X1, X2, X3, W1, W1_precon, W2, W2_precon, W3, + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, W3_precon, modulus, twice_modulus); } break; @@ -323,10 +497,14 @@ void InverseTransformFromBitReverseRadix4( X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; uint64_t W1_ind = w1_root_index++; uint64_t W2_ind = w1_root_index++; @@ -340,13 +518,17 @@ void InverseTransformFromBitReverseRadix4( const uint64_t W2_precon = precon_inv_root_of_unity_powers[W2_ind]; const uint64_t W3_precon = precon_inv_root_of_unity_powers[W3_ind]; - InvButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus); - InvButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus); - InvButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, twice_modulus); - InvButterflyRadix4(X0, X1, X2, X3, W1, W1_precon, W2, W2_precon, W3, + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, + W3_precon, modulus, twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, W3, W3_precon, modulus, twice_modulus); } break; @@ -359,10 +541,14 @@ void InverseTransformFromBitReverseRadix4( X0_offset += 4 * t; } - uint64_t* X0 = operand + X0_offset; - uint64_t* X1 = X0 + t; - uint64_t* X2 = X0 + 2 * t; - uint64_t* X3 = X0 + 3 * t; + uint64_t* X_r0 = result + X0_offset; + uint64_t* X_r1 = X_r0 + t; + uint64_t* X_r2 = X_r0 + 2 * t; + uint64_t* X_r3 = X_r0 + 3 * t; + const uint64_t* X_op0 = X_r0; + const uint64_t* X_op1 = X_r1; + const uint64_t* X_op2 = X_r2; + const uint64_t* X_op3 = X_r3; uint64_t W1_ind = w1_root_index++; uint64_t W2_ind = w1_root_index++; @@ -378,9 +564,9 @@ void InverseTransformFromBitReverseRadix4( for (size_t j = 0; j < t; j++) { HEXL_VLOG(4, "j " << j); - InvButterflyRadix4(X0++, X1++, X2++, X3++, W1, W1_precon, W2, - W2_precon, W3, W3_precon, modulus, - twice_modulus); + InvButterflyRadix4(X_r0++, X_r1++, X_r2++, X_r3++, X_op0++, X_op1++, + X_op2++, X_op3++, W1, W1_precon, W2, W2_precon, + W3, W3_precon, modulus, twice_modulus); } } } @@ -390,8 +576,14 @@ void InverseTransformFromBitReverseRadix4( w3_root_index += m / 2; } + // When M is too short it only needs the final stage butterfly. Copying here + // in the case of out-of-place. + if (result != operand && n == 2) { + std::memcpy(result, operand, n * sizeof(uint64_t)); + } + HEXL_VLOG(4, "Starting final invNTT stage"); - HEXL_VLOG(4, "operand " << std::vector(operand, operand + n)); + HEXL_VLOG(4, "operand " << std::vector(result, result + n)); // Fold multiplication by N^{-1} to final stage butterfly const uint64_t W = inv_root_of_unity_powers[n - 1]; @@ -403,7 +595,7 @@ void InverseTransformFromBitReverseRadix4( uint64_t inv_n_w_precon = MultiplyFactor(inv_n_w, 64, modulus).BarrettFactor(); - uint64_t* X = operand; + uint64_t* X = result; uint64_t* Y = X + n_div_2; for (size_t j = 0; j < n_div_2; ++j) { // Assume X, Y in [0, 2q) and compute @@ -419,9 +611,9 @@ void InverseTransformFromBitReverseRadix4( if (output_mod_factor == 1) { // Reduce from [0, 2q) to [0,q) for (size_t i = 0; i < n; ++i) { - operand[i] = ReduceMod<2>(operand[i], modulus); - HEXL_CHECK(operand[i] < modulus, "Incorrect modulus reduction in InvNTT" - << operand[i] << " >= " << modulus); + result[i] = ReduceMod<2>(result[i], modulus); + HEXL_CHECK(result[i] < modulus, "Incorrect modulus reduction in InvNTT" + << result[i] << " >= " << modulus); } } } diff --git a/test/test-ntt-avx512.cpp b/test/test-ntt-avx512.cpp index 7a76db27..8f2ce4c0 100644 --- a/test/test-ntt-avx512.cpp +++ b/test/test-ntt-avx512.cpp @@ -187,13 +187,13 @@ TEST_P(NttAVX512Test, FwdNTT_AVX512IFMA) { m_ntt.GetRootOfUnityPowers().data()); ForwardTransformToBitReverseAVX512<52>( - input_ifma.data(), m_N, m_ntt.GetModulus(), + input_ifma.data(), input_ifma.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon52RootOfUnityPowers().data(), 1, 1); // Compute lazy ForwardTransformToBitReverseAVX512<52>( - input_ifma_lazy.data(), m_N, m_ntt.GetModulus(), + input_ifma_lazy.data(), input_ifma_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon52RootOfUnityPowers().data(), 2, 4); for (auto& elem : input_ifma_lazy) { @@ -219,17 +219,18 @@ TEST_P(NttAVX512Test, InvNTT_AVX512IFMA) { // Compute reference InverseTransformFromBitReverseRadix2( - input64.data(), m_N, m_modulus, m_ntt.GetInvRootOfUnityPowers().data(), + input64.data(), input64.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 1); InverseTransformFromBitReverseAVX512<52>( - input_ifma.data(), m_N, m_ntt.GetModulus(), + input_ifma.data(), input_ifma.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon52InvRootOfUnityPowers().data(), 1, 1); // Compute lazy InverseTransformFromBitReverseAVX512<52>( - input_ifma_lazy.data(), m_N, m_ntt.GetModulus(), + input_ifma_lazy.data(), input_ifma_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon52InvRootOfUnityPowers().data(), 1, 2); for (auto& elem : input_ifma_lazy) { @@ -255,17 +256,18 @@ TEST_P(NttAVX512Test, FwdNTT_AVX512_32) { AlignedVector64 input_avx_lazy = input; ForwardTransformToBitReverseRadix2( - input.data(), m_N, m_modulus, m_ntt.GetRootOfUnityPowers().data(), + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetRootOfUnityPowers().data(), m_ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); ForwardTransformToBitReverseAVX512<32>( - input_avx.data(), m_N, m_ntt.GetModulus(), + input_avx.data(), input_avx.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon32RootOfUnityPowers().data(), 2, 1); // Compute lazy ForwardTransformToBitReverseAVX512<32>( - input_avx_lazy.data(), m_N, m_ntt.GetModulus(), + input_avx_lazy.data(), input_avx_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon32RootOfUnityPowers().data(), 2, 4); for (auto& elem : input_avx_lazy) { @@ -290,17 +292,18 @@ TEST_P(NttAVX512Test, FwdNTT_AVX512_64) { AlignedVector64 input_avx_lazy = input; ForwardTransformToBitReverseRadix2( - input.data(), m_N, m_modulus, m_ntt.GetRootOfUnityPowers().data(), + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetRootOfUnityPowers().data(), m_ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); ForwardTransformToBitReverseAVX512<64>( - input_avx.data(), m_N, m_ntt.GetModulus(), + input_avx.data(), input_avx.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon64RootOfUnityPowers().data(), 2, 1); // Compute lazy ForwardTransformToBitReverseAVX512<64>( - input_avx_lazy.data(), m_N, m_ntt.GetModulus(), + input_avx_lazy.data(), input_avx_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetAVX512RootOfUnityPowers().data(), m_ntt.GetAVX512Precon64RootOfUnityPowers().data(), 2, 4); for (auto& elem : input_avx_lazy) { @@ -326,17 +329,18 @@ TEST_P(NttAVX512Test, InvNTT_AVX512_32) { AlignedVector64 input_avx_lazy = input; InverseTransformFromBitReverseRadix2( - input.data(), m_N, m_modulus, m_ntt.GetInvRootOfUnityPowers().data(), + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 1); InverseTransformFromBitReverseAVX512<32>( - input_avx.data(), m_N, m_ntt.GetModulus(), + input_avx.data(), input_avx.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon32InvRootOfUnityPowers().data(), 1, 1); // Compute lazy InverseTransformFromBitReverseAVX512<32>( - input_avx_lazy.data(), m_N, m_ntt.GetModulus(), + input_avx_lazy.data(), input_avx_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon32InvRootOfUnityPowers().data(), 1, 2); for (auto& elem : input_avx_lazy) { @@ -361,17 +365,18 @@ TEST_P(NttAVX512Test, InvNTT_AVX512_64) { AlignedVector64 input_avx_lazy = input; InverseTransformFromBitReverseRadix2( - input.data(), m_N, m_modulus, m_ntt.GetInvRootOfUnityPowers().data(), + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 1); InverseTransformFromBitReverseAVX512<64>( - input_avx.data(), m_N, m_ntt.GetModulus(), + input_avx.data(), input_avx.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 1); // Compute lazy InverseTransformFromBitReverseAVX512<64>( - input_avx_lazy.data(), m_N, m_ntt.GetModulus(), + input_avx_lazy.data(), input_avx_lazy.data(), m_N, m_ntt.GetModulus(), m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 1, 2); for (auto& elem : input_avx_lazy) { diff --git a/test/test-ntt.cpp b/test/test-ntt.cpp index 362e04b6..70543691 100644 --- a/test/test-ntt.cpp +++ b/test/test-ntt.cpp @@ -282,14 +282,73 @@ TEST_P(DegreeModulusInputOutput, API) { } AssertEqual(input, input_copy); - auto input_radix4 = input; - InverseTransformFromBitReverseRadix4( - input_radix4.data(), N, modulus, ntt.GetInvRootOfUnityPowers().data(), + // In-place Fwd Radix2 + auto input_radix2 = input_copy; + ForwardTransformToBitReverseRadix2( + input_radix2.data(), input_radix2.data(), N, modulus, + ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); + + AssertEqual(input_radix2, exp_output); + + // In-place Inv Radix2 + InverseTransformFromBitReverseRadix2( + input_radix2.data(), input_radix2.data(), N, modulus, + ntt.GetInvRootOfUnityPowers().data(), ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); + AssertEqual(input_radix2, input_copy); + + // Out-of-place Fwd Radix2 + input_radix2 = input_copy; + ForwardTransformToBitReverseRadix2(out_buffer.data(), input_radix2.data(), N, + modulus, ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), + 2, 1); + + AssertEqual(out_buffer, exp_output); + + // Out-of-place Inv Radix2 InverseTransformFromBitReverseRadix2( - input.data(), N, modulus, ntt.GetInvRootOfUnityPowers().data(), + input_radix2.data(), out_buffer.data(), N, modulus, + ntt.GetInvRootOfUnityPowers().data(), ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); + + AssertEqual(input_radix2, input_copy); + + // In-place Fwd Radix4 + auto input_radix4 = input_copy; + ForwardTransformToBitReverseRadix4( + input_radix4.data(), input_radix4.data(), N, modulus, + ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); + + AssertEqual(input_radix4, exp_output); + + // In-place Inv Radix4 + InverseTransformFromBitReverseRadix4( + input_radix4.data(), input_radix4.data(), N, modulus, + ntt.GetInvRootOfUnityPowers().data(), + ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); + + AssertEqual(input_radix4, input_copy); + + // Out-of-place Fwd Radix4 + input_radix4 = input_copy; + ForwardTransformToBitReverseRadix4(out_buffer.data(), input_radix4.data(), N, + modulus, ntt.GetRootOfUnityPowers().data(), + ntt.GetPrecon64RootOfUnityPowers().data(), + 2, 1); + + AssertEqual(out_buffer, exp_output); + + // Out-of-place Inv Radix4 + InverseTransformFromBitReverseRadix4( + input_radix4.data(), out_buffer.data(), N, modulus, + ntt.GetInvRootOfUnityPowers().data(), + ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); + + AssertEqual(input_radix4, input_copy); } INSTANTIATE_TEST_SUITE_P( @@ -362,7 +421,8 @@ TEST_P(NttNativeTest, ForwardRadix4Random) { auto input_radix4 = input; ForwardTransformToBitReverseRadix4( - input_radix4.data(), m_N, m_modulus, m_ntt.GetRootOfUnityPowers().data(), + input_radix4.data(), input_radix4.data(), m_N, m_modulus, + m_ntt.GetRootOfUnityPowers().data(), m_ntt.GetPrecon64RootOfUnityPowers().data(), 2, 1); ReferenceForwardTransformToBitReverse(input.data(), m_N, m_modulus, @@ -376,11 +436,12 @@ TEST_P(NttNativeTest, InverseRadix4Random) { auto input_radix4 = input; InverseTransformFromBitReverseRadix2( - input.data(), m_N, m_modulus, m_ntt.GetInvRootOfUnityPowers().data(), + input.data(), input.data(), m_N, m_modulus, + m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1); InverseTransformFromBitReverseRadix4( - input_radix4.data(), m_N, m_modulus, + input_radix4.data(), input_radix4.data(), m_N, m_modulus, m_ntt.GetInvRootOfUnityPowers().data(), m_ntt.GetPrecon64InvRootOfUnityPowers().data(), 2, 1);