Skip to content

Commit 5d34f5a

Browse files
authored
Merge pull request #32 from Tiiiger/0.2.0
upgrade to 0.2.0
2 parents 01d353e + c2812ad commit 5d34f5a

File tree

11 files changed

+122
-136
lines changed

11 files changed

+122
-136
lines changed

README.md

+18-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,22 @@
11
# QPyTorch
22
[![Downloads](https://pepy.tech/badge/qtorch)](https://pepy.tech/project/qtorch) [![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT)
33

4+
#### News:
5+
- Updated to version 0.2.0:
6+
- **Bug fixed**: previously in our floating point quantization, numbers that are closer to 0 than the smallest
7+
representable positive number rounded to the smallest rep positive number. Now we round to 0 or the smallest
8+
representable number based on which one is the nearest
9+
- **Different Behavior**: To be consistent with PyTorch [Issue #17443](https://github.com/pytorch/pytorch/pull/17443),
10+
we round the nearest even now.
11+
- We migrate to PyTorch 1.5.0. There are several changes in the C++ API of PyTorch.
12+
This new version is not backward-compatible with older PyTorch.
13+
- *Note*: if you are using CUDA 10.1, please install CUDA 10.1 Update 1 (or later version). There is a bug in
14+
the first version of CUDA 10.1 which leads to compilation error.
15+
- *Note*: previous users, please remove the cache in the pytorch extension directory.
16+
For example, you can run this command `rm -rf /tmp/torch_extensions/quant_cuda /tmp/torch_extensions/quant_cuda` if
17+
you are using the default directory for pytorch extensions.
18+
19+
420
QPyTorch is a low-precision arithmetic simulation package in
521
PyTorch. It is designed to support researches on low-precision machine
622
learning, especially for researches in low-precision training.
@@ -30,8 +46,9 @@ and QPyTorch's simulation of half-precision numbers.
3046
requirements:
3147

3248
- Python >= 3.6
33-
- PyTorch >= 1.0
49+
- PyTorch >= 1.5.0
3450
- GCC >= 4.9 on linux
51+
- CUDA >= 10.1 on linux
3552

3653
Install other requirements by:
3754
```bash

examples/SWALP/README.md

+4-4
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,10 @@ bash example.sh
5353
```
5454

5555
## Results
56-
| Datset | Model | SGD-FP | SWA-FP | SGD-LP | SWALP |
57-
|----------|--------------|------------|------------|------------|------------|
58-
| CIFAR10 | VGG16 | 6.81±0.09 | 6.51±0.14 | 7.61±0.15 | 6.70±0.12 |
59-
| CIFAR100 | VGG16 | 27.23±0.17 | 25.93±0.21 | 29.59±0.32 | 26.65±0.29 |
56+
| Datset | Model | SGD-FP | SWA-FP | SGD-LP | SWALP |
57+
| -------- | ----- | ---------- | ---------- | ---------- | ---------- |
58+
| CIFAR10 | VGG16 | 6.81±0.09 | 6.51±0.14 | 7.61±0.15 | 6.70±0.12 |
59+
| CIFAR100 | VGG16 | 27.23±0.17 | 25.93±0.21 | 29.59±0.32 | 26.65±0.29 |
6060

6161
## References
6262
This repo is modified from the PyTorch repo of [SWALP](https://github.com/stevenygd/SWALP)

qtorch/quant/quant_cpu/quant_cpu.cpp

+24-24
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,8 @@ enum Mode
1212
rStochastic
1313
};
1414

15-
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
16-
#define CHECK_CPU(x) AT_CHECK(!x.type().is_cuda(), #x " must be a CPU tensor")
15+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
16+
#define CHECK_CPU(x) TORCH_CHECK(!x.is_cuda(), #x " must be a CPU tensor")
1717
#define CHECK_INPUT(x) \
1818
CHECK_CPU(x); \
1919
CHECK_CONTIGUOUS(x);
@@ -63,12 +63,12 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl
6363
{
6464
CHECK_INPUT(a);
6565
auto r = rand_like(a);
66-
auto a_array = a.data<float>();
67-
auto r_array = r.data<float>();
66+
auto a_array = a.data_ptr<float>();
67+
auto r_array = r.data_ptr<float>();
6868
auto o = zeros_like(a);
69-
auto o_array = o.data<float>();
69+
auto o_array = o.data_ptr<float>();
7070
auto m = zeros_like(a, torch::CPU(kByte));
71-
auto m_array = m.data<uint8_t>();
71+
auto m_array = m.data_ptr<uint8_t>();
7272
int64_t size = a.numel();
7373
int sigma = -fl;
7474
float t_min, t_max;
@@ -84,11 +84,11 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask(Tensor a, int wl
8484
std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask(Tensor a, int wl, int fl, bool symmetric)
8585
{
8686
CHECK_INPUT(a);
87-
auto a_array = a.data<float>();
87+
auto a_array = a.data_ptr<float>();
8888
auto o = zeros_like(a);
89-
auto o_array = o.data<float>();
89+
auto o_array = o.data_ptr<float>();
9090
auto m = zeros_like(a, torch::CPU(kByte));
91-
auto m_array = m.data<uint8_t>();
91+
auto m_array = m.data_ptr<uint8_t>();
9292
int64_t size = a.numel();
9393
int sigma = -fl;
9494
float t_min, t_max;
@@ -105,10 +105,10 @@ Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool clamp, boo
105105
{
106106
CHECK_INPUT(a);
107107
auto r = rand_like(a);
108-
auto a_array = a.data<float>();
109-
auto r_array = r.data<float>();
108+
auto a_array = a.data_ptr<float>();
109+
auto r_array = r.data_ptr<float>();
110110
Tensor o = zeros_like(a);
111-
auto o_array = o.data<float>();
111+
auto o_array = o.data_ptr<float>();
112112
int64_t size = a.numel();
113113
int sigma = -fl;
114114
float t_min, t_max;
@@ -127,9 +127,9 @@ Tensor fixed_point_quantize_stochastic(Tensor a, int wl, int fl, bool clamp, boo
127127
Tensor fixed_point_quantize_nearest(Tensor a, int wl, int fl, bool clamp, bool symmetric)
128128
{
129129
CHECK_INPUT(a);
130-
auto a_array = a.data<float>();
130+
auto a_array = a.data_ptr<float>();
131131
Tensor o = zeros_like(a);
132-
auto o_array = o.data<float>();
132+
auto o_array = o.data_ptr<float>();
133133
int64_t size = a.numel();
134134
int sigma = -fl;
135135
float t_min, t_max;
@@ -217,29 +217,29 @@ Tensor get_max_entry(Tensor a, int dim)
217217
Tensor block_quantize_nearest(Tensor a, int wl, int dim)
218218
{
219219
CHECK_INPUT(a);
220-
auto a_array = a.data<float>();
220+
auto a_array = a.data_ptr<float>();
221221
Tensor o = zeros_like(a);
222-
auto o_array = o.data<float>();
222+
auto o_array = o.data_ptr<float>();
223223
int64_t size = a.numel();
224224

225225
// get maximum number and base
226226
Tensor max_entry = get_max_entry(a, dim);
227-
auto max_elem = max_entry.data<float>();
227+
auto max_elem = max_entry.data_ptr<float>();
228228
block_quantize_helper(a_array, o_array, max_elem, wl, size, rNearest);
229229
return o;
230230
}
231231

232232
Tensor block_quantize_stochastic(Tensor a, int wl, int dim)
233233
{
234234
CHECK_INPUT(a);
235-
auto a_array = a.data<float>();
235+
auto a_array = a.data_ptr<float>();
236236
Tensor o = zeros_like(a);
237-
auto o_array = o.data<float>();
237+
auto o_array = o.data_ptr<float>();
238238
int64_t size = a.numel();
239239

240240
// get maximum number and base
241241
Tensor max_entry = get_max_entry(a, dim);
242-
auto max_elem = max_entry.data<float>();
242+
auto max_elem = max_entry.data_ptr<float>();
243243
// std::srand(time(0));
244244
block_quantize_helper(a_array, o_array, max_elem, wl, size, rStochastic);
245245
return o;
@@ -248,9 +248,9 @@ Tensor block_quantize_stochastic(Tensor a, int wl, int dim)
248248
Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits)
249249
{
250250
// use external random number right now
251-
auto a_array = a.data<float>();
251+
auto a_array = a.data_ptr<float>();
252252
auto o = zeros_like(a);
253-
auto o_array = o.data<float>();
253+
auto o_array = o.data_ptr<float>();
254254
int size = a.numel();
255255

256256
for (int64_t i = 0; i < size; i++)
@@ -268,9 +268,9 @@ Tensor float_quantize_stochastic(Tensor a, int man_bits, int exp_bits)
268268

269269
Tensor float_quantize_nearest(Tensor a, int man_bits, int exp_bits)
270270
{
271-
auto a_array = a.data<float>();
271+
auto a_array = a.data_ptr<float>();
272272
auto o = zeros_like(a);
273-
auto o_array = o.data<float>();
273+
auto o_array = o.data_ptr<float>();
274274
int size = a.numel();
275275

276276
for (int64_t i = 0; i < size; i++)

qtorch/quant/quant_cpu/sim_helper.cpp

+10-6
Original file line numberDiff line numberDiff line change
@@ -2,16 +2,20 @@
22
#include <math.h>
33
#include <stdint.h>
44

5-
void fixed_min_max(int wl, int fl, bool symmetric, float* t_min, float* t_max) {
5+
void fixed_min_max(int wl, int fl, bool symmetric, float *t_min, float *t_max)
6+
{
67
int sigma = -fl;
7-
*t_min = -ldexp(1.0, wl-fl-1);
8-
*t_max = -*t_min-ldexp(1.0, sigma);
9-
if (symmetric) *t_min = *t_min+ldexp(1.0, sigma);
8+
*t_min = -ldexp(1.0, wl - fl - 1);
9+
*t_max = -*t_min - ldexp(1.0, sigma);
10+
if (symmetric)
11+
*t_min = *t_min + ldexp(1.0, sigma);
1012
}
1113

12-
float round(float a, float r, int sigma) {
14+
float round(float a, float r, int sigma)
15+
{
1316
a = ldexp(a, -sigma);
14-
a = floor(a+r);
17+
a = nearbyint(a + r - 0.5);
18+
// a = floor(a + r);
1519
a = ldexp(a, sigma);
1620
return a;
1721
}

qtorch/quant/quant_cuda/quant.cu

+31-31
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,11 @@ Tensor block_quantize_stochastic_cuda(Tensor a, int wl, int dim) {
3636
int blockSize = 1024;
3737
int blockNums = (size + blockSize - 1) / blockSize;
3838

39-
block_kernel_stochastic<<<blockNums, blockSize>>>(a.data<float>(),
40-
rand_ints.data<int>(),
41-
o.data<float>(),
39+
block_kernel_stochastic<<<blockNums, blockSize>>>(a.data_ptr<float>(),
40+
rand_ints.data_ptr<int>(),
41+
o.data_ptr<float>(),
4242
size,
43-
max_entry.data<float>(),
43+
max_entry.data_ptr<float>(),
4444
wl);
4545
return o;
4646
}
@@ -53,10 +53,10 @@ Tensor block_quantize_nearest_cuda(Tensor a, int wl, int dim) {
5353
int blockSize = 1024;
5454
int blockNums = (size + blockSize - 1) / blockSize;
5555

56-
block_kernel_nearest<<<blockNums, blockSize>>>(a.data<float>(),
57-
o.data<float>(),
56+
block_kernel_nearest<<<blockNums, blockSize>>>(a.data_ptr<float>(),
57+
o.data_ptr<float>(),
5858
size,
59-
max_entry.data<float>(),
59+
max_entry.data_ptr<float>(),
6060
wl);
6161
return o;
6262
}
@@ -70,11 +70,11 @@ Tensor block_quantize_sim_stochastic_cuda(Tensor a, int wl) {
7070
int blockSize = 1024;
7171
int blockNums = (size + blockSize - 1) / blockSize;
7272

73-
block_kernel_sim_stochastic<<<blockNums, blockSize>>>(a.data<float>(),
74-
rand_probs.data<float>(),
75-
o.data<float>(),
73+
block_kernel_sim_stochastic<<<blockNums, blockSize>>>(a.data_ptr<float>(),
74+
rand_probs.data_ptr<float>(),
75+
o.data_ptr<float>(),
7676
size,
77-
max_entry.data<float>(),
77+
max_entry.data_ptr<float>(),
7878
wl);
7979
return o;
8080
}
@@ -88,10 +88,10 @@ Tensor block_quantize_sim_nearest_cuda(Tensor a, int wl) {
8888
int blockSize = 1024;
8989
int blockNums = (size + blockSize - 1) / blockSize;
9090

91-
block_kernel_sim_nearest<<<blockNums, blockSize>>>(a.data<float>(),
92-
o.data<float>(),
91+
block_kernel_sim_nearest<<<blockNums, blockSize>>>(a.data_ptr<float>(),
92+
o.data_ptr<float>(),
9393
size,
94-
max_entry.data<float>(),
94+
max_entry.data_ptr<float>(),
9595
wl);
9696
return o;
9797
}
@@ -104,9 +104,9 @@ Tensor float_quantize_stochastic_cuda(Tensor a, int man_bits, int exp_bits) {
104104
int blockSize = 1024;
105105
int blockNums = (size + blockSize - 1) / blockSize;
106106

107-
float_kernel_stochastic<<<blockNums, blockSize>>>(a.data<float>(),
108-
rand_ints.data<int>(),
109-
o.data<float>(),
107+
float_kernel_stochastic<<<blockNums, blockSize>>>(a.data_ptr<float>(),
108+
rand_ints.data_ptr<int>(),
109+
o.data_ptr<float>(),
110110
size,
111111
man_bits,
112112
exp_bits);
@@ -120,8 +120,8 @@ Tensor float_quantize_nearest_cuda(Tensor a, int man_bits, int exp_bits) {
120120
int blockSize = 1024;
121121
int blockNums = (size + blockSize - 1) / blockSize;
122122

123-
float_kernel_nearest<<<blockNums, blockSize>>>(a.data<float>(),
124-
o.data<float>(),
123+
float_kernel_nearest<<<blockNums, blockSize>>>(a.data_ptr<float>(),
124+
o.data_ptr<float>(),
125125
size,
126126
man_bits,
127127
exp_bits);
@@ -146,9 +146,9 @@ Tensor fixed_point_quantize_stochastic_cuda(Tensor a, int wl, int fl, bool use_c
146146
int blockSize = 1024;
147147
int blockNums = (size + blockSize - 1) / blockSize;
148148

149-
fixed_point_quantize_kernel_stochastic<<<blockNums, blockSize>>>(a.data<float>(),
150-
rand_probs.data<float>(),
151-
o.data<float>(),
149+
fixed_point_quantize_kernel_stochastic<<<blockNums, blockSize>>>(a.data_ptr<float>(),
150+
rand_probs.data_ptr<float>(),
151+
o.data_ptr<float>(),
152152
size,
153153
sigma,
154154
use_clamp,
@@ -167,8 +167,8 @@ Tensor fixed_point_quantize_nearest_cuda(Tensor a, int wl, int fl, bool use_clam
167167
int blockSize = 1024;
168168
int blockNums = (size + blockSize - 1) / blockSize;
169169

170-
fixed_point_quantize_kernel_nearest<<<blockNums, blockSize>>>(a.data<float>(),
171-
o.data<float>(),
170+
fixed_point_quantize_kernel_nearest<<<blockNums, blockSize>>>(a.data_ptr<float>(),
171+
o.data_ptr<float>(),
172172
size,
173173
sigma,
174174
use_clamp,
@@ -189,10 +189,10 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_stochastic_mask_cuda(Tensor a, i
189189
int blockSize = 1024;
190190
int blockNums = (size + blockSize - 1) / blockSize;
191191

192-
fixed_point_quantize_kernel_mask_stochastic<<<blockNums, blockSize>>>(a.data<float>(),
193-
rand_probs.data<float>(),
194-
o.data<float>(),
195-
m.data<uint8_t>(),
192+
fixed_point_quantize_kernel_mask_stochastic<<<blockNums, blockSize>>>(a.data_ptr<float>(),
193+
rand_probs.data_ptr<float>(),
194+
o.data_ptr<float>(),
195+
m.data_ptr<uint8_t>(),
196196
size,
197197
sigma,
198198
t_min,
@@ -211,9 +211,9 @@ std::tuple<Tensor, Tensor> fixed_point_quantize_nearest_mask_cuda(Tensor a, int
211211
int blockSize = 1024;
212212
int blockNums = (size + blockSize - 1) / blockSize;
213213

214-
fixed_point_quantize_kernel_mask_nearest<<<blockNums, blockSize>>>(a.data<float>(),
215-
o.data<float>(),
216-
m.data<uint8_t>(),
214+
fixed_point_quantize_kernel_mask_nearest<<<blockNums, blockSize>>>(a.data_ptr<float>(),
215+
o.data_ptr<float>(),
216+
m.data_ptr<uint8_t>(),
217217
size,
218218
sigma,
219219
t_min,

qtorch/quant/quant_cuda/quant_cuda.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44

55
using namespace at;
66

7-
#define CHECK_CUDA(x) AT_CHECK(x.type().is_cuda(), #x " must be a CUDA tensor")
8-
#define CHECK_CONTIGUOUS(x) AT_CHECK(x.is_contiguous(), #x " must be contiguous")
7+
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
8+
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
99
#define CHECK_INPUT(x) \
1010
CHECK_CUDA(x); \
1111
CHECK_CONTIGUOUS(x)

qtorch/quant/quant_cuda/sim_helper.cu

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22
#include <cmath>
33

44
__device__ __forceinline__ float round_helper(float a, float r) {
5-
return floor(a+r);
5+
// return floor(a+r);
6+
return nearbyint(a+r-0.5);
67
}
78

89
__device__ __forceinline__ float round(float a, float r, int sigma) {

0 commit comments

Comments
 (0)