This repository has been archived by the owner on Jul 23, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 2
/
ff_p.cpp
141 lines (111 loc) · 2.58 KB
/
ff_p.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
#include "ff_p.hpp"
#include <climits>
uint64_t
ff_p_add(uint64_t a, uint64_t b)
{
if (b >= MOD) {
b -= MOD;
}
uint64_t res_0 = a + b;
bool over_0 = a > UINT64_MAX - b;
uint32_t zero = 0;
uint64_t tmp_0 = (uint64_t)(zero - (uint32_t)(over_0 ? 1 : 0));
uint64_t res_1 = res_0 + tmp_0;
bool over_1 = res_0 > UINT64_MAX - tmp_0;
uint64_t tmp_1 = (uint64_t)(zero - (uint32_t)(over_1 ? 1 : 0));
uint64_t res = res_1 + tmp_1;
return res;
}
uint64_t
ff_p_sub(uint64_t a, uint64_t b)
{
if (b >= MOD) {
b -= MOD;
}
uint64_t res_0 = a - b;
bool under_0 = a < b;
uint32_t zero = 0;
uint64_t tmp_0 = (uint64_t)(zero - (uint32_t)(under_0 ? 1 : 0));
uint64_t res_1 = res_0 - tmp_0;
bool under_1 = res_0 < tmp_0;
uint64_t tmp_1 = (uint64_t)(zero - (uint32_t)(under_1 ? 1 : 0));
uint64_t res = res_1 + tmp_1;
return res;
}
uint64_t
ff_p_mult(uint64_t a, uint64_t b)
{
if (b >= MOD) {
b -= MOD;
}
uint64_t ab = a * b;
uint64_t cd = sycl::mul_hi(a, b);
uint64_t c = cd & 0x00000000ffffffff;
uint64_t d = cd >> 32;
uint64_t res_0 = ab - d;
bool under_0 = ab < d;
uint32_t zero = 0;
uint64_t tmp_0 = (uint64_t)(zero - (uint32_t)(under_0 ? 1 : 0));
res_0 -= tmp_0;
uint64_t tmp_1 = (c << 32) - c;
uint64_t res_1 = res_0 + tmp_1;
bool over_0 = res_0 > UINT64_MAX - tmp_1;
uint64_t tmp_2 = (uint64_t)(zero - (uint32_t)(over_0 ? 1 : 0));
uint64_t res = res_1 + tmp_2;
return res;
}
uint64_t
ff_p_pow(uint64_t a, const uint64_t b)
{
if (b == 0) {
return 1;
}
if (b == 1) {
return a;
}
if (a == 0) {
return 0;
}
uint64_t r = b & 0b1 ? a : 1;
for (uint8_t i = 1; i < 64 - sycl::clz(b); i++) {
a = ff_p_mult(a, a);
if ((b >> i) & 0b1) {
r = ff_p_mult(r, a);
}
}
return r;
}
uint64_t
ff_p_inv(uint64_t a)
{
if (a >= MOD) {
a -= MOD;
}
if (a == 0) {
// ** no multiplicative inverse of additive identity **
//
// I'm not throwing an exception from here, because
// this function is supposed to be invoked from
// kernel body, where exception throwing is not (yet) allowed !
return 0;
}
const uint64_t exp = MOD - 2;
return ff_p_pow(a, exp);
}
uint64_t
ff_p_div(uint64_t a, uint64_t b)
{
if (b == 0) {
// ** no multiplicative inverse of additive identity **
//
// I'm not throwing an exception from here, because
// this function is supposed to be invoked from
// kernel body, where exception throwing is not (yet) allowed !
return 0;
}
if (a == 0) {
return 0;
}
uint64_t b_inv = ff_p_inv(b);
return ff_p_mult(a, b_inv);
}