44# This source code is licensed under the license found in the
55# LICENSE file in the root directory of this source tree.
66
7- from enum import Enum
7+ from enum import Enum , auto
88from typing import List , Optional , Tuple , Dict
99import torch
1010
1111from torchao .kernel .intmm import int_scaled_matmul
1212from torchao .kernel .intmm import safe_int_mm
13- from torchao .utils import TORCH_VERSION_AFTER_2_3
13+ from torchao .utils import (
14+ TORCH_VERSION_AFTER_2_3 ,
15+ TORCH_VERSION_AFTER_2_5 ,
16+ )
17+ from torchao .utils import _register_custom_op
1418
1519
1620__all__ = [
@@ -34,17 +38,17 @@ class MappingType(Enum):
3438 based on this mapping
3539 e.g. scale = (10.2 - (-3.5)) / (7 - (-8))
3640 """
37- SYMMETRIC = 0
38- ASYMMETRIC = 1
41+ SYMMETRIC = auto ()
42+ ASYMMETRIC = auto ()
3943
4044class ZeroPointDomain (Enum ):
4145 """Enum that indicate whether zero_point is in integer domain or floating point domain
4246
4347 integer domain: quantized_val = (float_val / scale) (integer) + zero_point (integer)
4448 float domain: quantized_val = (float_val - (zero_point (float) - scale * mid_point)) / scale
4549 """
46- INT = 0
47- FLOAT = 1
50+ INT = auto ()
51+ FLOAT = auto ()
4852
4953"""
5054Map from dtype to the bound value of integers
@@ -69,6 +73,10 @@ class ZeroPointDomain(Enum):
6973 })
7074
7175
76+ quant_lib = torch .library .Library ("quant" , "FRAGMENT" )
77+
78+ register_custom_op = _register_custom_op (quant_lib )
79+
7280# TODO: decide on if we want to allow custom quant_min/quant_max here
7381def _get_and_check_qmin_qmax (dtype , quant_min , quant_max ):
7482 """Get quant_min and quant_max args based on dtype and also
@@ -140,7 +148,7 @@ def quantize_affine(
140148 quant_min : Optional [int ] = None ,
141149 quant_max : Optional [int ] = None ,
142150 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
143- ):
151+ ) -> torch . Tensor :
144152 """
145153 Args:
146154 input (torch.Tensor): original float32, float16 or bfloat16 Tensor
@@ -174,6 +182,31 @@ def quantize_affine(
174182 Output:
175183 quantized tensor with requested dtype
176184 """
185+ return _quantize_affine (
186+ input ,
187+ block_size ,
188+ scale ,
189+ zero_point ,
190+ output_dtype ,
191+ quant_min ,
192+ quant_max ,
193+ zero_point_domain .name ,
194+ )
195+
196+
197+ @register_custom_op
198+ def _quantize_affine (
199+ input : torch .Tensor ,
200+ block_size : List [int ],
201+ scale : torch .Tensor ,
202+ zero_point : Optional [torch .Tensor ],
203+ output_dtype : torch .dtype ,
204+ quant_min : Optional [int ] = None ,
205+ quant_max : Optional [int ] = None ,
206+ zero_point_domain : str = "INT" ,
207+ ) -> torch .Tensor :
208+ """op definition that has compatible signatures with custom op library
209+ """
177210 # TODO: validations
178211 # TODO: validate scale/zero_point dimensions are compatible with block_size
179212 assert input .dtype in [torch .float32 , torch .float16 , torch .bfloat16 ], f"Unsupported input dtype: { input .dtype } "
@@ -188,12 +221,12 @@ def quantize_affine(
188221 if zero_point is not None :
189222 zero_point = zero_point .view (shape_after_reduction )
190223
191- if zero_point_domain == ZeroPointDomain .INT :
224+ if zero_point_domain == ZeroPointDomain .INT . name :
192225 quant = torch .clamp (
193226 torch .round (input * (1.0 / scale )) + zero_point , quant_min , quant_max
194227 ).to (output_dtype )
195228 else :
196- assert zero_point_domain == ZeroPointDomain .FLOAT
229+ assert zero_point_domain == ZeroPointDomain .FLOAT . name
197230 mid_point = (quant_max + quant_min + 1 ) / 2
198231 min_val = zero_point - scale * mid_point
199232 quant = (
@@ -216,7 +249,7 @@ def dequantize_affine(
216249 zero_point_domain : ZeroPointDomain = ZeroPointDomain .INT ,
217250 * ,
218251 output_dtype : torch .dtype = torch .float32 ,
219- ):
252+ ) -> torch . Tensor :
220253 """
221254 Args:
222255 input (torch.Tensor): quantized tensor, should match the dtype `dtype` argument
@@ -238,6 +271,32 @@ def dequantize_affine(
238271 Output:
239272 dequantized Tensor, with requested dtype or fp32
240273 """
274+ return _dequantize_affine (
275+ input ,
276+ block_size ,
277+ scale ,
278+ zero_point ,
279+ input_dtype ,
280+ quant_min ,
281+ quant_max ,
282+ zero_point_domain .name ,
283+ output_dtype = output_dtype ,
284+ )
285+
286+ @register_custom_op
287+ def _dequantize_affine (
288+ input : torch .Tensor ,
289+ block_size : List [int ],
290+ scale : torch .Tensor ,
291+ zero_point : Optional [torch .Tensor ],
292+ input_dtype : torch .dtype ,
293+ quant_min : Optional [int ] = None ,
294+ quant_max : Optional [int ] = None ,
295+ zero_point_domain : str = "INT" ,
296+ output_dtype : torch .dtype = torch .float32 ,
297+ ) -> torch .Tensor :
298+ """op definition that has compatible signatures with custom op library
299+ """
241300
242301 # TODO: validations
243302 # TODO: validate scale/zero_point dimensions are compatible with block_size
@@ -255,16 +314,16 @@ def dequantize_affine(
255314 if zero_point is not None :
256315 zero_point = zero_point .view (shape_after_reduction )
257316
258- if zero_point_domain == ZeroPointDomain .INT :
317+ if zero_point_domain == ZeroPointDomain .INT . name :
259318 # Force a copy to avoid input modification due
260319 # to upcoming in-place operations.
261320 dequant = input .to (torch .int32 , copy = True )
262321 if zero_point is not None :
263- dequant -= zero_point .to (torch .int32 )
322+ dequant = dequant - zero_point .to (torch .int32 )
264323 dequant = dequant .to (output_dtype )
265- dequant *= scale
324+ dequant = dequant * scale
266325 else :
267- assert zero_point_domain == ZeroPointDomain .FLOAT , f"Unexpected zero point domain: { zero_point_domain } "
326+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , f"Unexpected zero point domain: { zero_point_domain } "
268327 mid_point = (quant_max + quant_min + 1 ) / 2
269328 # This should allocate new memory and avoid input modification
270329 dequant = input - mid_point
@@ -320,8 +379,38 @@ def choose_qparams_affine(
320379 Output:
321380 Tuple of scales and zero_points Tensor with requested dtype
322381 """
382+ return _choose_qparams_affine (
383+ input ,
384+ mapping_type .name ,
385+ block_size ,
386+ target_dtype ,
387+ quant_min ,
388+ quant_max ,
389+ eps ,
390+ scale_dtype ,
391+ zero_point_dtype ,
392+ preserve_zero ,
393+ zero_point_domain .name
394+ )
395+
396+ @register_custom_op
397+ def _choose_qparams_affine (
398+ input : torch .Tensor ,
399+ mapping_type : str ,
400+ block_size : List [int ],
401+ target_dtype : torch .dtype ,
402+ quant_min : Optional [int ] = None ,
403+ quant_max : Optional [int ] = None ,
404+ eps : Optional [float ] = None ,
405+ scale_dtype : Optional [torch .dtype ] = None ,
406+ zero_point_dtype : Optional [torch .dtype ] = None ,
407+ preserve_zero : bool = True ,
408+ zero_point_domain : str = "INT" ,
409+ ) -> Tuple [torch .Tensor , torch .Tensor ]:
410+ """op definition that has compatible signatures with custom op library
411+ """
323412 quant_min , quant_max = _get_and_check_qmin_qmax (target_dtype , quant_min , quant_max )
324- assert mapping_type in [MappingType .SYMMETRIC , MappingType .ASYMMETRIC ], f"Unsupported mapping type: { mapping_type } "
413+ assert mapping_type in [MappingType .SYMMETRIC . name , MappingType .ASYMMETRIC . name ], f"Unsupported mapping type: { mapping_type } "
325414
326415 if scale_dtype is None :
327416 scale_dtype = input .dtype
@@ -342,21 +431,22 @@ def choose_qparams_affine(
342431 min_val_neg = min_val
343432 max_val_pos = max_val
344433
345- if mapping_type == MappingType .SYMMETRIC :
434+ if mapping_type == MappingType .SYMMETRIC . name :
346435 max_val_pos = torch .max (- min_val_neg , max_val_pos )
347436 scale = max_val_pos / (float (quant_max - quant_min ) / 2 )
348437 if not preserve_zero :
349438 raise ValueError ("preserve_zero == False is not supported for symmetric quantization" )
350- if zero_point_domain != ZeroPointDomain .INT :
439+ if zero_point_domain != ZeroPointDomain .INT . name :
351440 raise ValueError ("zero_point_domain != ZeroPointDomain.INT is not supported for symmetric quantization" )
352441 zero_point = torch .full_like (scale , int ((quant_max + quant_min + 1 ) / 2 ))
353442 else :
443+ assert mapping_type == MappingType .ASYMMETRIC .name
354444 scale = (max_val_pos - min_val_neg ) / float (quant_max - quant_min )
355445 if preserve_zero :
356446 zero_point = quant_min - torch .round (min_val_neg / scale )
357447 zero_point = torch .clamp (zero_point , quant_min , quant_max )
358448 else :
359- assert zero_point_domain == ZeroPointDomain .FLOAT , "if not preserve_zero, zero_point must be in FLOAT domain"
449+ assert zero_point_domain == ZeroPointDomain .FLOAT . name , "if not preserve_zero, zero_point must be in FLOAT domain"
360450 mid_point = (quant_max + quant_min + 1 ) / 2
361451 zero_point = min_val_neg + scale * mid_point
362452
0 commit comments