-
Notifications
You must be signed in to change notification settings - Fork 0
/
s1_quantzation.py
147 lines (125 loc) · 5.89 KB
/
s1_quantzation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import torch
import torch.nn as nn
from torch.autograd import Variable
from torchvision import models
import torch.nn.functional as F
#This file implements S1* quantization. If you want to implement S1 quantization, then just comment the lines shown below.
class _quantize_func(torch.autograd.Function):
@staticmethod
def forward(ctx, input, step_size, half_lvls):
# ctx is a context object that can be used to stash information
# for backward computation
ctx.step_size = step_size
ctx.half_lvls = half_lvls
output = F.hardtanh(input,
min_val=-ctx.half_lvls * ctx.step_size.item(),
max_val=ctx.half_lvls * ctx.step_size.item())
output = output / ctx.step_size
output[(output >= -2) & (output < 1.5)] += 0.5
output = torch.floor(output)
output[(output >= 2) & (output % 2 == 0)] += 1
output[(output <= -2) & (output % 2 != 0)] += 1
return output
@staticmethod
def backward(ctx, grad_output):
grad_input = grad_output.clone() / ctx.step_size
return grad_input, None, None
quantize = _quantize_func.apply
class quan_Conv2d(nn.Conv2d):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride=1,
padding=0,
dilation=1,
groups=1,
bias=True):
super(quan_Conv2d, self).__init__(in_channels,
out_channels,
kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=bias)
self.N_bits = 8
self.full_lvls = 2**self.N_bits
self.half_lvls = (self.full_lvls - 2) / 2
# Initialize the step size
self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
self.__reset_stepsize__()
# flag to enable the inference with quantized weight or self.weight
self.inf_with_weight = False # disabled by default
# create a vector to identify the weight to each bit
self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
end=-1,
step=-1).unsqueeze(-1).float(),
requires_grad=False)
self.b_w[0] = -self.b_w[0] #in-place change MSB to negative
def forward(self, input):
if self.inf_with_weight:
return F.conv2d(input, self.weight * self.step_size, self.bias,
self.stride, self.padding, self.dilation,
self.groups)
else:
self.__reset_stepsize__()
weight_quan = quantize(self.weight, self.step_size,
self.half_lvls) * self.step_size
return F.conv2d(input, weight_quan, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def __reset_stepsize__(self):
with torch.no_grad():
self.step_size.data = self.weight.abs().max() / self.half_lvls
def __reset_weight__(self):
'''
This function will reconstruct the weight stored in self.weight.
Replacing the original floating-point with the quantized fix-point
weight representation.
'''
# replace the weight with the quantized version
with torch.no_grad():
self.weight.data = quantize(self.weight, self.step_size,
self.half_lvls)
# enable the flag, thus now computation does not invovle weight quantization
self.inf_with_weight = True
class quan_Linear(nn.Linear):
def __init__(self, in_features, out_features, bias=True):
super(quan_Linear, self).__init__(in_features, out_features, bias=bias)
self.N_bits = 8
self.full_lvls = 2**self.N_bits
self.half_lvls = (self.full_lvls - 2) / 2
# Initialize the step size
self.step_size = nn.Parameter(torch.Tensor([1]), requires_grad=True)
self.__reset_stepsize__()
# flag to enable the inference with quantized weight or self.weight
self.inf_with_weight = False # disabled by default
# create a vector to identify the weight to each bit
self.b_w = nn.Parameter(2**torch.arange(start=self.N_bits - 1,
end=-1,
step=-1).unsqueeze(-1).float(),
requires_grad=False)
self.b_w[0] = -self.b_w[0] #in-place reverse
def forward(self, input):
if self.inf_with_weight:
return F.linear(input, self.weight * self.step_size, self.bias)
else:
self.__reset_stepsize__()
weight_quan = quantize(self.weight, self.step_size,
self.half_lvls) * self.step_size
return F.linear(input, weight_quan, self.bias)
def __reset_stepsize__(self):
with torch.no_grad():
self.step_size.data = self.weight.abs().max() / self.half_lvls
def __reset_weight__(self):
'''
This function will reconstruct the weight stored in self.weight.
Replacing the orginal floating-point with the quantized fix-point
weight representation.
'''
# replace the weight with the quantized version
with torch.no_grad():
self.weight.data = quantize(self.weight, self.step_size,
self.half_lvls)
# enable the flag, thus now computation does not invovle weight quantization
self.inf_with_weight = True