23
23
"choose_qparams_affine" ,
24
24
"quantize_affine" ,
25
25
"dequantize_affine" ,
26
+ "fake_quantize_affine" ,
26
27
]
27
28
28
29
class MappingType (Enum ):
@@ -203,14 +204,34 @@ def _quantize_affine(
203
204
output_dtype : torch .dtype ,
204
205
quant_min : Optional [int ] = None ,
205
206
quant_max : Optional [int ] = None ,
206
- zero_point_domain : str = " INT" ,
207
+ zero_point_domain : str = ZeroPointDomain . INT . name ,
207
208
) -> torch .Tensor :
208
209
"""op definition that has compatible signatures with custom op library
209
210
"""
211
+ quant_min , quant_max = _get_and_check_qmin_qmax (output_dtype , quant_min , quant_max )
212
+ return _quantize_affine_no_dtype_cast (
213
+ input ,
214
+ block_size ,
215
+ scale ,
216
+ zero_point ,
217
+ quant_min ,
218
+ quant_max ,
219
+ zero_point_domain ,
220
+ ).to (output_dtype )
221
+
222
+
223
+ def _quantize_affine_no_dtype_cast (
224
+ input : torch .Tensor ,
225
+ block_size : List [int ],
226
+ scale : torch .Tensor ,
227
+ zero_point : Optional [torch .Tensor ],
228
+ quant_min : int ,
229
+ quant_max : int ,
230
+ zero_point_domain : str = ZeroPointDomain .INT .name ,
231
+ ) -> torch .Tensor :
210
232
# TODO: validations
211
233
# TODO: validate scale/zero_point dimensions are compatible with block_size
212
234
assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
213
- quant_min , quant_max = _get_and_check_qmin_qmax (output_dtype , quant_min , quant_max )
214
235
shape_for_reduction , reduction_dims = _get_reduction_params (block_size , input .size ())
215
236
original_shape = input .shape
216
237
input = input .view (shape_for_reduction )
@@ -224,7 +245,7 @@ def _quantize_affine(
224
245
if zero_point_domain == ZeroPointDomain .INT .name :
225
246
quant = torch .clamp (
226
247
torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
227
- ). to ( output_dtype )
248
+ )
228
249
else :
229
250
assert zero_point_domain == ZeroPointDomain .FLOAT .name
230
251
mid_point = (quant_max + quant_min + 1 ) / 2
@@ -233,11 +254,12 @@ def _quantize_affine(
233
254
torch .clamp (
234
255
torch .round ((input - min_val ) / scale ),
235
256
quant_min , quant_max )
236
- ). to ( output_dtype )
257
+ )
237
258
quant = quant .view (original_shape )
238
259
239
260
return quant
240
261
262
+
241
263
def dequantize_affine (
242
264
input : torch .Tensor ,
243
265
block_size : Tuple [int , ...],
@@ -283,6 +305,7 @@ def dequantize_affine(
283
305
output_dtype = output_dtype ,
284
306
)
285
307
308
+
286
309
@register_custom_op
287
310
def _dequantize_affine (
288
311
input : torch .Tensor ,
@@ -292,7 +315,7 @@ def _dequantize_affine(
292
315
input_dtype : torch .dtype ,
293
316
quant_min : Optional [int ] = None ,
294
317
quant_max : Optional [int ] = None ,
295
- zero_point_domain : str = " INT" ,
318
+ zero_point_domain : str = ZeroPointDomain . INT . name ,
296
319
output_dtype : torch .dtype = torch .float32 ,
297
320
) -> torch .Tensor :
298
321
"""op definition that has compatible signatures with custom op library
@@ -303,7 +326,28 @@ def _dequantize_affine(
303
326
assert input .dtype == input_dtype , f"Expected: { input_dtype } , got: { input .dtype } "
304
327
assert output_dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported output dtype: { output_dtype } "
305
328
quant_min , quant_max = _get_and_check_qmin_qmax (input_dtype , quant_min , quant_max )
329
+ return _dequantize_affine_no_dtype_check (
330
+ input ,
331
+ block_size ,
332
+ scale ,
333
+ zero_point ,
334
+ quant_min ,
335
+ quant_max ,
336
+ zero_point_domain ,
337
+ output_dtype ,
338
+ )
339
+
306
340
341
+ def _dequantize_affine_no_dtype_check (
342
+ input : torch .Tensor ,
343
+ block_size : List [int ],
344
+ scale : torch .Tensor ,
345
+ zero_point : Optional [torch .Tensor ],
346
+ quant_min : int ,
347
+ quant_max : int ,
348
+ zero_point_domain : str = ZeroPointDomain .INT .name ,
349
+ output_dtype : torch .dtype = torch .float32 ,
350
+ ) -> torch .Tensor :
307
351
shape_for_reduction , reduction_dims = _get_reduction_params (block_size , input .size ())
308
352
original_shape = input .shape
309
353
input = input .view (shape_for_reduction )
@@ -335,6 +379,62 @@ def _dequantize_affine(
335
379
336
380
return dequant .view (original_shape ).to (output_dtype )
337
381
382
+
383
+ def fake_quantize_affine (
384
+ input : torch .Tensor ,
385
+ block_size : Tuple [int , ...],
386
+ scale : torch .Tensor ,
387
+ zero_point : Optional [torch .Tensor ],
388
+ quant_dtype : torch .dtype ,
389
+ quant_min : Optional [int ] = None ,
390
+ quant_max : Optional [int ] = None ,
391
+ zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
392
+ ) -> torch .Tensor :
393
+ """
394
+ General fake quantize op for quantization-aware training (QAT).
395
+ This is equivalent to calling `quantize_affine` + `dequantize_affine`
396
+ but without the dtype casts.
397
+
398
+ Args:
399
+ input (torch.Tensor): original float32, float16 or bfloat16 Tensor
400
+ block_size: (Tuple[int, ...]): granularity of quantization, this means the size of the tensor elements that's sharing the same qparam
401
+ e.g. when size is the same as the input tensor dimension, we are using per tensor quantization
402
+ scale (float): quantization parameter for affine quantization
403
+ zero_point (int): quantization parameter for affine quantization
404
+ quant_dtype (torch.dtype): desired quantized dtype for determining and validating quant_min and quant_max values.
405
+ quant_min (Optional[int]): minimum quantized value for output Tensor, if not specified, it will be derived from dtype
406
+ quant_max (Optional[int]): maximum quantized value for output Tensor, if not specified, it will be derived from dtype
407
+ zero_point_domain (ZeroPointDomain): the domain that zero_point is in, should be eitehr integer or float
408
+ if zero_point is in integer domain, zero point is added to the quantized integer value during
409
+ quantization
410
+ if zero_point is in floating point domain, zero point is subtracted from the floating point (unquantized)
411
+ value during quantization
412
+ default is ZeroPointDomain.INT
413
+ """
414
+ input_dtype = input .dtype
415
+ quant_min , quant_max = _get_and_check_qmin_qmax (quant_dtype , quant_min , quant_max )
416
+ q = _quantize_affine_no_dtype_cast (
417
+ input ,
418
+ block_size ,
419
+ scale ,
420
+ zero_point ,
421
+ quant_min ,
422
+ quant_max ,
423
+ zero_point_domain .name ,
424
+ )
425
+ dq = _dequantize_affine_no_dtype_check (
426
+ q ,
427
+ block_size ,
428
+ scale ,
429
+ zero_point ,
430
+ quant_min ,
431
+ quant_max ,
432
+ zero_point_domain .name ,
433
+ output_dtype = input_dtype ,
434
+ )
435
+ return dq
436
+
437
+
338
438
def choose_qparams_affine (
339
439
input : torch .Tensor ,
340
440
mapping_type : MappingType ,
0 commit comments