Skip to content

Commit 1aa4d4b

Browse files
committed
debug underflow quantization.
1 parent 664fd5f commit 1aa4d4b

File tree

7 files changed

+309
-184
lines changed

7 files changed

+309
-184
lines changed

qtorch/quant/quant_cpu/bit_helper.cpp

+30-14
Original file line numberDiff line numberDiff line change
@@ -2,31 +2,47 @@
22

33
unsigned int clip_exponent(int exp_bits, int man_bits,
44
unsigned int old_num,
5-
unsigned int quantized_num) {
6-
if (quantized_num == 0) return quantized_num;
5+
unsigned int quantized_num)
6+
{
7+
if (quantized_num == 0)
8+
return quantized_num;
79

810
int quantized_exponent_store = quantized_num << 1 >> 1 >> 23; // 1 sign bit, 23 mantissa bits
9-
int min_exponent_store = -((1 << (exp_bits-1))-1) + 127;
10-
int max_exponent_store = ((1 << (exp_bits-1))-1) + 127; // excluding the exponent for infinity
11-
if (quantized_exponent_store > max_exponent_store) {
12-
unsigned int max_man = (unsigned int ) -1 << 9 >> 9 >> (23-man_bits) << (23-man_bits); // 1 sign bit, 8 exponent bits, 1 virtual bit
13-
unsigned int max_num = ((unsigned int) max_exponent_store << 23) | max_man;
11+
int min_exponent_store = -((1 << (exp_bits - 1)) - 1) + 127;
12+
int max_exponent_store = ((1 << (exp_bits - 1)) - 1) + 127; // excluding the exponent for infinity
13+
if (quantized_exponent_store > max_exponent_store)
14+
{
15+
unsigned int max_man = (unsigned int)-1 << 9 >> 9 >> (23 - man_bits) << (23 - man_bits); // 1 sign bit, 8 exponent bits, 1 virtual bit
16+
unsigned int max_num = ((unsigned int)max_exponent_store << 23) | max_man;
1417
unsigned int old_sign = old_num >> 31 << 31;
1518
quantized_num = old_sign | max_num;
16-
} else if (quantized_exponent_store < min_exponent_store) {
17-
unsigned int min_num = ((unsigned int) min_exponent_store << 23);
18-
unsigned int old_sign = old_num >> 31 << 31;
19-
quantized_num = old_sign | min_num;
19+
}
20+
else if (quantized_exponent_store < min_exponent_store)
21+
{
22+
unsigned int min_num = ((unsigned int)min_exponent_store << 23);
23+
unsigned int middle_num = ((unsigned int)(min_exponent_store - 1) << 23);
24+
unsigned int unsigned_quantized_num = quantized_num << 1 >> 1;
25+
if (unsigned_quantized_num > middle_num)
26+
{
27+
unsigned int old_sign = old_num >> 31 << 31;
28+
quantized_num = old_sign | min_num;
29+
}
30+
else
31+
{
32+
quantized_num = 0;
33+
}
2034
}
2135
return quantized_num;
2236
}
2337

2438
unsigned int clip_max_exponent(int man_bits,
2539
unsigned int max_exponent,
26-
unsigned int quantized_num) {
40+
unsigned int quantized_num)
41+
{
2742
unsigned int quantized_exponent = quantized_num << 1 >> 24 << 23; // 1 sign bit, 23 mantissa bits
28-
if (quantized_exponent > max_exponent) {
29-
unsigned int max_man = (unsigned int ) -1 << 9 >> 9 >> (23-man_bits) << (23-man_bits); // 1 sign bit, 8 exponent bits
43+
if (quantized_exponent > max_exponent)
44+
{
45+
unsigned int max_man = (unsigned int)-1 << 9 >> 9 >> (23 - man_bits) << (23 - man_bits); // 1 sign bit, 8 exponent bits
3046
unsigned int max_num = max_exponent | max_man;
3147
unsigned int old_sign = quantized_num >> 31 << 31;
3248
quantized_num = old_sign | max_num;

qtorch/quant/quant_cpu/quant_cpu.cpp

+96-50
Original file line numberDiff line numberDiff line change
@@ -6,42 +6,61 @@
66

77
using namespace at;
88

9-
enum Mode {rNearest, rStochastic};
9+
enum Mode
10+
{
11+
rNearest,
12+
rStochastic
13+
};
1014

1115
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
1216
#define CHECK_CPU(x) AT_CHECK(!x.type().is_cuda(), #x " must be a CPU tensor")
13-
#define CHECK_INPUT(x) CHECK_CPU(x); CHECK_CONTIGUOUS(x);
17+
#define CHECK_INPUT(x) \
18+
CHECK_CPU(x); \
19+
CHECK_CONTIGUOUS(x);
1420

15-
#define RFLOAT_TO_BITS(x) (*reinterpret_cast<unsigned int*>(x))
16-
#define RBITS_TO_FLOAT(x) (*reinterpret_cast<float*>(x))
17-
#define FLOAT_TO_BITS(f, i) assert(sizeof f == sizeof i); std::memcpy(&i, &f, sizeof i)
18-
#define BITS_TO_FLOAT(i, f) assert(sizeof f == sizeof i); std::memcpy(&f, &i, sizeof f)
21+
#define RFLOAT_TO_BITS(x) (*reinterpret_cast<unsigned int *>(x))
22+
#define RBITS_TO_FLOAT(x) (*reinterpret_cast<float *>(x))
23+
#define FLOAT_TO_BITS(f, i) \
24+
assert(sizeof f == sizeof i); \
25+
std::memcpy(&i, &f, sizeof i)
26+
#define BITS_TO_FLOAT(i, f) \
27+
assert(sizeof f == sizeof i); \
28+
std::memcpy(&f, &i, sizeof f)
1929

2030
std::random_device rd;
2131
std::mt19937 gen(rd());
2232
std::uniform_int_distribution<> dis(0);
2333

2434
template <typename T>
25-
T clamp_helper(T a, T min, T max) {
26-
if (a > max) return max;
27-
else if (a < min) return min;
28-
else return a;
35+
T clamp_helper(T a, T min, T max)
36+
{
37+
if (a > max)
38+
return max;
39+
else if (a < min)
40+
return min;
41+
else
42+
return a;
2943
}
3044

3145
template <typename T>
32-
T clamp_mask_helper(T a, T min, T max, uint8_t* mask) {
33-
if (a > max) {
46+
T clamp_mask_helper(T a, T min, T max, uint8_t *mask)
47+
{
48+
if (a > max)
49+
{
3450
*mask = 1;
3551
return max;
3652
}
37-
else if (a < min) {
53+
else if (a < min)
54+
{
3855
*mask = 1;
3956
return min;
4057
}
41-
else return a;
58+
else
59+
return a;
4260
}
4361

44-
std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl, int fl, bool symmetric) {
62+
std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl, int fl, bool symmetric)
63+
{
4564
CHECK_INPUT(a);
4665
auto r = rand_like(a);
4766
auto a_array = a.data<float>();
@@ -54,14 +73,16 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl
5473
int sigma = -fl;
5574
float t_min, t_max;
5675
fixed_min_max(wl, fl, symmetric, &t_min, &t_max);
57-
for (int64_t i=0; i < size; i++) {
76+
for (int64_t i = 0; i < size; i++)
77+
{
5878
o_array[i] = round(a_array[i], r_array[i], sigma);
59-
o_array[i] = clamp_mask_helper<float>(o_array[i], t_min, t_max, m_array+i);
79+
o_array[i] = clamp_mask_helper<float>(o_array[i], t_min, t_max, m_array + i);
6080
}
6181
return std::make_tuple(o, m);
6282
}
6383

64-
std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask(Tensor a, int wl, int fl, bool symmetric) {
84+
std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask(Tensor a, int wl, int fl, bool symmetric)
85+
{
6586
CHECK_INPUT(a);
6687
auto a_array = a.data<float>();
6788
auto o = zeros_like(a);
@@ -72,14 +93,16 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask(Tensor a, int wl, i
7293
int sigma = -fl;
7394
float t_min, t_max;
7495
fixed_min_max(wl, fl, symmetric, &t_min, &t_max);
75-
for (int64_t i=0; i < size; i++) {
96+
for (int64_t i = 0; i < size; i++)
97+
{
7698
o_array[i] = round(a_array[i], 0.5, sigma);
77-
o_array[i] = clamp_mask_helper<float>(o_array[i], t_min, t_max, m_array+i);
99+
o_array[i] = clamp_mask_helper<float>(o_array[i], t_min, t_max, m_array + i);
78100
}
79101
return std::make_tuple(o, m);
80102
}
81103

82-
Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool clamp, bool symmetric) {
104+
Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool clamp, bool symmetric)
105+
{
83106
CHECK_INPUT(a);
84107
auto r = rand_like(a);
85108
auto a_array = a.data<float>();
@@ -90,16 +113,19 @@ Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool clamp, boo
90113
int sigma = -fl;
91114
float t_min, t_max;
92115
fixed_min_max(wl, fl, symmetric, &t_min, &t_max);
93-
for (int64_t i=0; i < size; i++) {
116+
for (int64_t i = 0; i < size; i++)
117+
{
94118
o_array[i] = round(a_array[i], r_array[i], sigma);
95-
if (clamp) {
119+
if (clamp)
120+
{
96121
o_array[i] = clamp_helper(o_array[i], t_min, t_max);
97122
}
98123
}
99124
return o;
100125
}
101126

102-
Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool clamp, bool symmetric) {
127+
Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool clamp, bool symmetric)
128+
{
103129
CHECK_INPUT(a);
104130
auto a_array = a.data<float>();
105131
Tensor o = zeros_like(a);
@@ -108,31 +134,39 @@ Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool clamp, bool s
108134
int sigma = -fl;
109135
float t_min, t_max;
110136
fixed_min_max(wl, fl, symmetric, &t_min, &t_max);
111-
for (int64_t i=0; i < size; i++) {
137+
for (int64_t i = 0; i < size; i++)
138+
{
112139
o_array[i] = round(a_array[i], 0.5, sigma);
113-
if (clamp) {
140+
if (clamp)
141+
{
114142
o_array[i] = clamp_helper(o_array[i], t_min, t_max);
115143
}
116144
}
117145
return o;
118146
}
119147

120-
unsigned int round_bitwise(unsigned int target, int man_bits, Mode rounding){
121-
unsigned int mask = (1 << (23-man_bits)) - 1;
148+
unsigned int round_bitwise(unsigned int target, int man_bits, Mode rounding)
149+
{
150+
unsigned int mask = (1 << (23 - man_bits)) - 1;
122151
unsigned int rand_prob;
123-
if (rounding == rStochastic) {
152+
if (rounding == rStochastic)
153+
{
124154
rand_prob = (dis(gen)) & mask;
125-
} else {
126-
rand_prob = 1 << (23-man_bits-1);
127155
}
128-
unsigned int add_r = target+rand_prob;
156+
else
157+
{
158+
rand_prob = 1 << (23 - man_bits - 1);
159+
}
160+
unsigned int add_r = target + rand_prob;
129161
unsigned int quantized = add_r & ~mask;
130162
return quantized;
131163
}
132164

133-
void block_quantize_helper(float* input, float* output, float* max_elem,
134-
int wl, int size, Mode rounding) {
135-
for (int64_t i=0; i < size; i++) {
165+
void block_quantize_helper(float *input, float *output, float *max_elem,
166+
int wl, int size, Mode rounding)
167+
{
168+
for (int64_t i = 0; i < size; i++)
169+
{
136170

137171
unsigned int max_num;
138172
FLOAT_TO_BITS(max_elem[i], max_num);
@@ -141,31 +175,37 @@ void block_quantize_helper(float* input, float* output, float* max_elem,
141175
BITS_TO_FLOAT(max_exp, base_float);
142176
base_float *= 6;
143177

144-
float target_rebase = input[i]+base_float;
178+
float target_rebase = input[i] + base_float;
145179
unsigned int target_bits;
146180
FLOAT_TO_BITS(target_rebase, target_bits);
147181
unsigned int quantized_bits = round_bitwise(target_bits, wl, rounding); // -1 sign, -1 virtual, +2 base
148182
float quantized_rebase;
149183
BITS_TO_FLOAT(quantized_bits, quantized_rebase);
150-
float quantized = quantized_rebase-base_float;
184+
float quantized = quantized_rebase - base_float;
151185

152186
unsigned int quantize_bits;
153187
FLOAT_TO_BITS(quantized, quantize_bits);
154-
unsigned int clip_quantize = clip_max_exponent(wl-2, max_exp, quantize_bits);
188+
unsigned int clip_quantize = clip_max_exponent(wl - 2, max_exp, quantize_bits);
155189
BITS_TO_FLOAT(clip_quantize, quantized);
156190

157191
output[i] = quantized;
158192
}
159193
}
160194

161-
Tensor get_max_entry(Tensor a, int dim) {
195+
Tensor get_max_entry(Tensor a, int dim)
196+
{
162197
Tensor max_entry;
163-
if (dim == -1) {
198+
if (dim == -1)
199+
{
164200
max_entry = at::max(at::abs(a)).expand_as(a).contiguous();
165-
} else if (dim == 0) {
201+
}
202+
else if (dim == 0)
203+
{
166204
Tensor input_view = a.view({a.size(0), -1});
167205
max_entry = std::get<0>(input_view.max(1, true)).abs().expand_as(input_view).view_as(a).contiguous();
168-
} else {
206+
}
207+
else
208+
{
169209
Tensor input_transpose = a.transpose(0, dim);
170210
Tensor input_view = input_transpose.contiguous().view({input_transpose.size(0), -1});
171211
Tensor max_transpose = std::get<0>(input_view.max(1, true)).abs().expand_as(input_view).view_as(input_transpose);
@@ -174,7 +214,8 @@ Tensor get_max_entry(Tensor a, int dim) {
174214
return max_entry;
175215
}
176216

177-
Tensor block_quantize_nearest(Tensor a, int wl, int dim) {
217+
Tensor block_quantize_nearest(Tensor a, int wl, int dim)
218+
{
178219
CHECK_INPUT(a);
179220
auto a_array = a.data<float>();
180221
Tensor o = zeros_like(a);
@@ -188,7 +229,8 @@ Tensor block_quantize_nearest(Tensor a, int wl, int dim) {
188229
return o;
189230
}
190231

191-
Tensor block_quantize_stochastic(Tensor a, int wl, int dim) {
232+
Tensor block_quantize_stochastic(Tensor a, int wl, int dim)
233+
{
192234
CHECK_INPUT(a);
193235
auto a_array = a.data<float>();
194236
Tensor o = zeros_like(a);
@@ -203,15 +245,16 @@ Tensor block_quantize_stochastic(Tensor a, int wl, int dim) {
203245
return o;
204246
}
205247

206-
207-
Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits) {
248+
Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits)
249+
{
208250
// use external random number right now
209251
auto a_array = a.data<float>();
210252
auto o = zeros_like(a);
211253
auto o_array = o.data<float>();
212254
int size = a.numel();
213255

214-
for (int64_t i=0; i < size; i++) {
256+
for (int64_t i = 0; i < size; i++)
257+
{
215258
unsigned int target;
216259
FLOAT_TO_BITS(a_array[i], target);
217260
unsigned int quantize_bits = round_bitwise(target, man_bits, rStochastic);
@@ -223,13 +266,15 @@ Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits) {
223266
return o;
224267
}
225268

226-
Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits) {
269+
Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits)
270+
{
227271
auto a_array = a.data<float>();
228272
auto o = zeros_like(a);
229273
auto o_array = o.data<float>();
230274
int size = a.numel();
231275

232-
for (int64_t i=0; i < size; i++) {
276+
for (int64_t i = 0; i < size; i++)
277+
{
233278
unsigned int target;
234279
FLOAT_TO_BITS(a_array[i], target);
235280
unsigned int quantize_bits = round_bitwise(target, man_bits, rNearest);
@@ -241,7 +286,8 @@ Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits) {
241286
return o;
242287
}
243288

244-
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
289+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m)
290+
{
245291
m.def("fixed_point_quantize_stochastic_mask", &fixed_point_quantize_stochastic_mask, "Fixed Point Number Stochastic Quantization with Mask (CPU)");
246292
m.def("fixed_point_quantize_stochastic", &fixed_point_quantize_stochastic, "Fixed Point Number Stochastic Quantization (CPU)");
247293
m.def("block_quantize_stochastic", &block_quantize_stochastic, "Block Floating Point Number Stochastic Quantization (CPU)");

qtorch/quant/quant_cpu/quant_cpu.h

+4-4
Original file line numberDiff line numberDiff line change
@@ -7,12 +7,12 @@ unsigned int clip_max_exponent(int man_bits,
77
unsigned int max_exponent,
88
unsigned int quantized_num);
99

10-
11-
template <typename T> T clamp_helper(T a, T min, T max);
10+
template <typename T>
11+
T clamp_helper(T a, T min, T max);
1212

1313
template <typename T>
14-
T clamp_mask_helper(T a, T min, T max, uint8_t* mask);
14+
T clamp_mask_helper(T a, T min, T max, uint8_t *mask);
1515

16-
void fixed_min_max(int wl, int fl, bool symmetric, float* t_min, float* t_max);
16+
void fixed_min_max(int wl, int fl, bool symmetric, float *t_min, float *t_max);
1717

1818
float round(float a, float r, int sigma);

0 commit comments

Comments
 (0)