|
| 1 | +from .common import * |
| 2 | +import mxnet as mx |
| 3 | +import mxnet.ndarray as nd |
| 4 | + |
| 5 | +def _round_ste(x): |
| 6 | + return mx.nd.stop_gradient(mx.nd.round(x) - x) + x |
| 7 | + |
| 8 | + |
| 9 | +def _new_detached_nd(*args): |
| 10 | + res = [] |
| 11 | + for item in args: |
| 12 | + res.append(item.detach()) |
| 13 | + return res |
| 14 | + |
| 15 | + |
| 16 | +class UniformAffineQuantizerWrapper(Wrapper): |
| 17 | + _scale_methods = ['max_scale', 'max', 'mse'] |
| 18 | + def __init__(self, op, config): |
| 19 | + self.channel_wise = False |
| 20 | + self.scale_method = config['scale_method'] if 'scale_method' in config else _scale_methods[0] |
| 21 | + super(UniformAffineQuantizerWrapper, self).__init__(op, config) |
| 22 | + self.delta_nd = None |
| 23 | + self.delta_op = None |
| 24 | + self.zero_point_nd = None |
| 25 | + self.zero_point_op = None |
| 26 | + |
| 27 | + def _build_attr_dict(self): |
| 28 | + assert(self._config['q_op_name'] not in self._ori_op.attr('name')) |
| 29 | + # None Symble |
| 30 | + self._attr_dict['op_type'] = self._config['q_op_name'] |
| 31 | + self._attr_dict['name'] = f"{self._attr_dict['op_type']}_{self._ori_op.attr('name')}" |
| 32 | + self._attr_dict['n_bits'] = self._config['n_bits'] |
| 33 | + self.channel_wise = self._config['channel_wise'] |
| 34 | + # Symbles |
| 35 | + self._attr_dict['data'] = self._ori_op |
| 36 | + if not self.channel_wise: |
| 37 | + self.delta_op = mx.sym.Variable(f"{self._attr_dict['name']}_delta", shape=(1)) |
| 38 | + self.zero_point_op = mx.sym.Variable(f"{self._attr_dict['name']}_zero_point", shape=(1)) |
| 39 | + self._attr_dict['delta'] = self.delta_op |
| 40 | + self._attr_dict['zero_point'] = self.zero_point_op |
| 41 | + elif self.channel_wise: |
| 42 | + # Assume the the fisrt dim of input data is channel |
| 43 | + assert(len(self._ori_op.infer_shape()[1]) == 1) |
| 44 | + ori_op_shape = self._ori_op.infer_shape()[1][0] |
| 45 | + channel_wise_shape = (ori_op_shape[0], * ([1] * (len(ori_op_shape) - 1))) |
| 46 | + self.delta_op = mx.sym.Variable( |
| 47 | + f"{self._attr_dict['name']}_delta", |
| 48 | + shape=channel_wise_shape) |
| 49 | + self.zero_point_op = mx.sym.Variable( |
| 50 | + f"{self._attr_dict['name']}_zero_point", |
| 51 | + shape=channel_wise_shape) |
| 52 | + self._attr_dict['delta'] = self.delta_op |
| 53 | + self._attr_dict['zero_point'] = self.zero_point_op |
| 54 | + else: |
| 55 | + raise TypeError |
| 56 | + |
| 57 | + def init_param(self, data: nd.NDArray): |
| 58 | + pass |
| 59 | + |
| 60 | + def _init_param_impl(self, input_data: nd.NDArray, channel_wise:bool=False): |
| 61 | + delta, zero_point = None, None |
| 62 | + if channel_wise: |
| 63 | + x_clone = input_data.copy().detach() |
| 64 | + n_channels = x_clone.shape[0] |
| 65 | + if len(x.shape) == 4: |
| 66 | + x_max = x_clone.abs().max(dim=-1)[0].max(dim=-1)[0].max(dim=-1)[0] |
| 67 | + else: |
| 68 | + x_max = x_clone.abs().max(dim=-1)[0] |
| 69 | + delta = x_max.clone() |
| 70 | + zero_point = x_max.clone() |
| 71 | + # determine the scale and zero point channel-by-channel |
| 72 | + for c in range(n_channels): |
| 73 | + delta[c], zero_point[c] = self.init_quantization_scale(x_clone[c], channel_wise=False) |
| 74 | + if len(x.shape) == 4: |
| 75 | + delta = delta.view(-1, 1, 1, 1) |
| 76 | + zero_point = zero_point.view(-1, 1, 1, 1) |
| 77 | + else: |
| 78 | + delta = delta.view(-1, 1) |
| 79 | + zero_point = zero_point.view(-1, 1) |
| 80 | + else: |
| 81 | + if 'max' in self.scale_method: |
| 82 | + x_min = min(x.min().item(), 0) |
| 83 | + x_max = max(x.max().item(), 0) |
| 84 | + if 'scale' in self.scale_method: |
| 85 | + x_min = x_min * (self.n_bits + 2) / 8 |
| 86 | + x_max = x_max * (self.n_bits + 2) / 8 |
| 87 | + |
| 88 | + x_absmax = max(abs(x_min), x_max) |
| 89 | + if self.sym: |
| 90 | + x_min, x_max = -x_absmax if x_min < 0 else 0, x_absmax |
| 91 | + |
| 92 | + delta = float(x_max - x_min) / (self.n_levels - 1) |
| 93 | + if delta < 1e-8: |
| 94 | + warnings.warn('Quantization range close to zero: [{}, {}]'.format(x_min, x_max)) |
| 95 | + delta = 1e-8 |
| 96 | + |
| 97 | + zero_point = round(-x_min / delta) |
| 98 | + delta = torch.tensor(delta).type_as(x) |
| 99 | + |
| 100 | + elif self.scale_method == 'mse': |
| 101 | + # we always use symmetric quantization in mse mode |
| 102 | + x_absmax = x.abs().max() |
| 103 | + x_min = x.min().item() |
| 104 | + best_score = 1000 |
| 105 | + for i in range(80): |
| 106 | + new_max = x_absmax * (1.0 - (i * 0.01)) |
| 107 | + x_q = self.quantize(x, new_max) |
| 108 | + # L_p norm minimization as described in LAPQ |
| 109 | + # https://arxiv.org/abs/1911.07190 |
| 110 | + score = lp_loss(x, x_q, p=2.4, reduction='all') |
| 111 | + if score < best_score: |
| 112 | + best_score = score |
| 113 | + delta = (2 * new_max) / (2 ** self.n_bits - 1) |
| 114 | + zero_point = (new_max / delta).round() if x_min < 0 else 0 |
| 115 | + # re-calculate the scale delta if zero-point is not 0, |
| 116 | + else: |
| 117 | + raise NotImplementedError |
| 118 | +# def init_param(self, data:nd.NDArray, scale_method:str='max'): |
| 119 | +# assert scale_method in _scale_methods |
| 120 | +# if self.channel_wise: |
| 121 | +# data_abs = data.abs() |
| 122 | +# data_max_per_channel = |
| 123 | + |
| 124 | + |
| 125 | + |
| 126 | +class UniformAffineQuantizer(mx.operator.CustomOp): |
| 127 | + def __init__(self, n_bits): |
| 128 | + super(UniformAffineQuantizer, self).__init__() |
| 129 | + self.n_bits = n_bits |
| 130 | + self.n_levels = 2 ** self.n_bits |
| 131 | + |
| 132 | + def forward(self, is_train, req, in_data, out_data, aux): |
| 133 | + conv_weight, delta, zero_point = in_data[0], in_data[1], in_data[2] |
| 134 | + x_int = _round_ste(conv_weight / delta) + zero_point #TODO: Zero point is hard to implemented in the Fully Quantized Conditions. |
| 135 | + x_quant = mx.nd.clip(x_int, 0, self.n_levels - 1) |
| 136 | + x_dequant = (x_quant - zero_point) * delta |
| 137 | + self.assign(out_data[0], req[0], x_dequant) |
| 138 | + |
| 139 | + def backward(self, req, out_grad, in_data, out_data, in_grad, aux): # Seems like checkpoint techs in pytorch |
| 140 | + conv_weight, delta, zero_point = _new_detached_nd(*in_data[:3])# in_data[0].copy().detach(), in_data[1].copy().detach(), in_data[2].copy().detach() |
| 141 | + conv_weight.attach_grad() |
| 142 | + delta.attach_grad() |
| 143 | + zero_point.attach_grad() |
| 144 | + with mx.autograd.record(): |
| 145 | + x_int = _round_ste(conv_weight / delta) + zero_point |
| 146 | + x_quant = mx.nd.clip(x_int, 0, self.n_levels - 1) |
| 147 | + x_dequant = (x_quant - zero_point) * delta |
| 148 | + x_dequant.backward(_new_detached_nd(out_grad[0])[0]) |
| 149 | + |
| 150 | + self.assign(in_grad[0], req[0], conv_weight.grad) |
| 151 | + self.assign(in_grad[1], req[1], delta.grad) |
| 152 | + self.assign(in_grad[2], req[2], zero_point.grad) |
| 153 | + |
| 154 | + |
| 155 | +@mx.operator.register(QUANT_OP_PREFIX + "UniformAffineQuantizer") |
| 156 | +class UniformAffineQuantizerProp(mx.operator.CustomOpProp): |
| 157 | + def __init__(self, n_bits): |
| 158 | + super(UniformAffineQuantizerProp, self).__init__() |
| 159 | + n_bits = n_bits if type(n_bits) is int else int(n_bits) |
| 160 | + |
| 161 | + assert 2 <= n_bits <= 32, 'bitwidth not supported' |
| 162 | + self.n_bits = n_bits |
| 163 | + |
| 164 | + def list_arguments(self): |
| 165 | + return ['data', 'delta', 'zero_point'] |
| 166 | + |
| 167 | + def list_outputs(self): |
| 168 | + return ['output'] |
| 169 | + |
| 170 | + def infer_shape(self, in_shape): |
| 171 | + assert(len(in_shape)==3) |
| 172 | + return [*in_shape], [in_shape[0]], [] |
| 173 | + |
| 174 | + def infer_type(self, in_type): |
| 175 | + return [*in_type], [in_type[0]], [] |
| 176 | + |
| 177 | + def create_operator(self, ctx, shapes, dtypes): |
| 178 | + return UniformAffineQuantizer(n_bits=self.n_bits) |
| 179 | + |
0 commit comments