Skip to content

Commit a62370b

Browse files
committed
fixes
1 parent 569e6ea commit a62370b

File tree

6 files changed

+46
-32
lines changed

6 files changed

+46
-32
lines changed

cp-algo/number_theory/discrete_log.hpp

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,19 +4,21 @@
44
#include <optional>
55
namespace cp_algo::math {
66
// Find min non-negative x s.t. a*b^x = c (mod m)
7-
std::optional<int64_t> discrete_log(int64_t b, int64_t c, int64_t m, int64_t a = 1) {
7+
template<typename _Int>
8+
std::optional<_Int> discrete_log(_Int b, _Int c, _Int m, _Int a = 1) {
89
if(std::abs(a - c) % m == 0) {
910
return 0;
1011
}
11-
if(std::gcd(a, m) != std::gcd(a * b, m)) {
12-
auto res = discrete_log(b, c, m, a * b % m);
12+
if(std::gcd(a, m) != std::gcd(int64_t(a) * b, int64_t(m))) {
13+
auto res = discrete_log(b, c, m, _Int(int64_t(a) * b % m));
1314
return res ? std::optional(*res + 1) : res;
1415
}
1516
// a * b^x is periodic here
16-
using base = dynamic_modint<>;
17-
return base::with_mod(m, [&]() -> std::optional<uint64_t> {
17+
using Int = std::make_signed_t<_Int>;
18+
using base = dynamic_modint<Int>;
19+
return base::with_mod(m, [&]() -> std::optional<_Int> {
1820
int sqrtmod = std::max(1, (int)std::sqrt(m) / 2);
19-
std::unordered_map<int64_t, int> small;
21+
std::unordered_map<_Int, int> small;
2022
base cur = a;
2123
for(int i = 0; i < sqrtmod; i++) {
2224
small[cur.getr()] = i;

cp-algo/number_theory/euler.hpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,19 @@
22
#define CP_ALGO_NUMBER_THEORY_EULER_HPP
33
#include "factorize.hpp"
44
namespace cp_algo::math {
5-
int64_t euler_phi(int64_t m) {
5+
auto euler_phi(auto m) {
66
auto primes = to<std::vector>(factorize(m));
77
std::ranges::sort(primes);
88
auto [from, to] = std::ranges::unique(primes);
99
primes.erase(from, to);
10-
int64_t ans = m;
10+
auto ans = m;
1111
for(auto it: primes) {
1212
ans -= ans / it;
1313
}
1414
return ans;
1515
}
1616
template<modint_type base>
17-
int64_t period(base x) {
17+
auto period(base x) {
1818
auto ans = euler_phi(base::mod());
1919
base x0 = bpow(x, ans);
2020
for(auto t: factorize(ans)) {
@@ -24,8 +24,10 @@ namespace cp_algo::math {
2424
}
2525
return ans;
2626
}
27-
int64_t primitive_root(int64_t p) {
28-
using base = dynamic_modint<>;
27+
template<typename _Int>
28+
_Int primitive_root(_Int p) {
29+
using Int = std::make_signed_t<_Int>;
30+
using base = dynamic_modint<Int>;
2931
return base::with_mod(p, [p](){
3032
base t = 1;
3133
while(period(t) != p - 1) {

cp-algo/number_theory/factorize.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,10 @@
55
#include <generator>
66
namespace cp_algo::math {
77
// https://en.wikipedia.org/wiki/Pollard%27s_rho_algorithm
8-
auto proper_divisor(uint64_t m) {
9-
using base = dynamic_modint<>;
8+
template<typename _Int>
9+
auto proper_divisor(_Int m) {
10+
using Int = std::make_signed_t<_Int>;
11+
using base = dynamic_modint<Int>;
1012
return m % 2 == 0 ? 2 : base::with_mod(m, [&]() {
1113
base t = random::rng();
1214
auto f = [&](auto x) {
@@ -31,7 +33,8 @@ namespace cp_algo::math {
3133
return g.getr();
3234
});
3335
}
34-
std::generator<uint64_t> factorize(uint64_t m) {
36+
template<typename Int>
37+
std::generator<Int> factorize(Int m) {
3538
if(is_prime(m)) {
3639
co_yield m;
3740
} else if(m > 1) {

cp-algo/number_theory/primality.hpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
#include <bit>
66
namespace cp_algo::math {
77
// https://en.wikipedia.org/wiki/Miller–Rabin_primality_test
8-
bool is_prime(uint64_t m) {
8+
template<typename _Int>
9+
bool is_prime(_Int m) {
10+
using Int = std::make_signed_t<_Int>;
11+
using UInt = std::make_unsigned_t<Int>;
912
if(m == 1 || m % 2 == 0) {
1013
return m == 2;
1114
}
1215
// m - 1 = 2^s * d
13-
int s = std::countr_zero(m - 1);
16+
int s = std::countr_zero(UInt(m - 1));
1417
auto d = (m - 1) >> s;
15-
using base = dynamic_modint<>;
18+
using base = dynamic_modint<Int>;
1619
auto test = [&](base x) {
1720
x = bpow(x, d);
1821
if(std::abs(x.rem()) <= 1) {

cp-algo/number_theory/two_squares.hpp

Lines changed: 18 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@
77
#include <vector>
88
#include <map>
99
namespace cp_algo::math {
10-
using gaussint = complex<int64_t>;
11-
gaussint two_squares_prime_any(int64_t p) {
10+
template<typename T>
11+
using gaussint = complex<T>;
12+
template<typename _Int>
13+
auto two_squares_prime_any(_Int p) {
1214
if(p == 2) {
13-
return gaussint(1, 1);
15+
return gaussint<_Int>{1, 1};
1416
}
1517
assert(p % 4 == 1);
16-
using base = dynamic_modint<>;
18+
using Int = std::make_signed_t<_Int>;
19+
using base = dynamic_modint<Int>;
1720
return base::with_mod(p, [&](){
1821
base g = primitive_root(p);
1922
int64_t i = bpow(g, (p - 1) / 4).getr();
@@ -25,49 +28,50 @@ namespace cp_algo::math {
2528
q0 = std::exchange(q1, q0 + d * q1);
2629
r = std::exchange(m, r % m);
2730
} while(q1 < p / q1);
28-
return gaussint(q0, (base(i) * base(q0)).rem());
31+
return gaussint<_Int>{q0, (base(i) * base(q0)).rem()};
2932
});
3033
}
3134

32-
std::vector<gaussint> two_squares_all(int64_t n) {
35+
template<typename Int>
36+
std::vector<gaussint<Int>> two_squares_all(Int n) {
3337
if(n == 0) {
3438
return {0};
3539
}
3640
auto primes = factorize(n);
37-
std::map<int64_t, int> cnt;
41+
std::map<Int, int> cnt;
3842
for(auto p: primes) {
3943
cnt[p]++;
4044
}
41-
std::vector<gaussint> res = {1};
45+
std::vector<gaussint<Int>> res = {1};
4246
for(auto [p, c]: cnt) {
43-
std::vector<gaussint> nres;
47+
std::vector<gaussint<Int>> nres;
4448
if(p % 4 == 3) {
4549
if(c % 2 == 0) {
46-
auto mul = bpow(gaussint(p), c / 2);
50+
auto mul = bpow(gaussint<Int>(p), c / 2);
4751
for(auto p: res) {
4852
nres.push_back(p * mul);
4953
}
5054
}
5155
} else if(p % 4 == 1) {
52-
gaussint base = two_squares_prime_any(p);
56+
auto base = two_squares_prime_any(p);
5357
for(int i = 0; i <= c; i++) {
5458
auto mul = bpow(base, i) * bpow(conj(base), c - i);
5559
for(auto p: res) {
5660
nres.push_back(p * mul);
5761
}
5862
}
5963
} else if(p % 4 == 2) {
60-
auto mul = bpow(gaussint(1, 1), c);
64+
auto mul = bpow(gaussint<Int>(1, 1), c);
6165
for(auto p: res) {
6266
nres.push_back(p * mul);
6367
}
6468
}
6569
res = nres;
6670
}
67-
std::vector<gaussint> nres;
71+
std::vector<gaussint<Int>> nres;
6872
for(auto p: res) {
6973
while(p.real() < 0 || p.imag() < 0) {
70-
p *= gaussint(0, 1);
74+
p *= gaussint<Int>(0, 1);
7175
}
7276
nres.push_back(p);
7377
if(!p.real() || !p.imag()) {

verify/linalg/prod_dynamic_modint.test.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
using namespace std;
99
using namespace cp_algo::linalg;
1010
using namespace cp_algo::math;
11-
using base = dynamic_modint<>;
11+
using base = dynamic_modint<int64_t>;
1212

1313
const int64_t mod = 998244353;
1414

0 commit comments

Comments
 (0)