Skip to content

Expose FakeQuantizeConfigs in QAT quantizers #1214

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 4, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 74 additions & 27 deletions torchao/quantization/qat/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,12 +107,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.linear(x, w)


class _LegacyQATQuantizer(TwoStepQuantizer):
"""
Base class for sharing common methods across legacy QAT quantizers.
"""
def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return None

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return None


# =========================================================
# | Linear int8 dynamic activations + int4 weight QAT |
# =========================================================


class Int8DynActInt4WeightQATQuantizer(TwoStepQuantizer):
class Int8DynActInt4WeightQATQuantizer(_LegacyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have int8
dynamic per token fake quantized activations and int4 fake quantized
Expand Down Expand Up @@ -189,6 +200,12 @@ def _convert_qat_linear_8da4w(self, module: torch.nn.Module):
else:
self._convert_qat_linear_8da4w(child)

def get_activation_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_8da4w_activation_config(self.scales_precision)

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_8da4w_weight_config(self.groupsize, self.scales_precision)


class Int8DynActInt4WeightQATLinear(FakeQuantizedLinear):
"""
Expand All @@ -211,22 +228,8 @@ def __init__(
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
) -> None:
activation_config = FakeQuantizeConfig(
dtype=torch.int8,
granularity="per_token",
is_symmetric=False,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
)
weight_config = FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=groupsize,
is_symmetric=True,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
)
activation_config = _get_8da4w_activation_config(scales_precision)
weight_config = _get_8da4w_weight_config(groupsize, scales_precision)
super().__init__(
in_features,
out_features,
Expand Down Expand Up @@ -261,12 +264,43 @@ def disable_8da4w_fake_quant(mod: torch.nn.Module):
mod.disable_fake_quant()


def _get_8da4w_activation_config(qparams_precision: torch.dtype) -> FakeQuantizeConfig:
"""
Return the activation `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=torch.int8,
granularity="per_token",
is_symmetric=False,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
)


def _get_8da4w_weight_config(
group_size: int,
qparams_precision: torch.dtype,
) -> FakeQuantizeConfig:
"""
Return the weight `FakeQuantizeConfig` for `Int8DynActInt4WeightQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=TorchAODType.INT4,
group_size=group_size,
is_symmetric=True,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
)


# ===================================
# | Linear int4 weight-only QAT |
# ===================================


class Int4WeightOnlyQATQuantizer(TwoStepQuantizer):
class Int4WeightOnlyQATQuantizer(_LegacyQATQuantizer):
"""
Quantizer for performing QAT on a model, where linear layers have
int4 fake quantized grouped per channel weights.
Expand Down Expand Up @@ -348,6 +382,9 @@ def _convert_qat_linear_4w(self, module: torch.nn.Module):
else:
self._convert_qat_linear_4w(child)

def get_weight_fake_quantize_config(self) -> Optional[FakeQuantizeConfig]:
return _get_4w_weight_config(self.groupsize, self.scales_precision)


class Int4WeightOnlyQATLinear(FakeQuantizedLinear):
"""
Expand Down Expand Up @@ -376,15 +413,7 @@ def __init__(
if not _check_linear_int4_k(in_features, groupsize, inner_k_tiles):
raise ValueError("Padding for QAT 4w is not supported yet")
self.inner_k_tiles = inner_k_tiles
weight_config = FakeQuantizeConfig(
dtype=torch.uint4,
group_size=groupsize,
is_symmetric=False,
is_dynamic=True,
scale_precision=scales_precision,
zero_point_precision=scales_precision,
zero_point_domain=ZeroPointDomain.FLOAT,
)
weight_config = _get_4w_weight_config(groupsize, scales_precision)
super().__init__(
in_features,
out_features,
Expand Down Expand Up @@ -417,3 +446,21 @@ def disable_4w_fake_quant(mod: torch.nn.Module):
"""
if isinstance(mod, Int4WeightOnlyQATLinear):
mod.disable_fake_quant()


def _get_4w_weight_config(
group_size: int,
qparams_precision: torch.dtype,
) -> FakeQuantizeConfig:
"""
Return the weight `FakeQuantizeConfig` for `Int4WeightOnlyQATQuantizer`.
"""
return FakeQuantizeConfig(
dtype=torch.uint4,
group_size=group_size,
is_symmetric=False,
is_dynamic=True,
scale_precision=qparams_precision,
zero_point_precision=qparams_precision,
zero_point_domain=ZeroPointDomain.FLOAT,
)
Loading