Skip to content

Commit 4630bc2

Browse files
多项式
1 parent f861cf4 commit 4630bc2

File tree

2 files changed

+220
-75
lines changed

2 files changed

+220
-75
lines changed

数学/Fast-Fourier-Transform.cpp

Lines changed: 73 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,64 @@
11
namespace fft {
2-
const int N = 1 << 20, M = 31768;
2+
const int N = 1<<20, M = 31768;
33

44
struct Complex {
55
double x, y;
66

7-
Complex() { x = y = 0; }
7+
Complex () { x = y = 0; }
88

9-
Complex(double _x, double _y) { x = _x, y = _y; }
9+
Complex (double _x, double _y) { x = _x, y = _y; }
1010

11-
Complex operator+(const Complex &r) const {
12-
return Complex(x + r.x, y + r.y);
11+
Complex operator+ (const Complex &r) const {
12+
return Complex (x + r.x, y + r.y);
1313
}
1414

15-
Complex operator-(const Complex &r) const {
16-
return Complex(x - r.x, y - r.y);
15+
Complex operator- (const Complex &r) const {
16+
return Complex (x - r.x, y - r.y);
1717
}
1818

19-
Complex operator*(const double k) const {
20-
return Complex(x * k, y * k);
19+
Complex operator* (const double k) const {
20+
return Complex (x * k, y * k);
2121
}
2222

23-
Complex operator/(const double k) const {
24-
return Complex(x / k, y / k);
23+
Complex operator/ (const double k) const {
24+
return Complex (x / k, y / k);
2525
}
2626

27-
Complex operator*(const Complex &r) const {
28-
return Complex(x * r.x - y * r.y, x * r.y + y * r.x);
27+
Complex operator* (const Complex &r) const {
28+
return Complex (x * r.x - y * r.y, x * r.y + y * r.x);
2929
}
3030

31-
int operator=(const int a) {
32-
*this = Complex(a, 0);
31+
int operator= (const int a) {
32+
*this = Complex (a, 0);
3333
return a;
3434
}
3535

36-
Complex conj() const {
37-
return Complex(x, -y);
36+
Complex conj () const {
37+
return Complex (x, -y);
3838
}
3939
};
4040

41-
const double pi = acos(-1.0);
41+
const double pi = acos (-1.0);
4242
Complex w[N];
4343
int rev[N];
4444

45-
void init(int L) {
46-
int n = 1 << L;
45+
void init (int L) {
46+
int n = 1<<L;
4747
for (int i = 0; i < n; ++i) {
4848
double ang = 2 * pi * i / n;
49-
w[i] = Complex(cos(ang), sin(ang));
50-
rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
49+
w[i] = Complex (cos (ang), sin (ang));
50+
rev[i] = (rev[i>>1]>>1) | ((i & 1)<<(L - 1));
5151
}
5252
}
5353

54-
void trans(Complex P[], int n, int oper) {
54+
void trans (Complex P[], int n, int oper) {
5555
for (int i = 0; i < n; i++) {
5656
if (i < rev[i]) {
57-
std::swap(P[i], P[rev[i]]);
57+
std::swap (P[i], P[rev[i]]);
5858
}
5959
}
60-
for (int d = 0; (1 << d) < n; d++) {
61-
int m = 1 << d, m2 = m * 2, rm = n / m2;
60+
for (int d = 0; (1<<d) < n; d++) {
61+
int m = 1<<d, m2 = m * 2, rm = n / m2;
6262
for (int i = 0; i < n; i += m2) {
6363
for (int j = 0; j < m; j++) {
6464
Complex &P1 = P[i + j + m], &P2 = P[i + j];
@@ -77,91 +77,89 @@ namespace fft {
7777

7878
Complex A[N], B[N], C1[N], C2[N];
7979

80-
std::vector <ll> conv(const std::vector<int> &a, const std::vector<int> &b) {
81-
int n = a.size(), m = b.size(), L = 0, s = 1;
80+
std::vector<ll> conv (const std::vector<int> &a, const std::vector<int> &b) {
81+
int n = a.size (), m = b.size (), L = 0, s = 1;
8282
while (s <= n + m - 2) s <<= 1, ++L;
83-
init(L);
83+
init (L);
8484
for (int i = 0; i < s; ++i) {
85-
A[i] = i < n ? Complex(a[i], 0) : Complex();
86-
B[i] = i < m ? Complex(b[i], 0) : Complex();
85+
A[i] = i < n ? Complex (a[i], 0) : Complex ();
86+
B[i] = i < m ? Complex (b[i], 0) : Complex ();
8787
}
88-
trans(A, s, 1);
89-
trans(B, s, 1);
88+
trans (A, s, 1);
89+
trans (B, s, 1);
9090
for (int i = 0; i < s; ++i) {
9191
A[i] = A[i] * B[i];
9292
}
9393
for (int i = 0; i < s; ++i) {
94-
w[i] = w[i].conj();
94+
w[i] = w[i].conj ();
9595
}
96-
trans(A, s, -1);
97-
std::vector <ll> res(n + m - 1);
96+
trans (A, s, -1);
97+
std::vector<ll> res (n + m - 1);
9898
for (int i = 0; i < n + m - 1; ++i) {
99-
res[i] = (ll)(A[i].x + 0.5);
99+
res[i] = (ll) (A[i].x + 0.5);
100100
}
101101
return res;
102102
}
103103

104-
std::vector <ll> fast_conv(const std::vector<int> &a, const std::vector<int> &b) {
105-
int n = a.size(), m = b.size(), L = 0, s = 1;
104+
std::vector<ll> fast_conv (const std::vector<int> &a, const std::vector<int> &b) {
105+
int n = a.size (), m = b.size (), L = 0, s = 1;
106106
for (; s <= n + m - 2; s <<= 1, ++L);
107107
s >>= 1, --L;
108-
init(L);
108+
init (L);
109109
for (int i = 0; i < s; ++i) {
110-
A[i].x = (i << 1) < n ? a[i << 1] : 0;
111-
B[i].x = (i << 1) < m ? b[i << 1] : 0;
112-
A[i].y = (i << 1 | 1) < n ? a[i << 1 | 1] : 0;
113-
B[i].y = (i << 1 | 1) < m ? b[i << 1 | 1] : 0;
110+
A[i].x = (i<<1) < n ? a[i<<1] : 0;
111+
B[i].x = (i<<1) < m ? b[i<<1] : 0;
112+
A[i].y = (i<<1 | 1) < n ? a[i<<1 | 1] : 0;
113+
B[i].y = (i<<1 | 1) < m ? b[i<<1 | 1] : 0;
114114
}
115-
trans(A, s, 1);
116-
trans(B, s, 1);
115+
trans (A, s, 1);
116+
trans (B, s, 1);
117117
for (int i = 0; i < s; ++i) {
118118
int j = (s - i) & (s - 1);
119-
C1[i] = (Complex(4, 0) * (A[j] * B[j]).conj() -
120-
(A[j].conj() - A[i]) * (B[j].conj() - B[i]) * (w[i] + Complex(1, 0))) * Complex(0, 0.25);
119+
C1[i] = (Complex (4, 0) * (A[j] * B[j]).conj () -
120+
(A[j].conj () - A[i]) * (B[j].conj () - B[i]) * (w[i] + Complex (1, 0))) * Complex (0, 0.25);
121121
}
122-
std::reverse(w + 1, w + s);
123-
trans(C1, s, -1);
124-
std::vector <ll> res(n + m);
122+
std::reverse (w + 1, w + s);
123+
trans (C1, s, -1);
124+
std::vector<ll> res (n + m);
125125
for (int i = 0; i <= (n + m - 1) / 2; ++i) {
126-
res[i << 1] = ll(C1[i].y + 0.5);
127-
res[i << 1 | 1] = ll(C1[i].x + 0.5);
126+
res[i<<1] = ll (C1[i].y + 0.5);
127+
res[i<<1 | 1] = ll (C1[i].x + 0.5);
128128
}
129-
res.resize(n + m - 1);
129+
res.resize (n + m - 1);
130130
return res;
131131
}
132132

133133
// arbitrary modulo convolution
134-
// n,m = degree + 1
135-
// x^2 + 2x +1 => n = 3
136-
void conv(int a[], int b[], int n, int m, int mod, int res[]) {
134+
void conv (int a[], int b[], int n, int m, int mod, int res[]) {
137135
int s = 1, L = 0;
138136
while (s <= n + m - 2) s <<= 1, ++L;
139-
init(L);
137+
init (L);
140138
for (int i = 0; i < s; ++i) {
141-
A[i] = i < n ? Complex(a[i] / M, a[i] % M) : Complex();
142-
B[i] = i < m ? Complex(b[i] / M, b[i] % M) : Complex();
139+
A[i] = i < n ? Complex (a[i] / M, a[i] % M) : Complex ();
140+
B[i] = i < m ? Complex (b[i] / M, b[i] % M) : Complex ();
143141
}
144-
trans(A, s, 1);
145-
trans(B, s, 1);
142+
trans (A, s, 1);
143+
trans (B, s, 1);
146144
for (int i = 0; i < s; ++i) {
147145
int j = i ? s - i : i;
148-
Complex a1 = (A[i] + A[j].conj()) * Complex(0.5, 0);
149-
Complex a2 = (A[i] - A[j].conj()) * Complex(0, -0.5);
150-
Complex b1 = (B[i] + B[j].conj()) * Complex(0.5, 0);
151-
Complex b2 = (B[i] - B[j].conj()) * Complex(0, -0.5);
146+
Complex a1 = (A[i] + A[j].conj ()) * Complex (0.5, 0);
147+
Complex a2 = (A[i] - A[j].conj ()) * Complex (0, -0.5);
148+
Complex b1 = (B[i] + B[j].conj ()) * Complex (0.5, 0);
149+
Complex b2 = (B[i] - B[j].conj ()) * Complex (0, -0.5);
152150
Complex c11 = a1 * b1, c12 = a1 * b2;
153151
Complex c21 = a2 * b1, c22 = a2 * b2;
154-
C1[j] = c11 + c12 * Complex(0, 1);
155-
C2[j] = c21 + c22 * Complex(0, 1);
152+
C1[j] = c11 + c12 * Complex (0, 1);
153+
C2[j] = c21 + c22 * Complex (0, 1);
156154
}
157-
trans(C1, s, -1);
158-
trans(C2, s, -1);
155+
trans (C1, s, -1);
156+
trans (C2, s, -1);
159157
for (int i = 0; i < n + m - 1; ++i) {
160-
int x = ll(C1[i].x + 0.5) % mod;
161-
int y1 = ll(C1[i].y + 0.5) % mod;
162-
int y2 = ll(C2[i].x + 0.5) % mod;
163-
int z = ll(C2[i].y + 0.5) % mod;
164-
res[i] = ((ll) x * M * M + (ll)(y1 + y2) * M + z) % mod;
158+
int x = ll (C1[i].x + 0.5) % mod;
159+
int y1 = ll (C1[i].y + 0.5) % mod;
160+
int y2 = ll (C2[i].x + 0.5) % mod;
161+
int z = ll (C2[i].y + 0.5) % mod;
162+
res[i] = ((ll) x * M * M + (ll) (y1 + y2) * M + z) % mod;
165163
}
166164
}
167165
}

数学/Polynomial.cpp

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
using fft::N;
2+
3+
namespace Poly {
4+
int tmp1[N], tmp2[N], tmp3[N], mod = ;
5+
int iv[N];
6+
7+
void init () {
8+
iv[1] = 1;
9+
for (int i = 2; i < N; ++i) {
10+
iv[i] = (mod - (ll) mod / i * iv[mod % i] % mod);
11+
}
12+
}
13+
14+
15+
// res = 1 / poly
16+
void inv (int poly[], int res[], int n) {
17+
int deg = n - 1;
18+
std::vector<int> degs;
19+
while (deg > 0) {
20+
degs.push_back (deg);
21+
deg >>= 1;
22+
}
23+
std::reverse (degs.begin (), degs.end ());
24+
res[0] = inverse (poly[0], mod);
25+
deg = 1;
26+
for (int t: degs) {
27+
fft::conv (poly, res, t + 1, deg, mod, tmp1);
28+
fft::conv (res, tmp1 + deg, t + 1 - deg, t + 1 - deg, mod, tmp1);
29+
for (int i = 0; i < t + 1 - deg; ++i) {
30+
res[i + deg] = mod - tmp1[i];
31+
}
32+
deg = t + 1;
33+
}
34+
}
35+
36+
// res = ln(poly), poly[0] should be 1
37+
void log (int poly[], int res[], int n) {
38+
assert(poly[0] == 1);
39+
inv (poly, tmp2, n);
40+
for (int i = 0; i < n - 1; ++i) {
41+
res[i] = (ll) poly[i + 1] * (i + 1) % mod;
42+
}
43+
fft::conv (res, tmp2, n - 1, n, mod, res);
44+
for (int i = n - 1; i >= 1; --i) {
45+
res[i] = (ll) res[i - 1] * iv[i] % mod;
46+
}
47+
res[0] = 0;
48+
}
49+
50+
// res = exp(poly), poly[0] should be 0
51+
void exp (int poly[], int res[], int n) {
52+
assert(poly[0] == 0);
53+
while (n & n - 1)n += lowbit(n);
54+
if (n == 1) {
55+
res[0] = 1;
56+
return;
57+
}
58+
exp (poly, res, n>>1);
59+
log (res, tmp3, n);
60+
for (int i = 0; i < n; ++i) {
61+
tmp3[i] = poly[i] - tmp3[i];
62+
if (tmp3[i] < 0) tmp3[i] += mod;
63+
}
64+
if (++tmp3[0] == mod) tmp3[0] = 0;
65+
fft::conv (tmp3, res, n, n, mod, res);
66+
memset (res + n, 0, sizeof (*res) * n);
67+
}
68+
69+
// res = sqrt(poly), poly[0] should be 1
70+
void sqrt (int poly[], int res[], int n) {
71+
if (n == 1) {
72+
res[0] = 1;
73+
return;
74+
}
75+
sqrt (poly, res, n>>1);
76+
inv (res, tmp2, n);
77+
int s = n<<1;
78+
memcpy (tmp1, poly, sizeof (*poly) * n);
79+
memset (tmp1 + n, 0, sizeof (*tmp1) * n);
80+
memset (res + n, 0, sizeof (*res) * n);
81+
// NTT::trans(tmp1, s, 1); NTT::trans(res, s, 1); NTT::trans(tmp2, s, 1);
82+
for (int i = 0; i < s; ++i) {
83+
res[i] = ((ll) res[i] * res[i] + tmp1[i]) % mod * iv[2] % mod * tmp2[i] % mod;
84+
}
85+
// NTT::trans(res, s, -1)
86+
memset (res + n, 0, sizeof (*res) * n);
87+
}
88+
}
89+
90+
// polynomial arithmetic in O(n^2)
91+
// mod should be a prime greater than n
92+
namespace poly_brutal {
93+
// f = f * g mod x^n
94+
void conv (int n, ll f[], ll g[], ll mod) {
95+
for (int i = n - 1; i >= 0; --i) {
96+
ll tmp = 0;
97+
for (int j = 0; j <= i; ++j) {
98+
tmp += f[j] * g[i - j] % mod;
99+
}
100+
f[i] = tmp % mod;
101+
}
102+
}
103+
104+
// f = g^{-1} mod x^n, g[0] != 0
105+
void inv (int n, ll f[], ll g[], ll mod) {
106+
assert(g[0] != 0);
107+
ll inv0 = inverse (g[0], mod);
108+
for (int i = 0; i < n; ++i) {
109+
ll tmp = 0;
110+
for (int j = 0; j < i; ++j) {
111+
tmp += f[j] * g[i - j] % mod;
112+
}
113+
if (!i) f[i] = inv0;
114+
else f[i] = (mod - tmp % mod) * inv0 % mod;
115+
}
116+
}
117+
118+
// f = ln(g) mod x^n, g[0] = 1
119+
void log (int n, ll f[], ll g[], ll mod) {
120+
assert(g[0] == 1);
121+
inv (n, f, g, mod);
122+
for (int i = n - 1; i >= 0; --i) {
123+
ll tmp = 0;
124+
for (int j = 0; j <= i; ++j) {
125+
tmp += f[i - j] * (j == n ? 0 : g[j + 1]) % mod * (j + 1) % mod;
126+
}
127+
f[i] = tmp % mod;
128+
}
129+
for (int i = n - 1; i >= 1; --i) {
130+
f[i] = f[i - 1] * inverse (i, mod) % mod;
131+
}
132+
f[0] = 0;
133+
}
134+
135+
// f= exp(g) mod x^n, g[0] = 0
136+
void exp (int n, ll f[], ll g[], ll mod) {
137+
assert(g[0] == 0);
138+
f[0] = 1;
139+
for (int i = 1; i < n; ++i) {
140+
ll tmp = 0;
141+
for (int j = 1; j <= i; ++j) {
142+
tmp += j * g[j] % mod * f[i - j] % mod;
143+
}
144+
f[i] = tmp % mod * inverse (i, mod) % mod;
145+
}
146+
}
147+
}

0 commit comments

Comments
 (0)