Skip to content

Commit 244db65

Browse files
authored
Create fft_goo.cpp
1 parent 6544dd3 commit 244db65

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

docs/ch30/fft_goo.cpp

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
2+
3+
#include <complex>
4+
#include <iostream>
5+
#include <valarray>
6+
7+
8+
using namespace std;
9+
10+
typedef complex<double> base;
11+
12+
void fft(vector<base> &a, bool inv)
13+
{
14+
int n = a.size(), j = 0;
15+
vector<base> roots(n/2);
16+
17+
for(int i=1; i<n; i++)
18+
{
19+
int bit = (n >> 1);
20+
while(j >= bit)
21+
{
22+
j -= bit;
23+
bit >>= 1;
24+
}
25+
j += bit;
26+
if(i < j)
27+
swap(a[i], a[j]);
28+
}
29+
30+
double ang = 2 * acos(-1) / n * (inv ? -1 : 1);
31+
for(int i=0; i<n/2; i++)
32+
{
33+
roots[i] = base(cos(ang * i), sin(ang * i));
34+
}
35+
/* In NTT, let prr = primitive root. Then,
36+
int ang = ipow(prr, (mod - 1) / n);
37+
if(inv) ang = ipow(ang, mod - 2);
38+
for(int i=0; i<n/2; i++){
39+
roots[i] = (i ? (1ll * roots[i-1] * ang % mod) : 1);
40+
}
41+
XOR Convolution : set roots[*] = 1.
42+
OR Convolution : set roots[*] = 1, and do following:
43+
if (!inv) {
44+
a[j + k] = u + v;
45+
a[j + k + i/2] = u;
46+
} else {
47+
a[j + k] = v;
48+
a[j + k + i/2] = u - v;
49+
}
50+
*/
51+
for(int i=2; i<=n; i<<=1)
52+
{
53+
int step = n / i;
54+
for(int j=0; j<n; j+=i)
55+
{
56+
for(int k=0; k<i/2; k++)
57+
{
58+
base u = a[j+k],
59+
v = a[j+k+i/2] * roots[step * k];
60+
a[j+k] = u+v;
61+
a[j+k+i/2] = u-v;
62+
}
63+
}
64+
}
65+
if(inv)
66+
for(int i=0; i<n; i++)
67+
a[i] /= n; // skip for OR convolution.
68+
}
69+
70+
71+
vector<lint> multiply(vector<lint> &v, vector<lint> &w)
72+
{
73+
vector<base> fv(v.begin(), v.end()), fw(w.begin(), w.end());
74+
int n =2;
75+
while(n < v.size() + w.size())
76+
n <<=1;
77+
fv.resize(n);
78+
fw.resize(n);
79+
fft(fv,0);
80+
fft(fw,0);
81+
for(int i=0; i<n; i++)
82+
fv[i] *= fw[i];
83+
fft(fv,1);
84+
vector<lint> ret(n);
85+
for(int i=0; i<n; i++)
86+
ret[i] = (lint)round(fv[i].real());
87+
return ret;
88+
}
89+
90+
vector<lint> multiply(vector<lint> &v, vector<lint> &w, lint mod)
91+
{
92+
int n =2;
93+
while(n < v.size() + w.size())
94+
n <<=1s;
95+
vector<base> v1(n), v2(n), r1(n), r2(n);
96+
for(int i=0; i<v.size(); i++)
97+
{
98+
v1[i] = base(v[i] >> 15, v[i] & 32767);
99+
}
100+
for(int i=0; i<w.size(); i++)
101+
{
102+
v2[i] = base(w[i] >> 15, w[i] & 32767);
103+
}
104+
fft(v1,0);
105+
fft(v2,0);
106+
for(int i=0; i<n; i++)
107+
{
108+
int j = (i ? (n - i) : i);
109+
base ans1 = (v1[i] + conj(v1[j])) * base(0.5,0);
110+
base ans2 = (v1[i] - conj(v1[j])) * base(0, -0.5);
111+
base ans3 = (v2[i] + conj(v2[j])) * base(0.5,0);
112+
base ans4 = (v2[i] - conj(v2[j])) * base(0, -0.5);
113+
r1[i] = (ans1 * ans3) + (ans1 * ans4) * base(0,1);
114+
r2[i] = (ans2 * ans3) + (ans2 * ans4) * base(0,1);
115+
}
116+
fft(r1,1);
117+
fft(r2,1);
118+
vector<lint> ret(n);
119+
for(int i=0; i<n; i++)
120+
{
121+
lint av = (lint)round(r1[i].real());
122+
lint bv = (lint)round(r1[i].imag()) + (lint)round(r2[i].real());
123+
lint cv = (lint)round(r2[i].imag());
124+
av %= mod, bv %= mod, cv %= mod;
125+
ret[i] = (av << 30) + (bv << 15) + cv;
126+
ret[i] %= mod;
127+
ret[i] += mod;
128+
ret[i] %= mod;
129+
}
130+
return ret;
131+
}

0 commit comments

Comments
 (0)