Skip to content

migrate prototype/awq to configs #1853

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 16 commits into from
Mar 8, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 79 additions & 71 deletions torchao/prototype/awq/api.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,28 @@
import types
from dataclasses import dataclass

import torch

from torchao.core.config import AOBaseConfig
from torchao.dtypes import (
TensorCoreTiledLayout,
to_affine_quantized_intx,
)
from torchao.dtypes.uintx.uintx_layout import _DTYPE_TO_BIT_WIDTH, UintxLayout
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
from torchao.quantization.granularity import PerGroup
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
from torchao.quantization.quant_api import (
_linear_extra_repr,
_replace_with_custom_fn_if_matches_filter,
)
from torchao.quantization.quant_primitives import (
_DTYPE_TO_QVALUE_BOUNDS,
MappingType,
ZeroPointDomain,
)
from torchao.quantization.transform_module import (
register_quantize_module_handler,
)

from .core import (
AWQObservedLinear,
Expand Down Expand Up @@ -82,88 +92,86 @@ def replace_with_observer(layer):
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)


def _observed_linear_subclass_inserter(constructor):
@dataclass
class AWQUIntXConfig(AOBaseConfig):
"""
Replaces unquantized AWQObservedLinear instances with quantized linear instances.
Configuration for quantizing linear layers when passed into quantize_()

Args:
constructor: the function which applies quantization to the AWQObservedLinear layer
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
group_size: Quantization granularity. Use -1 for channel wise quantization
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
"""

def insert_subclass(observed_linear):
# creates the new linear layer using constructor
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias != None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(
constructor(observed_linear), requires_grad=False
)
linear.bias = observed_linear.bias
return linear
quant_dtype: torch.dtype = torch.uint4
group_size: int = 64
use_hqq: bool = False

return insert_subclass

# for bc
awq_uintx = AWQUIntXConfig

def awq_uintx(
quant_dtype: torch.dtype = torch.uint4,
group_size: int = 64,
use_hqq: bool = False,
):
"""
Quantizes linear layers when passed into quantize_()

Args:
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
group_size: Quantization granularity. Use -1 for channel wise quantization
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
"""
@register_quantize_module_handler(AWQUIntXConfig)
def _awq_uintx_transform(
module: torch.nn.Module,
config: AWQUIntXConfig,
) -> torch.nn.Module:
quant_dtype = config.quant_dtype
group_size = config.group_size
use_hqq = config.use_hqq
observed_linear = module

assert (
quant_dtype in _DTYPE_TO_BIT_WIDTH or quant_dtype == torch.uint8
), "Invalid quant_dtype. Please use torch.uint1 .. torch.uint8"

def weight_quant_func(observed_linear):
equalization_scale = observed_linear.act_obs.calculate_qparams()
# AQT config
if quant_dtype == torch.uint4:
target_dtype = torch.int32
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
else:
target_dtype = torch.uint8
eps = torch.finfo(torch.float32).eps
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
_layout = UintxLayout(quant_dtype)

mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
qw = to_affine_quantized_intx(
observed_linear.weight * equalization_scale,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
_layout=_layout,
use_hqq=use_hqq,
)
equalization_scale = observed_linear.act_obs.calculate_qparams()
# AQT config
if quant_dtype == torch.uint4:
target_dtype = torch.int32
eps = 1e-6
preserve_zero = False
zero_point_dtype = torch.bfloat16
zero_point_domain = ZeroPointDomain.FLOAT
_layout = TensorCoreTiledLayout(inner_k_tiles=8)
else:
target_dtype = torch.uint8
eps = torch.finfo(torch.float32).eps
preserve_zero = True
zero_point_dtype = torch.int64
zero_point_domain = ZeroPointDomain.INT
_layout = UintxLayout(quant_dtype)

return to_weight_tensor_with_linear_activation_scale_metadata(
qw, equalization_scale
)
mapping_type = MappingType.ASYMMETRIC
block_size = (1, group_size)
quant_min = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][0]
quant_max = _DTYPE_TO_QVALUE_BOUNDS[quant_dtype][1]
qw = to_affine_quantized_intx(
observed_linear.weight * equalization_scale,
mapping_type,
block_size,
target_dtype,
quant_min,
quant_max,
eps,
zero_point_dtype=zero_point_dtype,
preserve_zero=preserve_zero,
zero_point_domain=zero_point_domain,
_layout=_layout,
use_hqq=use_hqq,
)

qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, equalization_scale)

return _observed_linear_subclass_inserter(weight_quant_func)
linear = torch.nn.Linear(
observed_linear.in_features,
observed_linear.out_features,
observed_linear.bias != None,
device=observed_linear.weight.device,
dtype=observed_linear.weight.dtype,
)
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
linear.bias = observed_linear.bias
return linear
Loading