Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
MappingType,
quantize_,
)
from torchao.quantization.granularity import PerGroup
from torchao.quantization.granularity import PerAxis, PerGroup
from torchao.quantization.qat import IntxFakeQuantizeConfig, QATConfig
from torchao.quantization.quantize_.common import PackingFormat
from torchao.quantization.utils import compute_error
from torchao.utils import torch_version_at_least
from torchao.utils import torch_version_at_least, unwrap_tensor_subclass


@unittest.skipIf(not torch_version_at_least("2.7.0"), "Need pytorch 2.7+")
Expand Down Expand Up @@ -156,7 +156,7 @@ def test_export_int8_dyn_act_intx_weight_config(self):
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(64),
weight_granularity=PerAxis(0),
weight_mapping_type=MappingType.SYMMETRIC,
packing_format=PackingFormat.UNPACKED_TO_INT8,
version=2,
Expand All @@ -169,17 +169,52 @@ def test_export_int8_dyn_act_intx_weight_config(self):
exported_results = exported.module()(activations)
self.assertTrue(torch.allclose(eager_results, exported_results))

expected_lines = [
"torch.ops.torchao.choose_qparams_affine.default",
"torch.ops.torchao.quantize_affine.default",
"torch.ops.torchao.dequantize_affine.default",
"torch.ops.torchao.dequantize_affine.default",
"torch.ops.aten.linear.default",
expected_counts = {
"torch.ops.torchao.choose_qparams_affine.default": 1,
"torch.ops.torchao.quantize_affine.default": 1,
"torch.ops.torchao.dequantize_affine.default": 2,
"torch.ops.aten.linear.default": 1,
"torch.ops.aten.reshape.default": 0,
}
for line, count in expected_counts.items():
FileCheck().check_count(line, count, exactly=True).run(
exported.graph_module.code
)

def test_export_int8_dyn_act_intx_weight_config_with_unwrap(self):
layers = [
torch.nn.Linear(512, 256, bias=False),
]
for line in expected_lines:
count = 1
if line == "torch.ops.torchao.dequantize_affine.default":
count = 2
model = torch.nn.Sequential(*layers)
activations = torch.randn(1, 512, dtype=torch.float32)

quantize_(
model,
Int8DynamicActivationIntxWeightConfig(
weight_dtype=torch.int4,
weight_granularity=PerGroup(64),
weight_mapping_type=MappingType.SYMMETRIC,
packing_format=PackingFormat.UNPACKED_TO_INT8,
version=2,
),
)
eager_results = model(activations)

unwrap_tensor_subclass(model)

exported = torch.export.export(model, (activations,))

exported_results = exported.module()(activations)
self.assertTrue(torch.allclose(eager_results, exported_results))

expected_counts = {
"torch.ops.torchao.choose_qparams_affine.default": 1,
"torch.ops.torchao.quantize_affine.default": 1,
"torch.ops.torchao.dequantize_affine.default": 2,
"torch.ops.aten.linear.default": 1,
"torch.ops.aten.reshape.default": 0,
}
for line, count in expected_counts.items():
FileCheck().check_count(line, count, exactly=True).run(
exported.graph_module.code
)
Expand Down
2 changes: 1 addition & 1 deletion torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -833,7 +833,7 @@ def _int8_dynamic_activation_intx_weight_quantize_tensor(weight, bias, config):
block_size,
weight_dtype,
mapping_type=weight_mapping_type,
apply_int8_act_asym_per_token_quant=True,
activation_quantization="int8_asym_per_token",
)
if weight_scale_dtype is not None and weight_scale_dtype != weight.dtype:
_adjust_scale_dtype_in_intx_unpacked_tensor(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torchao.quantization.quant_primitives import _DTYPE_TO_BIT_WIDTH
from torchao.quantization.quantize_.workflows.intx.intx_unpacked_to_int8_tensor import (
IntxUnpackedToInt8Tensor,
IntxUnpackedToInt8TensorActivationQuantization,
)
from torchao.utils import (
TorchAOBaseTensor,
Expand Down Expand Up @@ -144,7 +145,10 @@ def from_intx_unpacked_to_int8_tensor(
compute_target = ComputeTarget[compute_target.upper()]

# Extract data from IntxUnpackedToInt8Tensor
assert tensor.apply_int8_act_asym_per_token_quant
assert (
tensor.activation_quantization
== IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN
)
qdata, scale, zero_point = tensor.qdata, tensor.scale, tensor.zero_point
bit_width = _DTYPE_TO_BIT_WIDTH[tensor.target_dtype]
dtype = tensor.dtype
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
# LICENSE file in the root directory of this source tree.


from typing import List, Tuple
import enum
from typing import List, Optional, Tuple

import torch
from torch.utils._python_dispatch import return_and_correct_aliasing
Expand All @@ -32,6 +33,14 @@
_FLOAT_TYPES: List[torch.dtype] = [torch.float16, torch.bfloat16, torch.float32]


class IntxUnpackedToInt8TensorActivationQuantization(str, enum.Enum):
"""
This applies int8 asymmetric activation quantization per token.
"""

INT8_ASYM_PER_TOKEN = "int8_asym_per_token"


class IntxUnpackedToInt8Tensor(TorchAOBaseTensor):
"""
intx quantization with unpacked format. Subbyte quantized data is represented as int8.
Expand All @@ -55,15 +64,15 @@ class IntxUnpackedToInt8Tensor(TorchAOBaseTensor):
target_dtype: this determines the quant_min/quant_max of the qdata (can be torch.int1, ..., torch.int8)
block_size: the block size for quantization, representing the granularity, for example groupwise quantization will have block_size (1, group_size)
dtype: the dtype of the dequantized Tensor
apply_int8_act_asym_per_token_quant: bool, whether to apply activation quantization to the dequantized Tensor during linear. Use False for weight-only quantization
activation_quantization: Optional[IntxUnpackedToInt8TensorActivationQuantization] = None, kind of activation quantization to apply. Default is None, which means weight-only quantization
"""

tensor_data_names = ["qdata", "scale", "zero_point"]
tensor_attribute_names = [
"target_dtype",
"block_size",
"dtype",
"apply_int8_act_asym_per_token_quant",
"activation_quantization",
]

def __new__(
Expand All @@ -74,7 +83,7 @@ def __new__(
target_dtype,
block_size,
dtype,
apply_int8_act_asym_per_token_quant,
activation_quantization,
):
kwargs = {}
kwargs["device"] = qdata.device
Expand All @@ -91,7 +100,7 @@ def __init__(
target_dtype,
block_size,
dtype,
apply_int8_act_asym_per_token_quant,
activation_quantization,
):
super().__init__()
assert qdata.dtype == torch.int8, (
Expand All @@ -113,8 +122,14 @@ def __init__(
for i in range(len(block_size)):
assert qdata.shape[i] % block_size[i] == 0
n_blocks.append(qdata.shape[i] // block_size[i])
scale = scale.reshape(*n_blocks)
zero_point = zero_point.reshape(*n_blocks)

# Assert shapes
assert scale.shape == tuple(n_blocks), (
f"Expected scale to have shape {n_blocks} (inferred from block_size={block_size}), but got {scale.shape}"
)
assert zero_point.shape == tuple(n_blocks), (
f"Expected zero_point to have shape {n_blocks} (inferred from block_size={block_size}), but got {zero_point.shape}"
)

assert dtype in _FLOAT_TYPES, (
f"dtype must be one of {_FLOAT_TYPES}, but got {dtype}"
Expand All @@ -126,10 +141,10 @@ def __init__(

self.target_dtype = target_dtype
self.block_size = block_size
self.apply_int8_act_asym_per_token_quant = apply_int8_act_asym_per_token_quant
self.activation_quantization = activation_quantization

def _quantization_type(self):
return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, apply_int8_act_asym_per_token_quant={self.apply_int8_act_asym_per_token_quant}"
return f"target_dtype={self.target_dtype}, block_size={self.block_size}, shape={self.shape}, dtype={self.dtype}, device={self.device}, activation_quantization={self.activation_quantization}"

def _has_float_zero_point(self) -> bool:
return self.zero_point.dtype in _FLOAT_TYPES
Expand All @@ -148,7 +163,7 @@ def to(self, *args, **kwargs):
self.target_dtype,
self.block_size,
dtype,
self.apply_int8_act_asym_per_token_quant,
self.activation_quantization,
)

@classmethod
Expand All @@ -159,7 +174,9 @@ def from_hp(
target_dtype: torch.dtype,
*,
mapping_type: MappingType = MappingType.SYMMETRIC,
apply_int8_act_asym_per_token_quant: bool = False,
activation_quantization: Optional[
IntxUnpackedToInt8TensorActivationQuantization
] = None,
):
"""
Create an IntxUnpackedToInt8Tensor from a high-precision tensor
Expand All @@ -183,14 +200,24 @@ def from_hp(
quant_min=qmin,
quant_max=qmax,
)

# Reshape scale and zero_point to be compatible with block_size
# This is asserted in IntxUnpackedToInt8Tensor's __init__
n_blocks = []
for i in range(len(block_size)):
assert qdata.shape[i] % block_size[i] == 0
n_blocks.append(qdata.shape[i] // block_size[i])
scale = scale.reshape(*n_blocks)
zero_point = zero_point.reshape(*n_blocks)

return IntxUnpackedToInt8Tensor(
qdata=qdata,
scale=scale,
zero_point=zero_point,
target_dtype=target_dtype,
block_size=block_size,
dtype=hp_tensor.dtype,
apply_int8_act_asym_per_token_quant=apply_int8_act_asym_per_token_quant,
activation_quantization=activation_quantization,
)

def dequantize(self):
Expand All @@ -207,6 +234,42 @@ def dequantize(self):
)


def _apply_int8_act_asym_per_token_quant_dequant(hp_tensor):
target_dtype = torch.int8
mapping_type = MappingType.ASYMMETRIC
block_size = _get_per_token_block_size(hp_tensor)
qmin, qmax = _DTYPE_TO_QVALUE_BOUNDS[target_dtype]
scale, zero_point = choose_qparams_affine(
hp_tensor,
mapping_type,
block_size,
target_dtype=target_dtype,
quant_min=qmin,
quant_max=qmax,
zero_point_dtype=torch.int8,
)
qdata = quantize_affine(
hp_tensor,
block_size,
scale,
zero_point,
output_dtype=torch.int8,
quant_min=qmin,
quant_max=qmax,
)
dequantized_affine = dequantize_affine(
qdata,
block_size,
scale,
zero_point,
torch.int8,
qmin,
qmax,
output_dtype=hp_tensor.dtype,
)
return dequantized_affine


implements = IntxUnpackedToInt8Tensor.implements


Expand All @@ -220,13 +283,16 @@ def _(func, types, args, kwargs):
assert isinstance(weight_tensor, IntxUnpackedToInt8Tensor)

# Apply dynamic activation quant
if weight_tensor.apply_int8_act_asym_per_token_quant:
input_tensor = IntxUnpackedToInt8Tensor.from_hp(
hp_tensor=input_tensor,
block_size=_get_per_token_block_size(input_tensor),
target_dtype=torch.int8,
mapping_type=MappingType.ASYMMETRIC,
).dequantize()
if weight_tensor.activation_quantization is not None:
if (
weight_tensor.activation_quantization
== IntxUnpackedToInt8TensorActivationQuantization.INT8_ASYM_PER_TOKEN
):
input_tensor = _apply_int8_act_asym_per_token_quant_dequant(input_tensor)
else:
raise NotImplementedError(
f"Unsupported activation quantization: {weight_tensor.activation_quantization}"
)

weight_tensor = weight_tensor.dequantize()
return torch.nn.functional.linear(input_tensor, weight_tensor, bias)
Expand Down Expand Up @@ -293,12 +359,14 @@ def _(func, types, args, kwargs):
self.target_dtype,
new_block_size,
self.dtype,
self.apply_int8_act_asym_per_token_quant,
self.activation_quantization,
)
return return_and_correct_aliasing(func, args, kwargs, new)


IntxUnpackedToInt8Tensor.__module__ = "torchao.quantization"

# Allow a model with IntxUnpackedToInt8Tensor weights to be loaded with `weights_only=True`
torch.serialization.add_safe_globals([IntxUnpackedToInt8Tensor])
torch.serialization.add_safe_globals(
[IntxUnpackedToInt8Tensor, IntxUnpackedToInt8TensorActivationQuantization]
)
Loading