|
1 | 1 | // @brief Many Factorials
|
2 | 2 | #define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
|
3 | 3 | #pragma GCC optimize("Ofast,unroll-loops")
|
| 4 | +#define CP_ALGO_CHECKPOINT |
4 | 5 | #include <bits/stdc++.h>
|
5 |
| -//#define CP_ALGO_CHECKPOINT |
6 | 6 | #include "blazingio/blazingio.min.hpp"
|
7 |
| -#include "cp-algo/util/checkpoint.hpp" |
8 |
| -#include "cp-algo/util/simd.hpp" |
9 |
| -#include "cp-algo/util/bump_alloc.hpp" |
10 |
| -#include "cp-algo/math/common.hpp" |
| 7 | +#include "cp-algo/math/factorials.hpp" |
11 | 8 |
|
12 | 9 | using namespace std;
|
13 |
| -using namespace cp_algo; |
14 |
| - |
15 |
| -constexpr int mod = 998244353; |
16 |
| -constexpr int imod = -math::inv2(mod); |
17 |
| - |
18 |
| -template<bool use_bump_alloc = false, int maxn = 100'000> |
19 |
| -vector<int> facts(vector<int> const& args) { |
20 |
| - constexpr int accum = 4; |
21 |
| - constexpr int simd_size = 8; |
22 |
| - constexpr int block = 1 << 18; |
23 |
| - constexpr int subblock = block / simd_size; |
24 |
| - using T = array<int, 2>; |
25 |
| - using alloc = conditional_t<use_bump_alloc, |
26 |
| - bump_alloc<T, 30 * maxn>, |
27 |
| - allocator<T>>; |
28 |
| - basic_string<T, char_traits<T>, alloc> odd_args_per_block[mod / subblock]; |
29 |
| - basic_string<T, char_traits<T>, alloc> reg_args_per_block[mod / subblock]; |
30 |
| - constexpr int limit_reg = mod / 64; |
31 |
| - int limit_odd = 0; |
32 |
| - |
33 |
| - vector<int> res(size(args), 1); |
34 |
| - auto prod_mod = [&](uint64_t a, uint64_t b) { |
35 |
| - return (a * b) % mod; |
36 |
| - }; |
37 |
| - for(auto [i, xy]: views::zip(args, res) | views::enumerate) { |
38 |
| - auto [x, y] = xy; |
39 |
| - auto t = x; |
40 |
| - if(t >= mod / 2) { |
41 |
| - t = mod - t - 1; |
42 |
| - y = t % 2 ? 1 : mod - 1; |
43 |
| - } |
44 |
| - int pw = 0; |
45 |
| - while(t > limit_reg) { |
46 |
| - limit_odd = max(limit_odd, (t - 1) / 2); |
47 |
| - odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2}); |
48 |
| - t /= 2; |
49 |
| - pw += t; |
50 |
| - } |
51 |
| - reg_args_per_block[t / subblock].push_back({int(i), t}); |
52 |
| - y = int(y * math::bpow(2, pw, 1ULL, prod_mod) % mod); |
53 |
| - } |
54 |
| - cp_algo::checkpoint("init"); |
55 |
| - uint32_t b2x32 = (1ULL << 32) % mod; |
56 |
| - auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) { |
57 |
| - uint64_t fact = 1; |
58 |
| - for(int b = 0; b <= limit; b += accum * block) { |
59 |
| - u32x8 cur[accum]; |
60 |
| - static array<u32x8, subblock> prods[accum]; |
61 |
| - for(int z = 0; z < accum; z++) { |
62 |
| - for(int j = 0; j < simd_size; j++) { |
63 |
| - cur[z][j] = uint32_t(b + z * block + j * subblock); |
64 |
| - cur[z][j] = proj(cur[z][j]); |
65 |
| - prods[z][0][j] = cur[z][j] + !cur[z][j]; |
66 |
| - #pragma GCC diagnostic push |
67 |
| - #pragma GCC diagnostic ignored "-Wmaybe-uninitialized" |
68 |
| - cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod); |
69 |
| - #pragma GCC diagnostic pop |
70 |
| - } |
71 |
| - } |
72 |
| - for(int i = 1; i < block / simd_size; i++) { |
73 |
| - for(int z = 0; z < accum; z++) { |
74 |
| - cur[z] += step; |
75 |
| - cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z]; |
76 |
| - prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod); |
77 |
| - } |
78 |
| - } |
79 |
| - cp_algo::checkpoint("inner loop"); |
80 |
| - for(int z = 0; z < accum; z++) { |
81 |
| - for(int j = 0; j < simd_size; j++) { |
82 |
| - int bl = b + z * block + j * subblock; |
83 |
| - for(auto [i, x]: args_per_block[bl / subblock]) { |
84 |
| - auto ans = fact * prods[z][x - bl][j] % mod; |
85 |
| - res[i] = int(res[i] * ans % mod); |
86 |
| - } |
87 |
| - fact = fact * prods[z].back()[j] % mod; |
88 |
| - } |
89 |
| - } |
90 |
| - cp_algo::checkpoint("mul ans"); |
91 |
| - } |
92 |
| - }; |
93 |
| - uint32_t b2x33 = (1ULL << 33) % mod; |
94 |
| - process(limit_reg, reg_args_per_block, b2x32, identity{}); |
95 |
| - process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;}); |
96 |
| - for(auto [i, x]: res | views::enumerate) { |
97 |
| - if (args[i] >= mod / 2) { |
98 |
| - x = int(math::bpow(x, mod - 2, 1ULL, prod_mod)); |
99 |
| - } |
100 |
| - } |
101 |
| - cp_algo::checkpoint("inv ans"); |
102 |
| - return res; |
103 |
| -} |
| 10 | +using base = cp_algo::math::modint<998244353>; |
104 | 11 |
|
105 | 12 | void solve() {
|
106 | 13 | int n;
|
107 | 14 | cin >> n;
|
108 |
| - vector<int> args(n); |
| 15 | + vector<base> args(n); |
109 | 16 | for(auto &x : args) {cin >> x;}
|
110 | 17 | cp_algo::checkpoint("read");
|
111 | 18 | auto res = facts(args);
|
|
0 commit comments