Skip to content

Commit 8fbf80f

Browse files
Update BerlekampMassey.cpp
1 parent d2ad842 commit 8fbf80f

File tree

1 file changed

+151
-68
lines changed

1 file changed

+151
-68
lines changed

BerlekampMassey.cpp

Lines changed: 151 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,79 @@
1-
// given first m items init[0..m-1] and coefficents trans[0..m-1] or
2-
// given first 2 *m items init[0..2m-1], it will compute trans[0..m-1]
3-
// for you. trans[0..m] should be given as that
1+
#include <cstdio>
2+
#include <vector>
3+
#include <cassert>
4+
#include <functional>
5+
#include <algorithm>
6+
#include <random>
7+
8+
// 1. Given first m items init[0..m-1] and coefficents trans[0..m-1]
9+
//
10+
// ```cpp
11+
// std::vector<int64> init = {1, 1};
12+
// std::vector<int64> trans = {1, 1};
13+
// const int mod = 1e9 + 7; // 998244353; 1000000006;
14+
// LinearRecurrence lr{init, trans, mod};
15+
// std::cout << lr.calc(1000000000000000000ll) << std::endl;
16+
// ```
17+
//
18+
// 2. Given first 2 * m items init[0..2m-1], it will compute trans[0..m-1]
19+
// trans[0..m] will be given as that:
420
// init[m] = sum_{i=0}^{m-1} init[i] * trans[i]
21+
// you should make sure that init[0] is not zero and init[i] is in [0, mod - 1]
22+
//
23+
// ```cpp
24+
// std::vector<int64> init = {1, 1, 2, 3, 5, 8, 13};
25+
// const int prime_mod = 998244353;
26+
// LinearRecurrence lr1{init, prime_mod};
27+
// std::cout << lr.calc(1000000000000000000ll) << std::endl;
28+
// const int non_prime_mod = 1000000006;
29+
// LinearRecurrence lr2{init, non_prime_mod, false};
30+
// std::cout << lr.calc(1000000000000000000ll) << std::endl;
31+
// ```
532
struct LinearRecurrence {
633
using int64 = long long;
734
using vec = std::vector<int64>;
835

9-
static void extand(vec &a, size_t d, int64 value = 0) {
10-
if (d <= a.size()) return;
36+
static void extend(vec &a, size_t d, int64 value = 0) {
37+
if(d <= a.size()) return;
1138
a.resize(d, value);
1239
}
1340

1441
static vec BerlekampMassey(const vec &s, int64 mod) {
15-
std::function < int64(int64) > inverse = [&](int64 a) {
42+
std::function<int64(int64)> inverse = [&](int64 a) {
1643
return a == 1 ? 1 : (int64) (mod - mod / a) * inverse(mod % a) % mod;
1744
};
1845
vec A = {1}, B = {1};
1946
int64 b = s[0];
20-
for (size_t i = 1, m = 1; i < s.size(); ++i, m++) {
47+
assert(b != 0);
48+
for(size_t i = 1, m = 1; i < s.size(); ++i, m++) {
2149
int64 d = 0;
22-
for (size_t j = 0; j < A.size(); ++j) {
50+
for(size_t j = 0; j < A.size(); ++j) {
2351
d += A[j] * s[i - j] % mod;
2452
}
25-
if (!(d %= mod)) continue;
26-
if (2 * (A.size() - 1) <= i) {
53+
if(!(d %= mod)) continue;
54+
if(2 * (A.size() - 1) <= i) {
2755
auto temp = A;
28-
extand(A, B.size() + m);
56+
extend(A, B.size() + m);
2957
int64 coef = d * inverse(b) % mod;
30-
for (size_t j = 0; j < B.size(); ++j) {
58+
for(size_t j = 0; j < B.size(); ++j) {
3159
A[j + m] -= coef * B[j] % mod;
32-
if (A[j + m] < 0) A[j + m] += mod;
60+
if(A[j + m] < 0) A[j + m] += mod;
3361
}
3462
B = temp, b = d, m = 0;
3563
} else {
36-
extand(A, B.size() + m);
64+
extend(A, B.size() + m);
3765
int64 coef = d * inverse(b) % mod;
38-
for (size_t j = 0; j < B.size(); ++j) {
66+
for(size_t j = 0; j < B.size(); ++j) {
3967
A[j + m] -= coef * B[j] % mod;
40-
if (A[j + m] < 0) A[j + m] += mod;
68+
if(A[j + m] < 0) A[j + m] += mod;
4169
}
4270
}
4371
}
4472
return A;
4573
}
4674

4775
static void exgcd(int64 a, int64 b, int64 &g, int64 &x, int64 &y) {
48-
if (!b) x = 1, y = 0, g = a;
76+
if(!b) x = 1, y = 0, g = a;
4977
else {
5078
exgcd(b, a % b, g, y, x);
5179
y -= x * (a / b);
@@ -55,8 +83,8 @@ struct LinearRecurrence {
5583
static int64 crt(const vec &c, const vec &m) {
5684
int n = c.size();
5785
int64 M = 1, ans = 0;
58-
for (int i = 0; i < n; ++i) M *= m[i];
59-
for (int i = 0; i < n; ++i) {
86+
for(int i = 0; i < n; ++i) M *= m[i];
87+
for(int i = 0; i < n; ++i) {
6088
int64 x, y, g, tm = M / m[i];
6189
exgcd(tm, m[i], g, x, y);
6290
ans = (ans + tm * x * c[i] % M) % M;
@@ -77,23 +105,25 @@ struct LinearRecurrence {
77105
};
78106
auto prime_power = [&](const vec &s, int64 mod, int64 p, int64 e) {
79107
// linear feedback shift register mod p^e, p is prime
80-
std::vector <vec> a(e), b(e), an(e), bn(e), ao(e), bo(e);
81-
vec t(e), u(e), r(e), to(e, 1), uo(e), pw(e + 1);;
82-
pw[0] = 1;
83-
for (int i = pw[0] = 1; i <= e; ++i) pw[i] = pw[i - 1] * p;
84-
for (int64 i = 0; i < e; ++i) {
108+
std::vector<vec> a(e), b(e), an(e), bn(e), ao(e), bo(e);
109+
vec t(e), u(e), r(e), to(e, 1), uo(e), pw(e + 1, 1);;
110+
for(int i = 1; i <= e; ++i) {
111+
pw[i] = pw[i - 1] * p;
112+
assert(pw[i] <= mod);
113+
}
114+
for(int64 i = 0; i < e; ++i) {
85115
a[i] = {pw[i]}, an[i] = {pw[i]};
86116
b[i] = {0}, bn[i] = {s[0] * pw[i] % mod};
87117
t[i] = s[0] * pw[i] % mod;
88-
if (t[i] == 0) {
118+
if(t[i] == 0) {
89119
t[i] = 1, u[i] = e;
90120
} else {
91-
for (u[i] = 0; t[i] % p == 0; t[i] /= p, ++u[i]);
121+
for(u[i] = 0; t[i] % p == 0; t[i] /= p, ++u[i]);
92122
}
93123
}
94-
for (size_t k = 1; k < s.size(); ++k) {
95-
for (int g = 0; g < e; ++g) {
96-
if (L(an[g], bn[g]) > L(a[g], b[g])) {
124+
for(size_t k = 1; k < s.size(); ++k) {
125+
for(int g = 0; g < e; ++g) {
126+
if(L(an[g], bn[g]) > L(a[g], b[g])) {
97127
ao[g] = a[e - 1 - u[g]];
98128
bo[g] = b[e - 1 - u[g]];
99129
to[g] = t[e - 1 - u[g]];
@@ -102,64 +132,65 @@ struct LinearRecurrence {
102132
}
103133
}
104134
a = an, b = bn;
105-
for (int o = 0; o < e; ++o) {
135+
for(int o = 0; o < e; ++o) {
106136
int64 d = 0;
107-
for (size_t i = 0; i < a[o].size() && i <= k; ++i) {
137+
for(size_t i = 0; i < a[o].size() && i <= k; ++i) {
108138
d = (d + a[o][i] * s[k - i]) % mod;
109139
}
110-
if (d == 0) {
140+
if(d == 0) {
111141
t[o] = 1, u[o] = e;
112142
} else {
113-
for (u[o] = 0, t[o] = d; t[o] % p == 0; t[o] /= p, ++u[o]);
143+
for(u[o] = 0, t[o] = d; t[o] % p == 0; t[o] /= p, ++u[o]);
114144
int g = e - 1 - u[o];
115-
if (L(a[g], b[g]) == 0) {
116-
extand(bn[o], k + 1);
145+
if(L(a[g], b[g]) == 0) {
146+
extend(bn[o], k + 1);
117147
bn[o][k] = (bn[o][k] + d) % mod;
118148
} else {
119149
int64 coef = t[o] * inverse(to[g], mod) % mod * pw[u[o] - uo[g]] % mod;
120150
int m = k - r[g];
121-
extand(an[o], ao[g].size() + m);
122-
extand(bn[o], bo[g].size() + m);
123-
for (size_t i = 0; i < ao[g].size(); ++i) {
151+
assert(m >= 0);
152+
extend(an[o], ao[g].size() + m);
153+
extend(bn[o], bo[g].size() + m);
154+
for(size_t i = 0; i < ao[g].size(); ++i) {
124155
an[o][i + m] -= coef * ao[g][i] % mod;
125-
if (an[o][i + m] < 0) an[o][i + m] += mod;
156+
if(an[o][i + m] < 0) an[o][i + m] += mod;
126157
}
127-
while (an[o].size() && an[o].back() == 0) an[o].pop_back();
128-
for (size_t i = 0; i < bo[g].size(); ++i) {
158+
while(an[o].size() && an[o].back() == 0) an[o].pop_back();
159+
for(size_t i = 0; i < bo[g].size(); ++i) {
129160
bn[o][i + m] -= coef * bo[g][i] % mod;
130-
if (bn[o][i + m] < 0) bn[o][i + m] -= mod;
161+
if(bn[o][i + m] < 0) bn[o][i + m] -= mod;
131162
}
132-
while (bn[o].size() && bn[o].back() == 0) bn[o].pop_back();
163+
while(bn[o].size() && bn[o].back() == 0) bn[o].pop_back();
133164
}
134165
}
135166
}
136167
}
137168
return std::make_pair(an[0], bn[0]);
138169
};
139170

140-
std::vector <std::tuple<int64, int64, int>> fac;
141-
for (int64 i = 2; i * i <= mod; ++i)
142-
if (mod % i == 0) {
171+
std::vector<std::tuple<int64, int64, int>> fac;
172+
for(int64 i = 2; i * i <= mod; ++i)
173+
if(mod % i == 0) {
143174
int64 cnt = 0, pw = 1;
144-
while (mod % i == 0) mod /= i, ++cnt, pw *= i;
175+
while(mod % i == 0) mod /= i, ++cnt, pw *= i;
145176
fac.emplace_back(pw, i, cnt);
146177
}
147-
if (mod > 1) fac.emplace_back(mod, mod, 1);
148-
std::vector <vec> as;
178+
if(mod > 1) fac.emplace_back(mod, mod, 1);
179+
std::vector<vec> as;
149180
size_t n = 0;
150-
for (auto &&x: fac) {
181+
for(auto &&x: fac) {
151182
int64 mod, p, e;
152183
vec a, b;
153184
std::tie(mod, p, e) = x;
154185
auto ss = s;
155-
for (auto &&x: ss) x %= mod;
186+
for(auto &&x: ss) x %= mod;
156187
std::tie(a, b) = prime_power(ss, mod, p, e);
157188
as.emplace_back(a);
158189
n = std::max(n, a.size());
159190
}
160191
vec a(n), c(as.size()), m(as.size());
161-
for (size_t i = 0; i < n; ++i) {
162-
for (size_t j = 0; j < as.size(); ++j) {
192+
for(size_t i = 0; i < n; ++i) {
193+
for(size_t j = 0; j < as.size(); ++j) {
163194
m[j] = std::get<0>(fac[j]);
164195
c[j] = i < as[j].size() ? as[j][i] : 0;
165196
}
@@ -172,46 +203,48 @@ struct LinearRecurrence {
172203
init(s), trans(c), mod(mod), m(s.size()) {}
173204

174205
LinearRecurrence(const vec &s, int64 mod, bool is_prime = true) : mod(mod) {
206+
assert(s.size() % 2 == 0);
175207
vec A;
176-
if (is_prime) A = BerlekampMassey(s, mod);
208+
if(is_prime) A = BerlekampMassey(s, mod);
177209
else A = ReedsSloane(s, mod);
178-
if (A.empty()) A = {0};
179-
m = A.size() - 1;
210+
m = s.size() / 2;
211+
A.resize(m + 1, 0);
180212
trans.resize(m);
181-
for (int i = 0; i < m; ++i) {
213+
for(int i = 0; i < m; ++i) {
182214
trans[i] = (mod - A[i + 1]) % mod;
183215
}
216+
if(m == 0) m = 1, trans = {1};
184217
std::reverse(trans.begin(), trans.end());
185218
init = {s.begin(), s.begin() + m};
186219
}
187220

188221
int64 calc(int64 n) {
189-
if (mod == 1) return 0;
190-
if (n < m) return init[n];
222+
if(mod == 1) return 0;
223+
if(n < m) return init[n];
191224
vec v(m), u(m << 1);
192-
int msk = !!n;
193-
for (int64 m = n; m > 1; m >>= 1) msk <<= 1;
225+
int64 msk = !!n;
226+
for(int64 m = n; m > 1; m >>= 1) msk <<= 1;
194227
v[0] = 1 % mod;
195-
for (int x = 0; msk; msk >>= 1, x <<= 1) {
228+
for(int64 x = 0; msk; msk >>= 1, x <<= 1) {
196229
std::fill_n(u.begin(), m * 2, 0);
197230
x |= !!(n & msk);
198-
if (x < m) u[x] = 1 % mod;
231+
if(x < m) u[x] = 1 % mod;
199232
else {// can be optimized by fft/ntt
200-
for (int i = 0; i < m; ++i) {
201-
for (int j = 0, t = i + (x & 1); j < m; ++j, ++t) {
233+
for(int i = 0; i < m; ++i) {
234+
for(int j = 0, t = i + (x & 1); j < m; ++j, ++t) {
202235
u[t] = (u[t] + v[i] * v[j]) % mod;
203236
}
204237
}
205-
for (int i = m * 2 - 1; i >= m; --i) {
206-
for (int j = 0, t = i - m; j < m; ++j, ++t) {
238+
for(int i = m * 2 - 1; i >= m; --i) {
239+
for(int j = 0, t = i - m; j < m; ++j, ++t) {
207240
u[t] = (u[t] + trans[j] * u[i]) % mod;
208241
}
209242
}
210243
}
211244
v = {u.begin(), u.begin() + m};
212245
}
213246
int64 ret = 0;
214-
for (int i = 0; i < m; ++i) {
247+
for(int i = 0; i < m; ++i) {
215248
ret = (ret + v[i] * init[i]) % mod;
216249
}
217250
return ret;
@@ -221,3 +254,53 @@ struct LinearRecurrence {
221254
int64 mod;
222255
int m;
223256
};
257+
258+
void verify() {
259+
using int64 = long long;
260+
std::mt19937 gen{0};
261+
for(int cas = 1; cas <= 1000; ++cas) {
262+
std::uniform_int_distribution<int> dis_mod(1, 1 << 31);
263+
std::uniform_int_distribution<int> dis_n(1, 100);
264+
int n = dis_n(gen), mod = dis_mod(gen);
265+
std::uniform_int_distribution<int> dis_a(0, mod - 1);
266+
std::vector<int64> a(n), c(n);
267+
for(int i = 0; i < n; ++i) {
268+
a[i] = dis_a(gen);
269+
c[i] = dis_a(gen);
270+
}
271+
for(int i = n; i <= n * 2; ++i) {
272+
int64 sum = 0;
273+
for(int j = 0; j < n; ++j) {
274+
sum += c[j] * a[i - 1 - j] % mod;
275+
}
276+
a.push_back(sum % mod);
277+
}
278+
auto u = a.back();
279+
a.pop_back();
280+
LinearRecurrence lr{a, mod, false};
281+
a.push_back(u);
282+
for(size_t i = 0; i < a.size(); ++i) {
283+
assert(lr.calc(i) == a[i]);
284+
}
285+
}
286+
}
287+
288+
// http://www.spoj.com/problems/FINDLR/
289+
void solve() {
290+
int T;
291+
scanf("%d", &T);
292+
for(int cas = 1; cas <= T; ++cas) {
293+
int n, mod;
294+
scanf("%d%d", &n, &mod);
295+
std::vector<LinearRecurrence::int64> a(n * 2);
296+
for(int i = 0; i < n * 2; ++i) scanf("%lld", &a[i]);
297+
LinearRecurrence lr{a, mod, false};
298+
printf("%lld\n", lr.calc(n * 2));
299+
}
300+
}
301+
302+
int main() {
303+
verify();
304+
//solve();
305+
return 0;
306+
}

0 commit comments

Comments
 (0)