|
| 1 | +#ifndef ATCODER_CONVOLUTION_HPP |
| 2 | +#define ATCODER_CONVOLUTION_HPP 1 |
| 3 | + |
| 4 | +#include <algorithm> |
| 5 | +#include <array> |
| 6 | +#include <atcoder/internal_bit> |
| 7 | +#include <atcoder/modint> |
| 8 | +#include <cassert> |
| 9 | +#include <type_traits> |
| 10 | +#include <vector> |
| 11 | + |
| 12 | +namespace atcoder { |
| 13 | + |
| 14 | +namespace internal { |
| 15 | + |
| 16 | +template <class mint, internal::is_static_modint_t<mint>* = nullptr> |
| 17 | +void butterfly(std::vector<mint>& a) { |
| 18 | + static constexpr int g = internal::primitive_root<mint::mod()>; |
| 19 | + int n = int(a.size()); |
| 20 | + int h = internal::ceil_pow2(n); |
| 21 | + |
| 22 | + static bool first = true; |
| 23 | + static mint sum_e[30]; // sum_e[i] = ies[0] * ... * ies[i - 1] * es[i] |
| 24 | + if (first) { |
| 25 | + first = false; |
| 26 | + mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 |
| 27 | + int cnt2 = bsf(mint::mod() - 1); |
| 28 | + mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); |
| 29 | + for (int i = cnt2; i >= 2; i--) { |
| 30 | + // e^(2^i) == 1 |
| 31 | + es[i - 2] = e; |
| 32 | + ies[i - 2] = ie; |
| 33 | + e *= e; |
| 34 | + ie *= ie; |
| 35 | + } |
| 36 | + mint now = 1; |
| 37 | + for (int i = 0; i < cnt2 - 2; i++) { |
| 38 | + sum_e[i] = es[i] * now; |
| 39 | + now *= ies[i]; |
| 40 | + } |
| 41 | + } |
| 42 | + for (int ph = 1; ph <= h; ph++) { |
| 43 | + int w = 1 << (ph - 1), p = 1 << (h - ph); |
| 44 | + mint now = 1; |
| 45 | + for (int s = 0; s < w; s++) { |
| 46 | + int offset = s << (h - ph + 1); |
| 47 | + for (int i = 0; i < p; i++) { |
| 48 | + auto l = a[i + offset]; |
| 49 | + auto r = a[i + offset + p] * now; |
| 50 | + a[i + offset] = l + r; |
| 51 | + a[i + offset + p] = l - r; |
| 52 | + } |
| 53 | + now *= sum_e[bsf(~(unsigned int)(s))]; |
| 54 | + } |
| 55 | + } |
| 56 | +} |
| 57 | + |
| 58 | +template <class mint, internal::is_static_modint_t<mint>* = nullptr> |
| 59 | +void butterfly_inv(std::vector<mint>& a) { |
| 60 | + static constexpr int g = internal::primitive_root<mint::mod()>; |
| 61 | + int n = int(a.size()); |
| 62 | + int h = internal::ceil_pow2(n); |
| 63 | + |
| 64 | + static bool first = true; |
| 65 | + static mint sum_ie[30]; // sum_ie[i] = es[0] * ... * es[i - 1] * ies[i] |
| 66 | + if (first) { |
| 67 | + first = false; |
| 68 | + mint es[30], ies[30]; // es[i]^(2^(2+i)) == 1 |
| 69 | + int cnt2 = bsf(mint::mod() - 1); |
| 70 | + mint e = mint(g).pow((mint::mod() - 1) >> cnt2), ie = e.inv(); |
| 71 | + for (int i = cnt2; i >= 2; i--) { |
| 72 | + // e^(2^i) == 1 |
| 73 | + es[i - 2] = e; |
| 74 | + ies[i - 2] = ie; |
| 75 | + e *= e; |
| 76 | + ie *= ie; |
| 77 | + } |
| 78 | + mint now = 1; |
| 79 | + for (int i = 0; i < cnt2 - 2; i++) { |
| 80 | + sum_ie[i] = ies[i] * now; |
| 81 | + now *= es[i]; |
| 82 | + } |
| 83 | + } |
| 84 | + |
| 85 | + for (int ph = h; ph >= 1; ph--) { |
| 86 | + int w = 1 << (ph - 1), p = 1 << (h - ph); |
| 87 | + mint inow = 1; |
| 88 | + for (int s = 0; s < w; s++) { |
| 89 | + int offset = s << (h - ph + 1); |
| 90 | + for (int i = 0; i < p; i++) { |
| 91 | + auto l = a[i + offset]; |
| 92 | + auto r = a[i + offset + p]; |
| 93 | + a[i + offset] = l + r; |
| 94 | + a[i + offset + p] = |
| 95 | + (unsigned long long)(mint::mod() + l.val() - r.val()) * |
| 96 | + inow.val(); |
| 97 | + } |
| 98 | + inow *= sum_ie[bsf(~(unsigned int)(s))]; |
| 99 | + } |
| 100 | + } |
| 101 | +} |
| 102 | + |
| 103 | +} // namespace internal |
| 104 | + |
| 105 | +template <class mint, internal::is_static_modint_t<mint>* = nullptr> |
| 106 | +std::vector<mint> convolution(std::vector<mint> a, std::vector<mint> b) { |
| 107 | + int n = int(a.size()), m = int(b.size()); |
| 108 | + if (!n || !m) return {}; |
| 109 | + if (std::min(n, m) <= 60) { |
| 110 | + if (n < m) { |
| 111 | + std::swap(n, m); |
| 112 | + std::swap(a, b); |
| 113 | + } |
| 114 | + std::vector<mint> ans(n + m - 1); |
| 115 | + for (int i = 0; i < n; i++) { |
| 116 | + for (int j = 0; j < m; j++) { |
| 117 | + ans[i + j] += a[i] * b[j]; |
| 118 | + } |
| 119 | + } |
| 120 | + return ans; |
| 121 | + } |
| 122 | + int z = 1 << internal::ceil_pow2(n + m - 1); |
| 123 | + a.resize(z); |
| 124 | + internal::butterfly(a); |
| 125 | + b.resize(z); |
| 126 | + internal::butterfly(b); |
| 127 | + for (int i = 0; i < z; i++) { |
| 128 | + a[i] *= b[i]; |
| 129 | + } |
| 130 | + internal::butterfly_inv(a); |
| 131 | + a.resize(n + m - 1); |
| 132 | + mint iz = mint(z).inv(); |
| 133 | + for (int i = 0; i < n + m - 1; i++) a[i] *= iz; |
| 134 | + return a; |
| 135 | +} |
| 136 | + |
| 137 | +template <unsigned int mod = 998244353, |
| 138 | + class T, |
| 139 | + std::enable_if_t<internal::is_integral<T>::value>* = nullptr> |
| 140 | +std::vector<T> convolution(const std::vector<T>& a, const std::vector<T>& b) { |
| 141 | + int n = int(a.size()), m = int(b.size()); |
| 142 | + if (!n || !m) return {}; |
| 143 | + |
| 144 | + using mint = static_modint<mod>; |
| 145 | + std::vector<mint> a2(n), b2(m); |
| 146 | + for (int i = 0; i < n; i++) { |
| 147 | + a2[i] = mint(a[i]); |
| 148 | + } |
| 149 | + for (int i = 0; i < m; i++) { |
| 150 | + b2[i] = mint(b[i]); |
| 151 | + } |
| 152 | + auto c2 = convolution(move(a2), move(b2)); |
| 153 | + std::vector<T> c(n + m - 1); |
| 154 | + for (int i = 0; i < n + m - 1; i++) { |
| 155 | + c[i] = c2[i].val(); |
| 156 | + } |
| 157 | + return c; |
| 158 | +} |
| 159 | + |
| 160 | +std::vector<long long> convolution_ll(const std::vector<long long>& a, |
| 161 | + const std::vector<long long>& b) { |
| 162 | + int n = int(a.size()), m = int(b.size()); |
| 163 | + if (!n || !m) return {}; |
| 164 | + |
| 165 | + static constexpr unsigned long long MOD1 = 754974721; // 2^24 |
| 166 | + static constexpr unsigned long long MOD2 = 167772161; // 2^25 |
| 167 | + static constexpr unsigned long long MOD3 = 469762049; // 2^26 |
| 168 | + static constexpr unsigned long long M2M3 = MOD2 * MOD3; |
| 169 | + static constexpr unsigned long long M1M3 = MOD1 * MOD3; |
| 170 | + static constexpr unsigned long long M1M2 = MOD1 * MOD2; |
| 171 | + static constexpr unsigned long long M1M2M3 = MOD1 * MOD2 * MOD3; |
| 172 | + |
| 173 | + static constexpr unsigned long long i1 = |
| 174 | + internal::inv_gcd(MOD2 * MOD3, MOD1).second; |
| 175 | + static constexpr unsigned long long i2 = |
| 176 | + internal::inv_gcd(MOD1 * MOD3, MOD2).second; |
| 177 | + static constexpr unsigned long long i3 = |
| 178 | + internal::inv_gcd(MOD1 * MOD2, MOD3).second; |
| 179 | + |
| 180 | + auto c1 = convolution<MOD1>(a, b); |
| 181 | + auto c2 = convolution<MOD2>(a, b); |
| 182 | + auto c3 = convolution<MOD3>(a, b); |
| 183 | + |
| 184 | + std::vector<long long> c(n + m - 1); |
| 185 | + for (int i = 0; i < n + m - 1; i++) { |
| 186 | + unsigned long long x = 0; |
| 187 | + x += (c1[i] * i1) % MOD1 * M2M3; |
| 188 | + x += (c2[i] * i2) % MOD2 * M1M3; |
| 189 | + x += (c3[i] * i3) % MOD3 * M1M2; |
| 190 | + // B = 2^63, -B <= x, r(real value) < B |
| 191 | + // (x, x - M, x - 2M, or x - 3M) = r (mod 2B) |
| 192 | + // r = c1[i] (mod MOD1) |
| 193 | + // focus on MOD1 |
| 194 | + // r = x, x - M', x - 2M', x - 3M' (M' = M % 2^64) (mod 2B) |
| 195 | + // r = x, |
| 196 | + // x - M' + (0 or 2B), |
| 197 | + // x - 2M' + (0, 2B or 4B), |
| 198 | + // x - 3M' + (0, 2B, 4B or 6B) (without mod!) |
| 199 | + // (r - x) = 0, (0) |
| 200 | + // - M' + (0 or 2B), (1) |
| 201 | + // -2M' + (0 or 2B or 4B), (2) |
| 202 | + // -3M' + (0 or 2B or 4B or 6B) (3) (mod MOD1) |
| 203 | + // we checked that |
| 204 | + // ((1) mod MOD1) mod 5 = 2 |
| 205 | + // ((2) mod MOD1) mod 5 = 3 |
| 206 | + // ((3) mod MOD1) mod 5 = 4 |
| 207 | + long long diff = |
| 208 | + c1[i] - internal::safe_mod((long long)(x), (long long)(MOD1)); |
| 209 | + if (diff < 0) diff += MOD1; |
| 210 | + static constexpr unsigned long long offset[5] = { |
| 211 | + 0, 0, M1M2M3, 2 * M1M2M3, 3 * M1M2M3}; |
| 212 | + x -= offset[diff % 5]; |
| 213 | + c[i] = x; |
| 214 | + } |
| 215 | + |
| 216 | + return c; |
| 217 | +} |
| 218 | + |
| 219 | +} // namespace atcoder |
| 220 | + |
| 221 | +#endif // ATCODER_CONVOLUTION_HPP |
0 commit comments