1
- // given first m items init[0..m-1] and coefficents trans[0..m-1] or
2
- // given first 2 *m items init[0..2m-1], it will compute trans[0..m-1]
3
- // for you. trans[0..m] should be given as that
1
+ #include < cstdio>
2
+ #include < vector>
3
+ #include < cassert>
4
+ #include < functional>
5
+ #include < algorithm>
6
+ #include < random>
7
+
8
+ // 1. Given first m items init[0..m-1] and coefficents trans[0..m-1]
9
+ //
10
+ // ```cpp
11
+ // std::vector<int64> init = {1, 1};
12
+ // std::vector<int64> trans = {1, 1};
13
+ // const int mod = 1e9 + 7; // 998244353; 1000000006;
14
+ // LinearRecurrence lr{init, trans, mod};
15
+ // std::cout << lr.calc(1000000000000000000ll) << std::endl;
16
+ // ```
17
+ //
18
+ // 2. Given first 2 * m items init[0..2m-1], it will compute trans[0..m-1]
19
+ // trans[0..m] will be given as that:
4
20
// init[m] = sum_{i=0}^{m-1} init[i] * trans[i]
21
+ // you should make sure that init[0] is not zero and init[i] is in [0, mod - 1]
22
+ //
23
+ // ```cpp
24
+ // std::vector<int64> init = {1, 1, 2, 3, 5, 8, 13};
25
+ // const int prime_mod = 998244353;
26
+ // LinearRecurrence lr1{init, prime_mod};
27
+ // std::cout << lr.calc(1000000000000000000ll) << std::endl;
28
+ // const int non_prime_mod = 1000000006;
29
+ // LinearRecurrence lr2{init, non_prime_mod, false};
30
+ // std::cout << lr.calc(1000000000000000000ll) << std::endl;
31
+ // ```
5
32
struct LinearRecurrence {
6
33
using int64 = long long ;
7
34
using vec = std::vector<int64>;
8
35
9
- static void extand (vec &a, size_t d, int64 value = 0 ) {
10
- if (d <= a.size ()) return ;
36
+ static void extend (vec &a, size_t d, int64 value = 0 ) {
37
+ if (d <= a.size ()) return ;
11
38
a.resize (d, value);
12
39
}
13
40
14
41
static vec BerlekampMassey (const vec &s, int64 mod) {
15
- std::function < int64 (int64) > inverse = [&](int64 a) {
42
+ std::function< int64 (int64)> inverse = [&](int64 a) {
16
43
return a == 1 ? 1 : (int64) (mod - mod / a) * inverse (mod % a) % mod;
17
44
};
18
45
vec A = {1 }, B = {1 };
19
46
int64 b = s[0 ];
20
- for (size_t i = 1 , m = 1 ; i < s.size (); ++i, m++) {
47
+ assert (b != 0 );
48
+ for (size_t i = 1 , m = 1 ; i < s.size (); ++i, m++) {
21
49
int64 d = 0 ;
22
- for (size_t j = 0 ; j < A.size (); ++j) {
50
+ for (size_t j = 0 ; j < A.size (); ++j) {
23
51
d += A[j] * s[i - j] % mod;
24
52
}
25
- if (!(d %= mod)) continue ;
26
- if (2 * (A.size () - 1 ) <= i) {
53
+ if (!(d %= mod)) continue ;
54
+ if (2 * (A.size () - 1 ) <= i) {
27
55
auto temp = A;
28
- extand (A, B.size () + m);
56
+ extend (A, B.size () + m);
29
57
int64 coef = d * inverse (b) % mod;
30
- for (size_t j = 0 ; j < B.size (); ++j) {
58
+ for (size_t j = 0 ; j < B.size (); ++j) {
31
59
A[j + m] -= coef * B[j] % mod;
32
- if (A[j + m] < 0 ) A[j + m] += mod;
60
+ if (A[j + m] < 0 ) A[j + m] += mod;
33
61
}
34
62
B = temp, b = d, m = 0 ;
35
63
} else {
36
- extand (A, B.size () + m);
64
+ extend (A, B.size () + m);
37
65
int64 coef = d * inverse (b) % mod;
38
- for (size_t j = 0 ; j < B.size (); ++j) {
66
+ for (size_t j = 0 ; j < B.size (); ++j) {
39
67
A[j + m] -= coef * B[j] % mod;
40
- if (A[j + m] < 0 ) A[j + m] += mod;
68
+ if (A[j + m] < 0 ) A[j + m] += mod;
41
69
}
42
70
}
43
71
}
44
72
return A;
45
73
}
46
74
47
75
static void exgcd (int64 a, int64 b, int64 &g, int64 &x, int64 &y) {
48
- if (!b) x = 1 , y = 0 , g = a;
76
+ if (!b) x = 1 , y = 0 , g = a;
49
77
else {
50
78
exgcd (b, a % b, g, y, x);
51
79
y -= x * (a / b);
@@ -55,8 +83,8 @@ struct LinearRecurrence {
55
83
static int64 crt (const vec &c, const vec &m) {
56
84
int n = c.size ();
57
85
int64 M = 1 , ans = 0 ;
58
- for (int i = 0 ; i < n; ++i) M *= m[i];
59
- for (int i = 0 ; i < n; ++i) {
86
+ for (int i = 0 ; i < n; ++i) M *= m[i];
87
+ for (int i = 0 ; i < n; ++i) {
60
88
int64 x, y, g, tm = M / m[i];
61
89
exgcd (tm, m[i], g, x, y);
62
90
ans = (ans + tm * x * c[i] % M) % M;
@@ -77,23 +105,25 @@ struct LinearRecurrence {
77
105
};
78
106
auto prime_power = [&](const vec &s, int64 mod, int64 p, int64 e) {
79
107
// linear feedback shift register mod p^e, p is prime
80
- std::vector <vec> a (e), b (e), an (e), bn (e), ao (e), bo (e);
81
- vec t (e), u (e), r (e), to (e, 1 ), uo (e), pw (e + 1 );;
82
- pw[0 ] = 1 ;
83
- for (int i = pw[0 ] = 1 ; i <= e; ++i) pw[i] = pw[i - 1 ] * p;
84
- for (int64 i = 0 ; i < e; ++i) {
108
+ std::vector<vec> a (e), b (e), an (e), bn (e), ao (e), bo (e);
109
+ vec t (e), u (e), r (e), to (e, 1 ), uo (e), pw (e + 1 , 1 );;
110
+ for (int i = 1 ; i <= e; ++i) {
111
+ pw[i] = pw[i - 1 ] * p;
112
+ assert (pw[i] <= mod);
113
+ }
114
+ for (int64 i = 0 ; i < e; ++i) {
85
115
a[i] = {pw[i]}, an[i] = {pw[i]};
86
116
b[i] = {0 }, bn[i] = {s[0 ] * pw[i] % mod};
87
117
t[i] = s[0 ] * pw[i] % mod;
88
- if (t[i] == 0 ) {
118
+ if (t[i] == 0 ) {
89
119
t[i] = 1 , u[i] = e;
90
120
} else {
91
- for (u[i] = 0 ; t[i] % p == 0 ; t[i] /= p, ++u[i]);
121
+ for (u[i] = 0 ; t[i] % p == 0 ; t[i] /= p, ++u[i]);
92
122
}
93
123
}
94
- for (size_t k = 1 ; k < s.size (); ++k) {
95
- for (int g = 0 ; g < e; ++g) {
96
- if (L (an[g], bn[g]) > L (a[g], b[g])) {
124
+ for (size_t k = 1 ; k < s.size (); ++k) {
125
+ for (int g = 0 ; g < e; ++g) {
126
+ if (L (an[g], bn[g]) > L (a[g], b[g])) {
97
127
ao[g] = a[e - 1 - u[g]];
98
128
bo[g] = b[e - 1 - u[g]];
99
129
to[g] = t[e - 1 - u[g]];
@@ -102,64 +132,65 @@ struct LinearRecurrence {
102
132
}
103
133
}
104
134
a = an, b = bn;
105
- for (int o = 0 ; o < e; ++o) {
135
+ for (int o = 0 ; o < e; ++o) {
106
136
int64 d = 0 ;
107
- for (size_t i = 0 ; i < a[o].size () && i <= k; ++i) {
137
+ for (size_t i = 0 ; i < a[o].size () && i <= k; ++i) {
108
138
d = (d + a[o][i] * s[k - i]) % mod;
109
139
}
110
- if (d == 0 ) {
140
+ if (d == 0 ) {
111
141
t[o] = 1 , u[o] = e;
112
142
} else {
113
- for (u[o] = 0 , t[o] = d; t[o] % p == 0 ; t[o] /= p, ++u[o]);
143
+ for (u[o] = 0 , t[o] = d; t[o] % p == 0 ; t[o] /= p, ++u[o]);
114
144
int g = e - 1 - u[o];
115
- if (L (a[g], b[g]) == 0 ) {
116
- extand (bn[o], k + 1 );
145
+ if (L (a[g], b[g]) == 0 ) {
146
+ extend (bn[o], k + 1 );
117
147
bn[o][k] = (bn[o][k] + d) % mod;
118
148
} else {
119
149
int64 coef = t[o] * inverse (to[g], mod) % mod * pw[u[o] - uo[g]] % mod;
120
150
int m = k - r[g];
121
- extand (an[o], ao[g].size () + m);
122
- extand (bn[o], bo[g].size () + m);
123
- for (size_t i = 0 ; i < ao[g].size (); ++i) {
151
+ assert (m >= 0 );
152
+ extend (an[o], ao[g].size () + m);
153
+ extend (bn[o], bo[g].size () + m);
154
+ for (size_t i = 0 ; i < ao[g].size (); ++i) {
124
155
an[o][i + m] -= coef * ao[g][i] % mod;
125
- if (an[o][i + m] < 0 ) an[o][i + m] += mod;
156
+ if (an[o][i + m] < 0 ) an[o][i + m] += mod;
126
157
}
127
- while (an[o].size () && an[o].back () == 0 ) an[o].pop_back ();
128
- for (size_t i = 0 ; i < bo[g].size (); ++i) {
158
+ while (an[o].size () && an[o].back () == 0 ) an[o].pop_back ();
159
+ for (size_t i = 0 ; i < bo[g].size (); ++i) {
129
160
bn[o][i + m] -= coef * bo[g][i] % mod;
130
- if (bn[o][i + m] < 0 ) bn[o][i + m] -= mod;
161
+ if (bn[o][i + m] < 0 ) bn[o][i + m] -= mod;
131
162
}
132
- while (bn[o].size () && bn[o].back () == 0 ) bn[o].pop_back ();
163
+ while (bn[o].size () && bn[o].back () == 0 ) bn[o].pop_back ();
133
164
}
134
165
}
135
166
}
136
167
}
137
168
return std::make_pair (an[0 ], bn[0 ]);
138
169
};
139
170
140
- std::vector <std::tuple<int64, int64, int >> fac;
141
- for (int64 i = 2 ; i * i <= mod; ++i)
142
- if (mod % i == 0 ) {
171
+ std::vector<std::tuple<int64, int64, int >> fac;
172
+ for (int64 i = 2 ; i * i <= mod; ++i)
173
+ if (mod % i == 0 ) {
143
174
int64 cnt = 0 , pw = 1 ;
144
- while (mod % i == 0 ) mod /= i, ++cnt, pw *= i;
175
+ while (mod % i == 0 ) mod /= i, ++cnt, pw *= i;
145
176
fac.emplace_back (pw, i, cnt);
146
177
}
147
- if (mod > 1 ) fac.emplace_back (mod, mod, 1 );
148
- std::vector <vec> as;
178
+ if (mod > 1 ) fac.emplace_back (mod, mod, 1 );
179
+ std::vector<vec> as;
149
180
size_t n = 0 ;
150
- for (auto &&x: fac) {
181
+ for (auto &&x: fac) {
151
182
int64 mod, p, e;
152
183
vec a, b;
153
184
std::tie (mod, p, e) = x;
154
185
auto ss = s;
155
- for (auto &&x: ss) x %= mod;
186
+ for (auto &&x: ss) x %= mod;
156
187
std::tie (a, b) = prime_power (ss, mod, p, e);
157
188
as.emplace_back (a);
158
189
n = std::max (n, a.size ());
159
190
}
160
191
vec a (n), c (as.size ()), m (as.size ());
161
- for (size_t i = 0 ; i < n; ++i) {
162
- for (size_t j = 0 ; j < as.size (); ++j) {
192
+ for (size_t i = 0 ; i < n; ++i) {
193
+ for (size_t j = 0 ; j < as.size (); ++j) {
163
194
m[j] = std::get<0 >(fac[j]);
164
195
c[j] = i < as[j].size () ? as[j][i] : 0 ;
165
196
}
@@ -172,46 +203,48 @@ struct LinearRecurrence {
172
203
init (s), trans(c), mod(mod), m(s.size()) {}
173
204
174
205
LinearRecurrence (const vec &s, int64 mod, bool is_prime = true ) : mod(mod) {
206
+ assert (s.size () % 2 == 0 );
175
207
vec A;
176
- if (is_prime) A = BerlekampMassey (s, mod);
208
+ if (is_prime) A = BerlekampMassey (s, mod);
177
209
else A = ReedsSloane (s, mod);
178
- if (A. empty ()) A = { 0 } ;
179
- m = A. size () - 1 ;
210
+ m = s. size () / 2 ;
211
+ A. resize (m + 1 , 0 ) ;
180
212
trans.resize (m);
181
- for (int i = 0 ; i < m; ++i) {
213
+ for (int i = 0 ; i < m; ++i) {
182
214
trans[i] = (mod - A[i + 1 ]) % mod;
183
215
}
216
+ if (m == 0 ) m = 1 , trans = {1 };
184
217
std::reverse (trans.begin (), trans.end ());
185
218
init = {s.begin (), s.begin () + m};
186
219
}
187
220
188
221
int64 calc (int64 n) {
189
- if (mod == 1 ) return 0 ;
190
- if (n < m) return init[n];
222
+ if (mod == 1 ) return 0 ;
223
+ if (n < m) return init[n];
191
224
vec v (m), u (m << 1 );
192
- int msk = !!n;
193
- for (int64 m = n; m > 1 ; m >>= 1 ) msk <<= 1 ;
225
+ int64 msk = !!n;
226
+ for (int64 m = n; m > 1 ; m >>= 1 ) msk <<= 1 ;
194
227
v[0 ] = 1 % mod;
195
- for ( int x = 0 ; msk; msk >>= 1 , x <<= 1 ) {
228
+ for (int64 x = 0 ; msk; msk >>= 1 , x <<= 1 ) {
196
229
std::fill_n (u.begin (), m * 2 , 0 );
197
230
x |= !!(n & msk);
198
- if (x < m) u[x] = 1 % mod;
231
+ if (x < m) u[x] = 1 % mod;
199
232
else {// can be optimized by fft/ntt
200
- for (int i = 0 ; i < m; ++i) {
201
- for (int j = 0 , t = i + (x & 1 ); j < m; ++j, ++t) {
233
+ for (int i = 0 ; i < m; ++i) {
234
+ for (int j = 0 , t = i + (x & 1 ); j < m; ++j, ++t) {
202
235
u[t] = (u[t] + v[i] * v[j]) % mod;
203
236
}
204
237
}
205
- for (int i = m * 2 - 1 ; i >= m; --i) {
206
- for (int j = 0 , t = i - m; j < m; ++j, ++t) {
238
+ for (int i = m * 2 - 1 ; i >= m; --i) {
239
+ for (int j = 0 , t = i - m; j < m; ++j, ++t) {
207
240
u[t] = (u[t] + trans[j] * u[i]) % mod;
208
241
}
209
242
}
210
243
}
211
244
v = {u.begin (), u.begin () + m};
212
245
}
213
246
int64 ret = 0 ;
214
- for (int i = 0 ; i < m; ++i) {
247
+ for (int i = 0 ; i < m; ++i) {
215
248
ret = (ret + v[i] * init[i]) % mod;
216
249
}
217
250
return ret;
@@ -221,3 +254,53 @@ struct LinearRecurrence {
221
254
int64 mod;
222
255
int m;
223
256
};
257
+
258
+ void verify () {
259
+ using int64 = long long ;
260
+ std::mt19937 gen{0 };
261
+ for (int cas = 1 ; cas <= 1000 ; ++cas) {
262
+ std::uniform_int_distribution<int > dis_mod (1 , 1 << 31 );
263
+ std::uniform_int_distribution<int > dis_n (1 , 100 );
264
+ int n = dis_n (gen), mod = dis_mod (gen);
265
+ std::uniform_int_distribution<int > dis_a (0 , mod - 1 );
266
+ std::vector<int64> a (n), c (n);
267
+ for (int i = 0 ; i < n; ++i) {
268
+ a[i] = dis_a (gen);
269
+ c[i] = dis_a (gen);
270
+ }
271
+ for (int i = n; i <= n * 2 ; ++i) {
272
+ int64 sum = 0 ;
273
+ for (int j = 0 ; j < n; ++j) {
274
+ sum += c[j] * a[i - 1 - j] % mod;
275
+ }
276
+ a.push_back (sum % mod);
277
+ }
278
+ auto u = a.back ();
279
+ a.pop_back ();
280
+ LinearRecurrence lr{a, mod, false };
281
+ a.push_back (u);
282
+ for (size_t i = 0 ; i < a.size (); ++i) {
283
+ assert (lr.calc (i) == a[i]);
284
+ }
285
+ }
286
+ }
287
+
288
+ // http://www.spoj.com/problems/FINDLR/
289
+ void solve () {
290
+ int T;
291
+ scanf (" %d" , &T);
292
+ for (int cas = 1 ; cas <= T; ++cas) {
293
+ int n, mod;
294
+ scanf (" %d%d" , &n, &mod);
295
+ std::vector<LinearRecurrence::int64> a (n * 2 );
296
+ for (int i = 0 ; i < n * 2 ; ++i) scanf (" %lld" , &a[i]);
297
+ LinearRecurrence lr{a, mod, false };
298
+ printf (" %lld\n " , lr.calc (n * 2 ));
299
+ }
300
+ }
301
+
302
+ int main () {
303
+ verify ();
304
+ // solve();
305
+ return 0 ;
306
+ }
0 commit comments