Skip to content

Commit bc21b44

Browse files
committed
move factorials to library file
1 parent 0f96754 commit bc21b44

File tree

3 files changed

+107
-105
lines changed

3 files changed

+107
-105
lines changed

cp-algo/math/factorials.hpp

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
#ifndef CP_ALGO_MATH_FACTORIALS_HPP
2+
#define CP_ALGO_MATH_FACTORIALS_HPP
3+
#include "cp-algo/util/checkpoint.hpp"
4+
#include "cp-algo/util/bump_alloc.hpp"
5+
#include "cp-algo/util/simd.hpp"
6+
#include "cp-algo/math/common.hpp"
7+
#include "cp-algo/number_theory/modint.hpp"
8+
#include <ranges>
9+
10+
namespace cp_algo::math {
11+
template<bool use_bump_alloc = false, int maxn = 100'000>
12+
auto facts(auto const& args) {
13+
constexpr int max_mod = 1'000'000'000;
14+
constexpr int accum = 4;
15+
constexpr int simd_size = 8;
16+
constexpr int block = 1 << 18;
17+
constexpr int subblock = block / simd_size;
18+
using base = std::decay_t<decltype(args[0])>;
19+
static_assert(modint_type<base>, "Base type must be a modint type");
20+
using T = std::array<int, 2>;
21+
using alloc = std::conditional_t<use_bump_alloc,
22+
bump_alloc<T, 30 * maxn>,
23+
big_alloc<T>>;
24+
std::basic_string<T, std::char_traits<T>, alloc> odd_args_per_block[max_mod / subblock];
25+
std::basic_string<T, std::char_traits<T>, alloc> reg_args_per_block[max_mod / subblock];
26+
constexpr int limit_reg = max_mod / 64;
27+
int limit_odd = 0;
28+
29+
std::vector<base, big_alloc<base>> res(size(args), 1);
30+
const int mod = base::mod();
31+
const int imod = -math::inv2(mod);
32+
for(auto [i, xy]: std::views::zip(args, res) | std::views::enumerate) {
33+
auto [x, y] = xy;
34+
int t = x.getr();
35+
if(t >= mod / 2) {
36+
t = mod - t - 1;
37+
y = t % 2 ? 1 : mod-1;
38+
}
39+
int pw = 0;
40+
while(t > limit_reg) {
41+
limit_odd = std::max(limit_odd, (t - 1) / 2);
42+
odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2});
43+
t /= 2;
44+
pw += t;
45+
}
46+
reg_args_per_block[t / subblock].push_back({int(i), t});
47+
y *= bpow(base(2), pw);
48+
}
49+
checkpoint("init");
50+
uint32_t b2x32 = (1ULL << 32) % mod;
51+
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
52+
base fact = 1;
53+
for(int b = 0; b <= limit; b += accum * block) {
54+
u32x8 cur[accum];
55+
static std::array<u32x8, subblock> prods[accum];
56+
for(int z = 0; z < accum; z++) {
57+
for(int j = 0; j < simd_size; j++) {
58+
cur[z][j] = uint32_t(b + z * block + j * subblock);
59+
cur[z][j] = proj(cur[z][j]);
60+
prods[z][0][j] = cur[z][j] + !cur[z][j];
61+
#pragma GCC diagnostic push
62+
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
63+
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
64+
#pragma GCC diagnostic pop
65+
}
66+
}
67+
for(int i = 1; i < block / simd_size; i++) {
68+
for(int z = 0; z < accum; z++) {
69+
cur[z] += step;
70+
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
71+
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
72+
}
73+
}
74+
checkpoint("inner loop");
75+
for(int z = 0; z < accum; z++) {
76+
for(int j = 0; j < simd_size; j++) {
77+
int bl = b + z * block + j * subblock;
78+
for(auto [i, x]: args_per_block[bl / subblock]) {
79+
res[i] *= fact * prods[z][x - bl][j];
80+
}
81+
fact *= base(prods[z].back()[j]);
82+
}
83+
}
84+
checkpoint("mul ans");
85+
}
86+
};
87+
uint32_t b2x33 = (1ULL << 33) % mod;
88+
process(limit_reg, reg_args_per_block, b2x32, std::identity{});
89+
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;});
90+
for(auto [i, x]: res | std::views::enumerate) {
91+
if (args[i] >= mod / 2) {
92+
x = x.inv();
93+
}
94+
}
95+
checkpoint("inv ans");
96+
return res;
97+
}
98+
}
99+
#endif // CP_ALGO_MATH_FACTORIALS_HPP

cp-algo/util/simd.hpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -41,13 +41,9 @@ namespace cp_algo {
4141
[[gnu::always_inline]] inline u64x4 low32(u64x4 x) {
4242
return x & uint32_t(-1);
4343
}
44-
[[gnu::always_inline]] inline auto rotr(auto x) {
45-
return decltype(x)(__builtin_shufflevector(u32x8(x), u32x8(x), 1, 2, 3, 0, 5, 6, 7, 4));
44+
[[gnu::always_inline]] inline auto swap_bytes(auto x) {
45+
return decltype(x)(__builtin_shufflevector(u32x8(x), u32x8(x), 1, 0, 3, 2, 5, 4, 7, 6));
4646
}
47-
[[gnu::always_inline]] inline auto rotl(auto x) {
48-
return decltype(x)(__builtin_shufflevector(u32x8(x), u32x8(x), 3, 0, 1, 2, 7, 4, 5, 6));
49-
}
50-
5147
[[gnu::always_inline]] inline u64x4 montgomery_reduce(u64x4 x, uint32_t mod, uint32_t imod) {
5248
#ifdef __AVX2__
5349
auto x_ninv = u64x4(_mm256_mul_epu32(__m256i(x), __m256i() + imod));
@@ -56,7 +52,7 @@ namespace cp_algo {
5652
auto x_ninv = x * imod;
5753
x += low32(x_ninv) * mod;
5854
#endif
59-
return rotr(x);
55+
return swap_bytes(x);
6056
}
6157

6258
[[gnu::always_inline]] inline u64x4 montgomery_mul(u64x4 x, u64x4 y, uint32_t mod, uint32_t imod) {
@@ -68,7 +64,7 @@ namespace cp_algo {
6864
}
6965
[[gnu::always_inline]] inline u32x8 montgomery_mul(u32x8 x, u32x8 y, uint32_t mod, uint32_t imod) {
7066
return u32x8(montgomery_mul(u64x4(x), u64x4(y), mod, imod)) |
71-
u32x8(rotl(montgomery_mul(u64x4(rotr(x)), u64x4(rotr(y)), mod, imod)));
67+
u32x8(swap_bytes(montgomery_mul(u64x4(swap_bytes(x)), u64x4(swap_bytes(y)), mod, imod)));
7268
}
7369
[[gnu::always_inline]] inline dx4 rotate_right(dx4 x) {
7470
static constexpr u64x4 shuffler = {3, 0, 1, 2};

verify/simd/many_facts.test.cpp

Lines changed: 4 additions & 97 deletions
Original file line numberDiff line numberDiff line change
@@ -1,111 +1,18 @@
11
// @brief Many Factorials
22
#define PROBLEM "https://judge.yosupo.jp/problem/many_factorials"
33
#pragma GCC optimize("Ofast,unroll-loops")
4+
#define CP_ALGO_CHECKPOINT
45
#include <bits/stdc++.h>
5-
//#define CP_ALGO_CHECKPOINT
66
#include "blazingio/blazingio.min.hpp"
7-
#include "cp-algo/util/checkpoint.hpp"
8-
#include "cp-algo/util/simd.hpp"
9-
#include "cp-algo/util/bump_alloc.hpp"
10-
#include "cp-algo/math/common.hpp"
7+
#include "cp-algo/math/factorials.hpp"
118

129
using namespace std;
13-
using namespace cp_algo;
14-
15-
constexpr int mod = 998244353;
16-
constexpr int imod = -math::inv2(mod);
17-
18-
template<bool use_bump_alloc = false, int maxn = 100'000>
19-
vector<int> facts(vector<int> const& args) {
20-
constexpr int accum = 4;
21-
constexpr int simd_size = 8;
22-
constexpr int block = 1 << 18;
23-
constexpr int subblock = block / simd_size;
24-
using T = array<int, 2>;
25-
using alloc = conditional_t<use_bump_alloc,
26-
bump_alloc<T, 30 * maxn>,
27-
allocator<T>>;
28-
basic_string<T, char_traits<T>, alloc> odd_args_per_block[mod / subblock];
29-
basic_string<T, char_traits<T>, alloc> reg_args_per_block[mod / subblock];
30-
constexpr int limit_reg = mod / 64;
31-
int limit_odd = 0;
32-
33-
vector<int> res(size(args), 1);
34-
auto prod_mod = [&](uint64_t a, uint64_t b) {
35-
return (a * b) % mod;
36-
};
37-
for(auto [i, xy]: views::zip(args, res) | views::enumerate) {
38-
auto [x, y] = xy;
39-
auto t = x;
40-
if(t >= mod / 2) {
41-
t = mod - t - 1;
42-
y = t % 2 ? 1 : mod - 1;
43-
}
44-
int pw = 0;
45-
while(t > limit_reg) {
46-
limit_odd = max(limit_odd, (t - 1) / 2);
47-
odd_args_per_block[(t - 1) / 2 / subblock].push_back({int(i), (t - 1) / 2});
48-
t /= 2;
49-
pw += t;
50-
}
51-
reg_args_per_block[t / subblock].push_back({int(i), t});
52-
y = int(y * math::bpow(2, pw, 1ULL, prod_mod) % mod);
53-
}
54-
cp_algo::checkpoint("init");
55-
uint32_t b2x32 = (1ULL << 32) % mod;
56-
auto process = [&](int limit, auto &args_per_block, auto step, auto &&proj) {
57-
uint64_t fact = 1;
58-
for(int b = 0; b <= limit; b += accum * block) {
59-
u32x8 cur[accum];
60-
static array<u32x8, subblock> prods[accum];
61-
for(int z = 0; z < accum; z++) {
62-
for(int j = 0; j < simd_size; j++) {
63-
cur[z][j] = uint32_t(b + z * block + j * subblock);
64-
cur[z][j] = proj(cur[z][j]);
65-
prods[z][0][j] = cur[z][j] + !cur[z][j];
66-
#pragma GCC diagnostic push
67-
#pragma GCC diagnostic ignored "-Wmaybe-uninitialized"
68-
cur[z][j] = uint32_t(uint64_t(cur[z][j]) * b2x32 % mod);
69-
#pragma GCC diagnostic pop
70-
}
71-
}
72-
for(int i = 1; i < block / simd_size; i++) {
73-
for(int z = 0; z < accum; z++) {
74-
cur[z] += step;
75-
cur[z] = cur[z] >= mod ? cur[z] - mod : cur[z];
76-
prods[z][i] = montgomery_mul(prods[z][i - 1], cur[z], mod, imod);
77-
}
78-
}
79-
cp_algo::checkpoint("inner loop");
80-
for(int z = 0; z < accum; z++) {
81-
for(int j = 0; j < simd_size; j++) {
82-
int bl = b + z * block + j * subblock;
83-
for(auto [i, x]: args_per_block[bl / subblock]) {
84-
auto ans = fact * prods[z][x - bl][j] % mod;
85-
res[i] = int(res[i] * ans % mod);
86-
}
87-
fact = fact * prods[z].back()[j] % mod;
88-
}
89-
}
90-
cp_algo::checkpoint("mul ans");
91-
}
92-
};
93-
uint32_t b2x33 = (1ULL << 33) % mod;
94-
process(limit_reg, reg_args_per_block, b2x32, identity{});
95-
process(limit_odd, odd_args_per_block, b2x33, [](uint32_t x) {return 2 * x + 1;});
96-
for(auto [i, x]: res | views::enumerate) {
97-
if (args[i] >= mod / 2) {
98-
x = int(math::bpow(x, mod - 2, 1ULL, prod_mod));
99-
}
100-
}
101-
cp_algo::checkpoint("inv ans");
102-
return res;
103-
}
10+
using base = cp_algo::math::modint<998244353>;
10411

10512
void solve() {
10613
int n;
10714
cin >> n;
108-
vector<int> args(n);
15+
vector<base> args(n);
10916
for(auto &x : args) {cin >> x;}
11017
cp_algo::checkpoint("read");
11118
auto res = facts(args);

0 commit comments

Comments
 (0)