Skip to content

ZeroPointDomain as an arguments #1264

Open
@airMeng

Description

@airMeng

Context

Current ZeroPointDomain is bound to the layout

zero_point_domain = ZeroPointDomain.FLOAT
# Sparse Marlin only supports symmetric quantization.
# NOTE: If we start having lots of layouts that require different configurations,
# we should consider moving this logic somewhere else.
if isinstance(layout, MarlinSparseLayout):
mapping_type = MappingType.SYMMETRIC
preserve_zero = True
zero_point_domain = ZeroPointDomain.INT

Ideally, we should allow the data types of zero points to be specified as arguments. There are two main benefits:

  1. Memory Efficiency: Integer zero points can significantly reduce memory footprint with the designed kernel. For example, with a group_size=32, using int4 zero points instead of bf16 can save 0.375 bits per element.
  2. Community Compatibility: Integer zero points have a well-established ecosystem of recipes and kernels. Making this option available in TorchAO would allow us to leverage these resources more effectively.

Proposals

Add an optional argument to let users specify the data types of zero points:

diff --git a/torchao/quantization/quant_api.py b/torchao/quantization/quant_api.py
index 476cc229..5cd35648 100644
--- a/torchao/quantization/quant_api.py
+++ b/torchao/quantization/quant_api.py
@@ -568,8 +568,8 @@ def int8_dynamic_activation_int4_weight(


 def int4_weight_only(
-    group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False
-):
+    group_size=128, layout=TensorCoreTiledLayout(inner_k_tiles=8), use_hqq=False,
+    zero_point_dtype=torch.bfloat16, zero_point_domain=ZeroPointDomain.INT):
     """
     Applies uint4 weight-only asymmetric per-group quantization to linear layers, using
     "tensor_core_tiled" layout for speedup with tinygemm kernel
@@ -587,6 +587,8 @@ def int4_weight_only(
          size is more fine grained, choices are [256, 128, 64, 32]
         `layout`: layout type for quantized tensor, default is `TensorCoreTiledLayout(inner_k_tiles=8)`
         `use_hqq`: whether to use hqq or default quantization mode, default is False
+        `zero_point_dtype`: the dtype of zero point, default is torch.bfloat16
+        `zero_point_domain`: the domain of zero point, default is ZeroPointDomain.INT
     """

     def apply_int4_weight_only_quant(weight):
@@ -603,8 +605,8 @@ def int4_weight_only(
         quant_max = 15
         eps = 1e-6
         preserve_zero = False
-        zero_point_dtype = torch.bfloat16
-        zero_point_domain = ZeroPointDomain.FLOAT
+        zero_point_dtype = zero_point_dtype if zero_point_dtype else torch.bfloat16
+        zero_point_domain = zero_point_domain if zero_point_domain else ZeroPointDomain.INT

         # Sparse Marlin only supports symmetric quantization.
         # NOTE: If we start having lots of layouts that require different configurations,

Meanwhile we will overload _weight_int4pack_mm with zero points and scales as separate tensors.

An example usage

from torchao.quantization.quant_api import (
    quantize_,
    int8_dynamic_activation_int8_weight,
    int4_weight_only,
    int8_weight_only
)
quantize_(m, int4_weight_only(default_fp_zp=True/False))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions