Skip to content

Commit 6f81d2a

Browse files
authored
Merge pull request #91 from CortexFoundation/ryt
MRT progress
2 parents 6cbd822 + 23037a9 commit 6f81d2a

20 files changed

+856
-4
lines changed

python/mrt/sym_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -471,6 +471,9 @@ def get_entry_id(sym):
471471
oindex = json.loads(graph.json())['heads'][0][1]
472472
return oindex
473473

474+
def has_multi_outs(sym):
475+
return sym.attr('op_name') in MULTIPYE_OUTS_NODE
476+
474477
def get_node(sym, graph):
475478
""" Get the symbol from the provided graph which has the same name as the given symbol.
476479

python/mrt/yamrt/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
from .fquant import *
2+
from .modelhandler import ModelHandler, MxnetModelHandler
3+
from mrt.transformer import *
4+
from .autoquanter import *

python/mrt/yamrt/autoquanter.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
# General
2+
from .modelhandler import ModelHandler
3+
4+
# Mxnet
5+
from mrt.transformer import *
6+
from mrt import tfm_pass as tpass
7+
8+
9+
class AutoQuanter(object):
10+
def __init__(self, model:ModelHandler):
11+
self._model = model
12+
13+
def prepare(self, *args, **kwargs):
14+
raise NotImplementedError
15+
16+
def ptq_pre(self, *args, **kwargs):
17+
raise NotImplementedError
18+
19+
def ptq_pre_param(self, *args, **kwargs):
20+
raise NotImplementedError
21+
22+
def ptq(self, *args, **kwargs):
23+
raise NotImplementedError
24+
25+
def ptq_collect(self, *args, **kwargs):
26+
raise NotImplementedError
27+
28+
#TODO: Add full APIs.
29+
30+
class MxnetAutoQuanter(AutoQuanter):
31+
def __init__(self, model:ModelHandler):
32+
super(MxnetAutoQuanter, self).__init__(model)
33+
34+
def prepare(self, input_shape:dict=None): #TODO: Turn configurable like ptq_pre.
35+
assert(input_shape is not None)
36+
self._model.visit_model(tpass.name_duplicate_check)
37+
if isinstance(input_shape, dict):
38+
self._model.update_model(tpass.attach_input_shape, input_shape=input_shape)
39+
self._model.update_model(tpass.fuse_multiple_inputs)
40+
elif input_shape is not None:
41+
model_inputs = self._model.visit_model(tpass.model_inputs)
42+
assert model_inputs == 1, "Multiple inputs non-known shape"
43+
self._model.update_model(tpass.input_name_replace)
44+
self._model.update_model(tpass.attach_input_shape, {"data": input_shape})
45+
self._model.visit_model(tpass.infer_shape)
46+
47+
self._model.update_model(tpass.fuse_multiple_outputs)
48+
self._model.update_model(tpass.fuse_constant)
49+
self._model.update_model(tpass.fuse_transpose)
50+
self._model.update_model(tpass.rewrite)
51+
self._model.update_model(tpass.fuse_constant)
52+
self._model.update_model(tpass.params_unique)
53+
54+
def ptq_pre(self, rule_list):
55+
self._model.update_model(tpass.ptq_pre, rule_list=rule_list)
56+
57+
def ptq_pre_param(self, config):
58+
pass
59+
60+
def ptq(self, ):
61+
62+
raise NotImplementedError
63+
64+
def ptq_collect(self):
65+
raise NotImplementedError

python/mrt/yamrt/fquant/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
from .uniform_affine_quantizer import *
2+
from .proxy import *

python/mrt/yamrt/fquant/common.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import mxnet as mx
2+
3+
QUANT_OP_PREFIX = "MRT_"
4+
5+
6+
7+
class Wrapper(object):
8+
"""Basic Class for Quantization Info, Factory Functions, etc.
9+
"""
10+
def __init__(self, op:mx.sym.Symbol, config:dict):
11+
self._ori_op = op
12+
self._config = config
13+
self._attr_dict = {}
14+
self._build_attr_dict()
15+
self._op = None
16+
self._param = None
17+
18+
def _build_attr_dict(self):
19+
raise NotImplementedError
20+
21+
def new_op(self):
22+
self._op = mx.sym.Custom(**self._attr_dict)
23+
return self._op
24+
25+
def op(self):
26+
return self._op
27+
28+
def attr(self, key:str):
29+
if key in self._attr_dict:
30+
return self._attr_dict[key]
31+
return 'null'
32+
33+
def key(self):
34+
return self._attr_dict[name]
35+
36+
def init_param(self, *args, **kwargs):
37+
raise NotImplementedError
38+
39+
def param(self)->dict:
40+
return self._param

python/mrt/yamrt/fquant/proxy.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from .common import *
2+
import mxnet as mx
3+
4+
5+
class ProxyWrapper(Wrapper):
6+
def __init__(self, op, config):
7+
super(ProxyWrapper, self).__init__(op, config)
8+
9+
def _build_attr_dict(self):
10+
# None Symble
11+
self._attr_dict['op_type'] = self._config['q_op_name']
12+
self._attr_dict['name'] = f"{self._attr_dict['op_type']}_{self._ori_op.attr('name')}"
13+
# Symbles
14+
self._attr_dict['data'] = self._ori_op
15+
self._attr_dict['qbias'] = mx.sym.Variable(**self._ori_op.list_attr(), name=f"{self._attr_dict['name']}_qbias")
16+
17+
18+
class Proxy(mx.operator.CustomOp):
19+
def __init__(self):
20+
super(Proxy, self).__init__()
21+
22+
def forward(self, is_train, req, in_data, out_data, aux):
23+
self.assign(out_data[0], req[0], in_data[1])
24+
25+
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): # Seems like checkpoint techs in pytorch
26+
assert(req[0] == req[1])
27+
self.assign(in_grad[1], req[0], out_grad[0])
28+
29+
30+
@mx.operator.register(QUANT_OP_PREFIX + "Proxy")
31+
class ProxyProp(mx.operator.CustomOpProp):
32+
def __init__(self):
33+
super(ProxyProp, self).__init__()
34+
35+
def list_arguments(self):
36+
return ['data', 'qbias']
37+
38+
def list_outputs(self):
39+
return ['output']
40+
41+
def infer_shape(self, in_shape):
42+
assert(len(in_shape)==2)
43+
return [*in_shape], [in_shape[0]], []
44+
45+
def infer_type(self, in_type):
46+
return [*in_type], [in_type[0]], []
47+
48+
def create_operator(self, ctx, shapes, dtypes):
49+
return Proxy()
50+
Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
from .common import *
2+
import mxnet as mx
3+
import mxnet.ndarray as nd
4+
5+
def _round_ste(x):
6+
return mx.nd.stop_gradient(mx.nd.round(x) - x) + x
7+
8+
9+
def _new_detached_nd(*args):
10+
res = []
11+
for item in args:
12+
res.append(item.detach())
13+
return res
14+
15+
16+
class UniformAffineQuantizerWrapper(Wrapper):
17+
_scale_methods = ['max_scale', 'max', 'mse']
18+
def __init__(self, op, config):
19+
self.channel_wise = False
20+
self.scale_method = config['scale_method'] if 'scale_method' in config else _scale_methods[0]
21+
super(UniformAffineQuantizerWrapper, self).__init__(op, config)
22+
self.delta_nd = None
23+
self.delta_op = None
24+
self.zero_point_nd = None
25+
self.zero_point_op = None
26+
27+
def _build_attr_dict(self):
28+
assert(self._config['q_op_name'] not in self._ori_op.attr('name'))
29+
# None Symble
30+
self._attr_dict['op_type'] = self._config['q_op_name']
31+
self._attr_dict['name'] = f"{self._attr_dict['op_type']}_{self._ori_op.attr('name')}"
32+
self._attr_dict['n_bits'] = self._config['n_bits']
33+
self.channel_wise = self._config['channel_wise']
34+
# Symbles
35+
self._attr_dict['data'] = self._ori_op
36+
if not self.channel_wise:
37+
self.delta_op = mx.sym.Variable(f"{self._attr_dict['name']}_delta", shape=(1))
38+
self.zero_point_op = mx.sym.Variable(f"{self._attr_dict['name']}_zero_point", shape=(1))
39+
self._attr_dict['delta'] = self.delta_op
40+
self._attr_dict['zero_point'] = self.zero_point_op
41+
elif self.channel_wise:
42+
# Assume the the fisrt dim of input data is channel
43+
assert(len(self._ori_op.infer_shape()[1]) == 1)
44+
ori_op_shape = self._ori_op.infer_shape()[1][0]
45+
channel_wise_shape = (ori_op_shape[0], * ([1] * (len(ori_op_shape) - 1)))
46+
self.delta_op = mx.sym.Variable(
47+
f"{self._attr_dict['name']}_delta",
48+
shape=channel_wise_shape)
49+
self.zero_point_op = mx.sym.Variable(
50+
f"{self._attr_dict['name']}_zero_point",
51+
shape=channel_wise_shape)
52+
self._attr_dict['delta'] = self.delta_op
53+
self._attr_dict['zero_point'] = self.zero_point_op
54+
else:
55+
raise TypeError
56+
57+
def init_param(self, data: nd.NDArray):
58+
pass
59+
60+
def _init_param_impl(self, input_data: nd.NDArray, channel_wise:bool=False):
61+
delta, zero_point = None, None
62+
if channel_wise:
63+
x_clone = input_data.copy().detach()
64+
n_channels = x_clone.shape[0]
65+
if len(x.shape) == 4:
66+
x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0]
67+
else:
68+
x_max = x_clone.abs().max(dim=-1)[0]
69+
delta = x_max.clone()
70+
zero_point = x_max.clone()
71+
# determine the scale and zero point channel-by-channel
72+
for c in range(n_channels):
73+
delta[c], zero_point[c] = self.init_quantization_scale(x_clone[c], channel_wise=False)
74+
if len(x.shape) == 4:
75+
delta = delta.view(-1, 1, 1, 1)
76+
zero_point = zero_point.view(-1, 1, 1, 1)
77+
else:
78+
delta = delta.view(-1, 1)
79+
zero_point = zero_point.view(-1, 1)
80+
else:
81+
if 'max' in self.scale_method:
82+
x_min = min(x.min().item(), 0)
83+
x_max = max(x.max().item(), 0)
84+
if 'scale' in self.scale_method:
85+
x_min = x_min * (self.n_bits + 2) / 8
86+
x_max = x_max * (self.n_bits + 2) / 8
87+
88+
x_absmax = max(abs(x_min), x_max)
89+
if self.sym:
90+
x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax
91+
92+
delta = float(x_max - x_min) / (self.n_levels - 1)
93+
if delta < 1e-8:
94+
warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max))
95+
delta = 1e-8
96+
97+
zero_point = round(-x_min / delta)
98+
delta = torch.tensor(delta).type_as(x)
99+
100+
elif self.scale_method == 'mse':
101+
# we always use symmetric quantization in mse mode
102+
x_absmax = x.abs().max()
103+
x_min = x.min().item()
104+
best_score = 1000
105+
for i in range(80):
106+
new_max = x_absmax * (1.0 - (i * 0.01))
107+
x_q = self.quantize(x, new_max)
108+
# L_p norm minimization as described in LAPQ
109+
# https://arxiv.org/abs/1911.07190
110+
score = lp_loss(x, x_q, p=2.4, reduction='all')
111+
if score < best_score:
112+
best_score = score
113+
delta = (2 * new_max) / (2 ** self.n_bits - 1)
114+
zero_point = (new_max / delta).round() if x_min < 0 else 0
115+
# re-calculate the scale delta if zero-point is not 0,
116+
else:
117+
raise NotImplementedError
118+
# def init_param(self, data:nd.NDArray, scale_method:str='max'):
119+
# assert scale_method in _scale_methods
120+
# if self.channel_wise:
121+
# data_abs = data.abs()
122+
# data_max_per_channel =
123+
124+
125+
126+
class UniformAffineQuantizer(mx.operator.CustomOp):
127+
def __init__(self, n_bits):
128+
super(UniformAffineQuantizer, self).__init__()
129+
self.n_bits = n_bits
130+
self.n_levels = 2 ** self.n_bits
131+
132+
def forward(self, is_train, req, in_data, out_data, aux):
133+
conv_weight, delta, zero_point = in_data[0], in_data[1], in_data[2]
134+
x_int = _round_ste(conv_weight / delta) + zero_point #TODO: Zero point is hard to implemented in the Fully Quantized Conditions.
135+
x_quant = mx.nd.clip(x_int, 0, self.n_levels - 1)
136+
x_dequant = (x_quant - zero_point) * delta
137+
self.assign(out_data[0], req[0], x_dequant)
138+
139+
def backward(self, req, out_grad, in_data, out_data, in_grad, aux): # Seems like checkpoint techs in pytorch
140+
conv_weight, delta, zero_point = _new_detached_nd(*in_data[:3])# in_data[0].copy().detach(), in_data[1].copy().detach(), in_data[2].copy().detach()
141+
conv_weight.attach_grad()
142+
delta.attach_grad()
143+
zero_point.attach_grad()
144+
with mx.autograd.record():
145+
x_int = _round_ste(conv_weight / delta) + zero_point
146+
x_quant = mx.nd.clip(x_int, 0, self.n_levels - 1)
147+
x_dequant = (x_quant - zero_point) * delta
148+
x_dequant.backward(_new_detached_nd(out_grad[0])[0])
149+
150+
self.assign(in_grad[0], req[0], conv_weight.grad)
151+
self.assign(in_grad[1], req[1], delta.grad)
152+
self.assign(in_grad[2], req[2], zero_point.grad)
153+
154+
155+
@mx.operator.register(QUANT_OP_PREFIX + "UniformAffineQuantizer")
156+
class UniformAffineQuantizerProp(mx.operator.CustomOpProp):
157+
def __init__(self, n_bits):
158+
super(UniformAffineQuantizerProp, self).__init__()
159+
n_bits = n_bits if type(n_bits) is int else int(n_bits)
160+
161+
assert 2 <= n_bits <= 32, 'bitwidth not supported'
162+
self.n_bits = n_bits
163+
164+
def list_arguments(self):
165+
return ['data', 'delta', 'zero_point']
166+
167+
def list_outputs(self):
168+
return ['output']
169+
170+
def infer_shape(self, in_shape):
171+
assert(len(in_shape)==3)
172+
return [*in_shape], [in_shape[0]], []
173+
174+
def infer_type(self, in_type):
175+
return [*in_type], [in_type[0]], []
176+
177+
def create_operator(self, ctx, shapes, dtypes):
178+
return UniformAffineQuantizer(n_bits=self.n_bits)
179+

python/mrt/yamrt/model/__init__.py

Whitespace-only changes.

python/mrt/yamrt/model/block/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)