Open
Description
Context
Current ZeroPointDomain is bound to the layout
ao/torchao/quantization/quant_api.py
Lines 607 to 615 in 2ba1a61
Ideally, we should allow the data types of zero points to be specified as arguments. There are two main benefits:
- Memory Efficiency: Integer zero points can significantly reduce memory footprint with the designed kernel. For example, with a group_size=32, using int4 zero points instead of bf16 can save 0.375 bits per element.
- Community Compatibility: Integer zero points have a well-established ecosystem of recipes and kernels. Making this option available in TorchAO would allow us to leverage these resources more effectively.
Proposals
Add an optional argument to let users specify the data types of zero points:
diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 476cc229..5cd35648 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -568,8 +568,8 @@ def int8_dynamic_activation_int4_weight(
def int4_weight_only(
- group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
-):
+ group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False,
+ zero_point_dtype=torch.bfloat16, zero_point_domain=ZeroPointDomain.INT):
"""
Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
"tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -587,6 +587,8 @@ def int4_weight_only(
size is more fine grained, choices are [256, 128, 64, 32]
`layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
`use_hqq`: whether to use hqq or default quantization mode, default is False
+ `zero_point_dtype`: the dtype of zero point, default is torch.bfloat16
+ `zero_point_domain`: the domain of zero point, default is ZeroPointDomain.INT
"""
def apply_int4_weight_only_quant(weight):
@@ -603,8 +605,8 @@ def int4_weight_only(
quant_max = 15
eps = 1e-6
preserve_zero = False
- zero_point_dtype = torch.bfloat16
- zero_point_domain = ZeroPointDomain.FLOAT
+ zero_point_dtype = zero_point_dtype if zero_point_dtype else torch.bfloat16
+ zero_point_domain = zero_point_domain if zero_point_domain else ZeroPointDomain.INT
# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
Meanwhile we will overload _weight_int4pack_mm
with zero points and scales as separate tensors.
An example usage
from torchao.quantization.quant_api import (
quantize_,
int8_dynamic_activation_int8_weight,
int4_weight_only,
int8_weight_only
)
quantize_(m, int4_weight_only(default_fp_zp=True/False))
Metadata
Metadata
Assignees
Labels
No labels