@@ -142,7 +142,10 @@ def disable_fake_quant(self):
142
142
def forward (self , x : torch .Tensor ) -> torch .Tensor :
143
143
# activations: int8 dynamic asymmetric quant
144
144
if self ._fake_quant_enabled :
145
- (act_scales , act_zp ) = _choose_qparams_per_token_asymmetric (
145
+ (
146
+ act_scales ,
147
+ act_zp
148
+ ) = torch .ops .quantized_decomposed ._choose_qparams_per_token_asymmetric_impl (
146
149
x , torch .int8 , # dtype not used
147
150
)
148
151
(act_qmin , act_qmax ) = self ._get_qmin_qmax (8 )
@@ -269,49 +272,3 @@ def fake_quantize_per_token(
269
272
return _GenericFakeQuantize .apply (
270
273
input , scales , zero_points , quant_min , quant_max ,
271
274
)
272
-
273
- # TODO: This is copied from torch/ao/quantization/fx/_decomposed.py.
274
- # The version in pytorch does not have backward support yet so we add
275
- # it here for now until https://github.com/pytorch/pytorch/pull/123452
276
- # is landed.
277
- def _choose_qparams_per_token_asymmetric (
278
- input : torch .Tensor ,
279
- dtype : torch .dtype ,
280
- ) -> Tuple [torch .Tensor , torch .Tensor ]:
281
- """Choose quantization parameters for per token quantization. This means for a N dimension Tensor
282
- (M1, M2, ...Mn, N), we calculate scales/zero_points for each N elements and quantize
283
- every N elements with the same quantization parameter. The dimension for scales/zero_points
284
- will be (M1 * M2 ... * Mn)
285
-
286
- Args:
287
- input (torch.Tensor): original float32/float16 Tensor
288
- dtype (torch.dtype): dtype (e.g. torch.uint8) for input Tensor
289
-
290
- Returns:
291
- scales and zero_points, both float32 Tensors
292
- """
293
- # Based on https://github.com/google/XNNPACK/blob/df156f0cf3db5a4576cc711123eeb54915f82ffc/src/xnnpack/quantization.h#L18
294
- qmin , qmax = - 128 , 127
295
- min_val = torch .amin (input , dim = - 1 , keepdim = True )
296
- max_val = torch .amax (input , dim = - 1 , keepdim = True )
297
- min_val_neg = torch .min (min_val , torch .zeros_like (min_val ))
298
- max_val_pos = torch .max (max_val , torch .zeros_like (max_val ))
299
- eps = torch .finfo (torch .float32 ).eps # use xnnpack eps?
300
-
301
- # scale
302
- scale = (max_val_pos - min_val_neg ) / float (qmax - qmin )
303
- scale = scale .clamp (min = eps )
304
-
305
- # zero point
306
- descaled_min = min_val_neg / scale
307
- descaled_max = max_val_pos / scale
308
- zero_point_from_min_error = qmin + descaled_min
309
- zero_point_from_max_error = qmax + descaled_max
310
- zero_point = torch .where (
311
- zero_point_from_min_error + zero_point_from_max_error > 0 ,
312
- qmin - descaled_min ,
313
- qmax - descaled_max ,
314
- )
315
- zero_point = torch .clamp (zero_point , qmin , qmax ).round ()
316
-
317
- return scale .to (torch .float32 ), zero_point .to (torch .float32 )
0 commit comments