3
3
#include " cp-algo/util/checkpoint.hpp"
4
4
#include " cp-algo/util/bump_alloc.hpp"
5
5
#include " cp-algo/util/simd.hpp"
6
- #include " cp-algo/math/common .hpp"
6
+ #include " cp-algo/math/combinatorics .hpp"
7
7
#include " cp-algo/number_theory/modint.hpp"
8
8
#include < ranges>
9
9
@@ -37,38 +37,37 @@ namespace cp_algo::math {
37
37
t = mod - t - 1 ;
38
38
y = t % 2 ? 1 : mod-1 ;
39
39
}
40
- int pw = 0 ;
40
+ auto pw = 32ull * (t + 1 ) ;
41
41
while (t > limit_reg) {
42
42
limit_odd = std::max (limit_odd, (t - 1 ) / 2 );
43
43
odd_args_per_block[(t - 1 ) / 2 / subblock].push_back ({int (i), (t - 1 ) / 2 });
44
44
t /= 2 ;
45
45
pw += t;
46
46
}
47
47
reg_args_per_block[t / subblock].push_back ({int (i), t});
48
- y *= bpow ( base ( 2 ), pw );
48
+ y *= pow_fixed< base, 2 >( int (pw % (mod - 1 )) );
49
49
}
50
50
checkpoint (" init" );
51
- uint32_t b2x32 = ( 1ULL << 32 ) % mod ;
51
+ base bi2x32 = pow_fixed<base, 2 >( 32 ). inv () ;
52
52
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
53
53
base fact = 1 ;
54
54
for (int b = 0 ; b <= limit; b += accum * block) {
55
55
u32x8 cur[accum];
56
56
static std::array<u32x8, subblock> prods[accum];
57
57
for (int z = 0 ; z < accum; z++) {
58
58
for (int j = 0 ; j < simd_size; j++) {
59
+ #pragma GCC diagnostic push
60
+ #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
59
61
cur[z][j] = uint32_t (b + z * block + j * subblock);
60
62
cur[z][j] = proj (cur[z][j]);
61
63
prods[z][0 ][j] = cur[z][j] + !cur[z][j];
62
- #pragma GCC diagnostic push
63
- #pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
64
- cur[z][j] = uint32_t (uint64_t (cur[z][j]) * b2x32 % mod);
64
+ prods[z][0 ][j] = uint32_t (uint64_t (prods[z][0 ][j]) * bi2x32.getr () % mod);
65
65
#pragma GCC diagnostic pop
66
66
}
67
67
}
68
68
for (int i = 1 ; i < block / simd_size; i++) {
69
69
for (int z = 0 ; z < accum; z++) {
70
70
cur[z] += step;
71
- cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
72
71
prods[z][i] = montgomery_mul (prods[z][i - 1 ], cur[z], mod, imod);
73
72
}
74
73
}
@@ -85,12 +84,12 @@ namespace cp_algo::math {
85
84
checkpoint (" mul ans" );
86
85
}
87
86
};
88
- uint32_t b2x33 = ( 1ULL << 33 ) % mod ;
89
- process (limit_reg, reg_args_per_block, b2x32, std::identity{ });
90
- process (limit_odd, odd_args_per_block, b2x33, []( uint32_t x) { return 2 * x + 1 ;} );
87
+ process (limit_reg, reg_args_per_block, 1 , std::identity{}) ;
88
+ process (limit_odd, odd_args_per_block, 2 , []( uint32_t x) { return 2 * x + 1 ; });
89
+ auto invs = bulk_invs<base>(res );
91
90
for (auto [i, x]: res | std::views::enumerate) {
92
91
if (args[i] >= mod / 2 ) {
93
- x = x. inv () ;
92
+ x = invs[i] ;
94
93
}
95
94
}
96
95
checkpoint (" inv ans" );
0 commit comments