Skip to content

Commit d131d01

Browse files
committed
update many_facts
1 parent 6ef9873 commit d131d01

File tree

1 file changed

+75
-62
lines changed

1 file changed

+75
-62
lines changed

verify/simd/many_facts.test.cpp

Lines changed: 75 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
#define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
44
#include <bits/stdc++.h>
5-
#define CP_ALGO_CHECKPOINT
6-
#include "cp-algo/util/checkpoint.hpp"
5+
//#define CP_ALGO_CHECKPOINT
76
#include "blazingio/blazingio.min.hpp"
7+
#include "cp-algo/util/checkpoint.hpp"
88
#include "cp-algo/util/simd.hpp"
99
#include "cp-algo/math/common.hpp"
1010

@@ -14,84 +14,97 @@ using namespace cp_algo;
1414
constexpr int mod = 998244353;
1515
constexpr int imod = -math::inv2(mod);
1616

17-
void facts_inplace(vector<int> &args) {
18-
constexpr int block = 1 << 16;
19-
static basic_string<size_t> args_per_block[mod / block];
20-
uint64_t limit = 0;
21-
for(auto [i, x]: args | views::enumerate) {
22-
if(x < mod / 2) {
23-
limit = max(limit, uint64_t(x));
24-
args_per_block[x / block].push_back(i);
25-
} else {
26-
limit = max(limit, uint64_t(mod - x - 1));
27-
args_per_block[(mod - x - 1) / block].push_back(i);
17+
vector<int> facts(vector<int> const& args) {
18+
constexpr int accum = 4;
19+
constexpr int simd_size = 8;
20+
constexpr int block = 1 << 18;
21+
constexpr int subblock = block / simd_size;
22+
static basic_string<array<int, 2>> odd_args_per_block[mod / subblock];
23+
static basic_string<array<int, 2>> reg_args_per_block[mod / subblock];
24+
constexpr int limit_reg = mod / 64;
25+
int limit_odd = 0;
26+
27+
vector<int> res(size(args), 1);
28+
auto prod_mod = [&](uint64_t a, uint64_t b) {
29+
return (a * b) % mod;
30+
};
31+
for(auto [i, xy]: views::zip(args, res) | views::enumerate) {
32+
auto [x, y] = xy;
33+
auto t = x;
34+
if(t >= mod / 2) {
35+
t = mod - t - 1;
36+
y = t % 2 ? 1 : mod - 1;
37+
}
38+
int pw = 0;
39+
while(t > limit_reg) {
40+
limit_odd = max(limit_odd, (t - 1) / 2);
41+
odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2});
42+
t /= 2;
43+
pw += t;
2844
}
45+
reg_args_per_block[t / subblock].push_back({int(i), t});
46+
y = int(y * math::bpow(2, pw, 1ULL, prod_mod) % mod);
2947
}
3048
cp_algo::checkpoint("init");
3149
uint32_t b2x32 = (1ULL << 32) % mod;
32-
uint64_t fact = 1;
33-
const int accum = 4;
34-
const int simd_size = 8;
35-
for(uint64_t b = 0; b <= limit; b += accum * block) {
36-
u32x8 cur[accum];
37-
static array<u32x8, block / simd_size> prods[accum];
38-
for(int z = 0; z < accum; z++) {
39-
for(int j = 0; j < simd_size; j++) {
40-
cur[z][j] = uint32_t(b + z * block + j * (block / simd_size));
41-
prods[z][0][j] = cur[z][j] + !(b || z || j);
42-
#pragma GCC diagnostic push
43-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
44-
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
45-
#pragma GCC diagnostic pop
46-
}
47-
}
48-
for(int i = 1; i < block / simd_size; i++) {
50+
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
51+
uint64_t fact = 1;
52+
for(int b = 0; b <= limit; b += accum * block) {
53+
u32x8 cur[accum];
54+
static array<u32x8, subblock> prods[accum];
4955
for(int z = 0; z < accum; z++) {
50-
cur[z] += b2x32;
51-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
52-
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
53-
}
54-
}
55-
cp_algo::checkpoint("inner loop");
56-
for(int z = 0; z < accum; z++) {
57-
uint64_t bl = b + z * block;
58-
for(auto i: args_per_block[bl / block]) {
59-
size_t x = args[i];
60-
if(x >= mod / 2) {
61-
x = mod - x - 1;
62-
}
63-
x -= bl;
64-
auto pre_blocks = x / (block / simd_size);
65-
auto in_block = x % (block / simd_size);
66-
auto ans = fact * prods[z][in_block][pre_blocks] % mod;
67-
for(size_t j = 0; j < pre_blocks; j++) {
68-
ans = ans * prods[z].back()[j] % mod;
56+
for(int j = 0; j < simd_size; j++) {
57+
cur[z][j] = uint32_t(b + z * block + j * subblock);
58+
cur[z][j] = proj(cur[z][j]);
59+
prods[z][0][j] = cur[z][j] + !cur[z][j];
60+
#pragma GCC diagnostic push
61+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
62+
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
63+
#pragma GCC diagnostic pop
6964
}
70-
if(args[i] >= mod / 2) {
71-
ans = math::bpow(ans, mod - 2, 1ULL, [](auto a, auto b){return a * b % mod;});
72-
args[i] = int(x % 2 ? ans : mod - ans);
73-
} else {
74-
args[i] = int(ans);
65+
}
66+
for(int i = 1; i < block / simd_size; i++) {
67+
for(int z = 0; z < accum; z++) {
68+
cur[z] += step;
69+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
70+
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
7571
}
7672
}
77-
args_per_block[bl / block].clear();
78-
for(int j = 0; j < simd_size; j++) {
79-
fact = fact * prods[z].back()[j] % mod;
73+
cp_algo::checkpoint("inner loop");
74+
for(int z = 0; z < accum; z++) {
75+
for(int j = 0; j < simd_size; j++) {
76+
int bl = b + z * block + j * subblock;
77+
for(auto [i, x]: args_per_block[bl / subblock]) {
78+
auto ans = fact * prods[z][x - bl][j] % mod;
79+
res[i] = int(res[i] * ans % mod);
80+
}
81+
fact = fact * prods[z].back()[j] % mod;
82+
}
8083
}
84+
cp_algo::checkpoint("mul ans");
85+
}
86+
};
87+
uint32_t b2x33 = (1ULL << 33) % mod;
88+
process(limit_reg, reg_args_per_block, b2x32, identity{});
89+
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;});
90+
for(auto [i, x]: res | views::enumerate) {
91+
if (args[i] >= mod / 2) {
92+
x = int(math::bpow(x, mod - 2, 1ULL, prod_mod));
8193
}
82-
cp_algo::checkpoint("write ans");
8394
}
95+
cp_algo::checkpoint("inv ans");
96+
return res;
8497
}
8598

8699
void solve() {
87100
int n;
88101
cin >> n;
89102
vector<int> args(n);
90103
for(auto &x : args) {cin >> x;}
91-
cp_algo::checkpoint("input read");
92-
facts_inplace(args);
93-
for(auto it: args) {cout << it << "\n";}
94-
cp_algo::checkpoint("output written");
104+
cp_algo::checkpoint("read");
105+
auto res = facts(args);
106+
for(auto it: res) {cout << it << "\n";}
107+
cp_algo::checkpoint("write");
95108
cp_algo::checkpoint<1>();
96109
}
97110

0 commit comments

Comments
 (0)