Skip to content

Commit 7596d34

Browse files
committed
add bulk_invs and pow_fixed helpers
1 parent a03e963 commit 7596d34

File tree

2 files changed

+42
-19
lines changed

2 files changed

+42
-19
lines changed

cp-algo/math/combinatorics.hpp

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -33,17 +33,41 @@ namespace cp_algo::math {
3333
}
3434
return F[n];
3535
}
36-
template<typename T>
37-
T small_inv(auto n) {
38-
static std::vector<T> F(maxn);
36+
template<typename T, int base>
37+
T pow_fixed(int n) {
38+
static std::vector<T> prec_low(1 << 16);
39+
static std::vector<T> prec_high(1 << 16);
3940
static bool init = false;
4041
if(!init) {
41-
for(int i = 1; i < maxn; i++) {
42-
F[i] = rfact<T>(i) * fact<T>(i - 1);
43-
}
4442
init = true;
43+
prec_low[0] = prec_high[0] = T(1);
44+
T step_low = T(base);
45+
T step_high = bpow(T(base), 1 << 16);
46+
for(int i = 1; i < (1 << 16); i++) {
47+
prec_low[i] = prec_low[i - 1] * step_low;
48+
prec_high[i] = prec_high[i - 1] * step_high;
49+
}
4550
}
46-
return F[n];
51+
return prec_low[n & 0xFFFF] * prec_high[n >> 16];
52+
}
53+
template<typename T>
54+
std::vector<T> bulk_invs(auto const& args) {
55+
std::vector<T> res(size(args), args[0]);
56+
for(size_t i = 1; i < size(args); i++) {
57+
res[i] = res[i - 1] * args[i];
58+
}
59+
auto all_invs = T(1) / res.back();
60+
for(size_t i = size(args) - 1; i > 0; i--) {
61+
res[i] = all_invs * res[i - 1];
62+
all_invs *= args[i];
63+
}
64+
res[0] = all_invs;
65+
return res;
66+
}
67+
template<typename T>
68+
T small_inv(auto n) {
69+
static auto F = builk_invs<T>(std::views::iota(1) | std::views::take(maxn));
70+
return F[n - 1];
4771
}
4872
template<typename T>
4973
T binom_large(T n, auto r) {

cp-algo/math/factorials.hpp

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
#include "cp-algo/util/checkpoint.hpp"
44
#include "cp-algo/util/bump_alloc.hpp"
55
#include "cp-algo/util/simd.hpp"
6-
#include "cp-algo/math/common.hpp"
6+
#include "cp-algo/math/combinatorics.hpp"
77
#include "cp-algo/number_theory/modint.hpp"
88
#include <ranges>
99

@@ -37,38 +37,37 @@ namespace cp_algo::math {
3737
t = mod - t - 1;
3838
y = t % 2 ? 1 : mod-1;
3939
}
40-
int pw = 0;
40+
auto pw = 32ull * (t + 1);
4141
while(t > limit_reg) {
4242
limit_odd = std::max(limit_odd, (t - 1) / 2);
4343
odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2});
4444
t /= 2;
4545
pw += t;
4646
}
4747
reg_args_per_block[t / subblock].push_back({int(i), t});
48-
y *= bpow(base(2), pw);
48+
y *= pow_fixed<base, 2>(int(pw % (mod - 1)));
4949
}
5050
checkpoint("init");
51-
uint32_t b2x32 = (1ULL << 32) % mod;
51+
base bi2x32 = pow_fixed<base, 2>(32).inv();
5252
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
5353
base fact = 1;
5454
for(int b = 0; b <= limit; b += accum * block) {
5555
u32x8 cur[accum];
5656
static std::array<u32x8, subblock> prods[accum];
5757
for(int z = 0; z < accum; z++) {
5858
for(int j = 0; j < simd_size; j++) {
59+
#pragma GCC diagnostic push
60+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
5961
cur[z][j] = uint32_t(b + z * block + j * subblock);
6062
cur[z][j] = proj(cur[z][j]);
6163
prods[z][0][j] = cur[z][j] + !cur[z][j];
62-
#pragma GCC diagnostic push
63-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
64-
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
64+
prods[z][0][j] = uint32_t(uint64_t(prods[z][0][j]) * bi2x32.getr() % mod);
6565
#pragma GCC diagnostic pop
6666
}
6767
}
6868
for(int i = 1; i < block / simd_size; i++) {
6969
for(int z = 0; z < accum; z++) {
7070
cur[z] += step;
71-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
7271
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
7372
}
7473
}
@@ -85,12 +84,12 @@ namespace cp_algo::math {
8584
checkpoint("mul ans");
8685
}
8786
};
88-
uint32_t b2x33 = (1ULL << 33) % mod;
89-
process(limit_reg, reg_args_per_block, b2x32, std::identity{});
90-
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;});
87+
process(limit_reg, reg_args_per_block, 1, std::identity{});
88+
process(limit_odd, odd_args_per_block, 2, [](uint32_t x) {return 2 * x + 1;});
89+
auto invs = bulk_invs<base>(res);
9190
for(auto [i, x]: res | std::views::enumerate) {
9291
if (args[i] >= mod / 2) {
93-
x = x.inv();
92+
x = invs[i];
9493
}
9594
}
9695
checkpoint("inv ans");

0 commit comments

Comments
 (0)