Skip to content

Commit 6ef9873

Browse files
committed
checkpoints + better montgomery simd
1 parent a97ac75 commit 6ef9873

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

cp-algo/util/simd.hpp

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,16 @@ namespace cp_algo {
3838
};
3939
}
4040

41+
[[gnu::always_inline]] inline u64x4 low32(u64x4 x) {
42+
return x & uint32_t(-1);
43+
}
44+
4145
[[gnu::always_inline]] inline u64x4 montgomery_reduce(u64x4 x, uint32_t mod, uint32_t imod) {
4246
auto x_ninv = u64x4(u32x8(x) * (u32x8() + imod));
4347
#ifdef __AVX2__
4448
x += u64x4(_mm256_mul_epu32(__m256i(x_ninv), __m256i() + mod));
4549
#else
46-
x += x_ninv * mod;
50+
x += low32(x_ninv) * mod;
4751
#endif
4852
return x >> 32;
4953
}
@@ -52,13 +56,13 @@ namespace cp_algo {
5256
#ifdef __AVX2__
5357
return montgomery_reduce(u64x4(_mm256_mul_epu32(__m256i(x), __m256i(y))), mod, imod);
5458
#else
55-
return montgomery_reduce(x * y, mod, imod);
59+
return montgomery_reduce(low32(x) * low32(y), mod, imod);
5660
#endif
5761
}
5862

5963
[[gnu::always_inline]] inline u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
60-
auto x0246 = u64x4(x) & uint32_t(-1);
61-
auto y0246 = u64x4(y) & uint32_t(-1);
64+
auto x0246 = u64x4(x);
65+
auto y0246 = u64x4(y);
6266
auto x1357 = u64x4(x) >> 32;
6367
auto y1357 = u64x4(y) >> 32;
6468
return u32x8(montgomery_mul(x0246, y0246, mod, imod)) |

verify/simd/many_facts.test.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
#define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#include <bits/stdc++.h>
5+
#define CP_ALGO_CHECKPOINT
6+
#include "cp-algo/util/checkpoint.hpp"
57
#include "blazingio/blazingio.min.hpp"
68
#include "cp-algo/util/simd.hpp"
79
#include "cp-algo/math/common.hpp"
@@ -25,6 +27,7 @@ void facts_inplace(vector<int> &args) {
2527
args_per_block[(mod - x - 1) / block].push_back(i);
2628
}
2729
}
30+
cp_algo::checkpoint("init");
2831
uint32_t b2x32 = (1ULL << 32) % mod;
2932
uint64_t fact = 1;
3033
const int accum = 4;
@@ -49,6 +52,7 @@ void facts_inplace(vector<int> &args) {
4952
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
5053
}
5154
}
55+
cp_algo::checkpoint("inner loop");
5256
for(int z = 0; z < accum; z++) {
5357
uint64_t bl = b + z * block;
5458
for(auto i: args_per_block[bl / block]) {
@@ -75,6 +79,7 @@ void facts_inplace(vector<int> &args) {
7579
fact = fact * prods[z].back()[j] % mod;
7680
}
7781
}
82+
cp_algo::checkpoint("write ans");
7883
}
7984
}
8085

@@ -83,8 +88,11 @@ void solve() {
8388
cin >> n;
8489
vector<int> args(n);
8590
for(auto &x : args) {cin >> x;}
91+
cp_algo::checkpoint("input read");
8692
facts_inplace(args);
8793
for(auto it: args) {cout << it << "\n";}
94+
cp_algo::checkpoint("output written");
95+
cp_algo::checkpoint<1>();
8896
}
8997

9098
signed main() {

0 commit comments

Comments
 (0)