Skip to content

Commit 4b3a742

Browse files
committed
up
1 parent 18d63e1 commit 4b3a742

File tree

3 files changed

+76
-94
lines changed

3 files changed

+76
-94
lines changed

torchao/experimental/packed_linear_int8_dynamic_activation_intx_weight_layout.py

Lines changed: 25 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import logging
88
from enum import Enum, auto
9-
from typing import Optional, Tuple
9+
from typing import Optional, Tuple, Union
1010

1111
import torch
1212
from torch.utils._python_dispatch import return_and_correct_aliasing
@@ -53,39 +53,21 @@ def target_from_str(target: str) -> Target:
5353

5454

5555
class PackedLinearInt8DynamicActivationIntxWeightLayout(Layout):
56-
bit_width: Optional[int]
57-
group_size: Optional[int]
58-
has_weight_zeros: Optional[bool]
59-
has_bias: Optional[bool]
60-
# The target platform for the layout, 'native' or 'aten'
61-
target: Optional[Target]
62-
6356
def __init__(
6457
self,
65-
bit_width: Optional[int] = None,
66-
group_size: Optional[int] = None,
67-
has_weight_zeros: Optional[bool] = None,
68-
has_bias: Optional[bool] = None,
69-
target: Optional[str] = "native",
58+
target: Union[str, Target] = "native",
7059
):
71-
if bit_width is not None:
72-
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
73-
if group_size is not None:
74-
assert group_size >= 1, f"group_size must be positive, got {group_size}"
75-
76-
self.bit_width = bit_width
77-
self.group_size = group_size
78-
self.has_weight_zeros = has_weight_zeros
79-
self.has_bias = has_bias
80-
self.target = target_from_str(target)
81-
82-
if not self.has_params_set():
83-
assert (
84-
self.bit_width is None
85-
and self.group_size is None
86-
and self.has_weight_zeros is None
87-
and self.has_bias is None
88-
), "bit_width, group_size, has_weight_zeros, has_bias must be None if has_params_set is False"
60+
if isinstance(target, str):
61+
target = target_from_str(target)
62+
self.target = target
63+
64+
self.bit_width: Optional[int] = None
65+
self.group_size: Optional[int] = None
66+
self.has_weight_zeros: Optional[bool] = None
67+
# has_bias is whether the packed weights
68+
# have bias packed with them, not whether the
69+
# linear operator has bias
70+
self.has_bias: Optional[bool] = None
8971

9072
def extra_repr(self):
9173
return f"group_size={self.group_size}, bit_width={self.bit_width}, has_weight_zeros={self.has_weight_zeros}, has_bias={self.has_bias}, target={self.target}"
@@ -99,6 +81,18 @@ def has_params_set(self) -> bool:
9981
and (self.target is not None)
10082
)
10183

84+
def set_params(
85+
self, bit_width: int, group_size: int, has_weight_zeros: bool, has_bias: bool
86+
):
87+
assert bit_width >= 1 and bit_width <= 8, "bit_width must be 1 to 8"
88+
assert group_size >= 1, f"group_size must be positive, got {group_size}"
89+
90+
self.bit_width = bit_width
91+
self.group_size = group_size
92+
self.has_weight_zeros = has_weight_zeros
93+
self.has_bias = has_bias
94+
assert self.has_params_set()
95+
10296

10397
@register_layout(PackedLinearInt8DynamicActivationIntxWeightLayout)
10498
class PackedLinearInt8DynamicActivationIntxWeightAQTTensorImpl(AQTTensorImpl):

torchao/experimental/quant_api.py

Lines changed: 40 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
quantize_per_channel_group,
1616
)
1717

18-
from torchao.dtypes import PlainLayout
1918
from torchao.quantization.granularity import (
2019
PerGroup,
2120
PerRow,
@@ -516,7 +515,7 @@ def quantize(self, model: nn.Module) -> nn.Module:
516515
@dataclass
517516
class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
518517
"""
519-
Configuration for dynamically quantizes activations with 8-bits and weights with a low-bit value for linear layers.
518+
Configuration for dynamically quantizing activations with 8-bits and quantizing weights with a low-bit value.
520519
More specifically, activations are dynamically quantized to 8-bits in a channelwise manner with scales and zeros.
521520
Weights are quantized with scales and optionally zeros (controlled by has_weight_zeros) in a groupwise or channelwise
522521
manner using the number of bits specified by weight_dtype.
@@ -527,20 +526,17 @@ class Int8DynamicActivationIntxWeightConfig(AOBaseConfig):
527526
has_weight_zeros: Whether or not to include zeros in the weight quantization.
528527
weight_mapping_type: The type of mapping to use for the weight quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
529528
act_mapping_type: The type of mapping to use for the activation quantization. Must be one of MappingType.ASYMMETRIC or MappingType.SYMMETRIC.
530-
layout: The layout to use for the packed weight tensor. Must be PackedLinearInt8DynamicActivationIntxWeightLayout (default) or PlainLayout.
531-
The layout does not affect the quantization numerically and both layouts will give the same results. PlainLayout is a generic layout
532-
that works on all devices, but it is much slower than PackedLinearInt8DynamicActivationIntxWeightLayout on CPU.
533-
PackedLinearInt8DynamicActivationIntxWeightLayout is a specialized layout for CPU performance.
534-
When using PackedLinearInt8DynamicActivationIntxWeightLayout,
535-
- The weight tensor must have device=CPU
536-
- The weight tensor must have dtype=float32 (note that after applying quantization, the weights will no longer be float32)
537-
- act_mapping_type must be MappingType.ASYMMETRIC
529+
layout: The layout to use for the packed weight tensor. The layout does not affect the quantization numerically and different
530+
layouts will give similar results. The following are available layouts:
531+
- PackedLinearInt8DynamicActivationIntxWeightLayout: This layout is optimized for CPU performance.
532+
- QDQLayout: This layout is designed for export to ExecuTorch
533+
- PlainLayout: This layout is a simple python-based layout. It has low performance, but can be used
534+
when PackedLinearInt8DynamicActivationIntxWeightLayout is unavailable.
538535
"""
539536

540537
weight_dtype: torch.dtype = torch.int4
541538
granularity: Union[PerRow, PerGroup] = PerRow()
542539
has_weight_zeros: bool = False
543-
has_bias: bool = False
544540
weight_mapping_type: MappingType = MappingType.ASYMMETRIC
545541
act_mapping_type: MappingType = MappingType.ASYMMETRIC
546542
layout: Layout = PackedLinearInt8DynamicActivationIntxWeightLayout(target="native")
@@ -559,27 +555,10 @@ def _int8_dynamic_activation_intx_weigh_transform(
559555
weight_dtype = config.weight_dtype
560556
granularity = config.granularity
561557
has_weight_zeros = config.has_weight_zeros
562-
has_bias = config.has_bias
563558
weight_mapping_type = config.weight_mapping_type
564559
act_mapping_type = config.act_mapping_type
565560
layout = config.layout
566561

567-
def is_torchao_op_skippable(layout):
568-
return isinstance(layout, PlainLayout) or (
569-
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
570-
and layout.target == Target.ATEN
571-
)
572-
573-
if not is_torchao_op_skippable(layout):
574-
try:
575-
torch.ops.torchao._pack_8bit_act_4bit_weight
576-
except AttributeError:
577-
raise Exception(
578-
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
579-
+ " You can also set target to 'aten' if you are using ARM CPU."
580-
+ " Alternatively, use layout=PlainLayout() with int8_dynamic_activation_intx_weight, but note that doing so will result in much slower performance."
581-
)
582-
583562
dtype_to_bit_width = {
584563
torch.int1: 1,
585564
torch.int2: 2,
@@ -603,7 +582,18 @@ def is_torchao_op_skippable(layout):
603582
else:
604583
raise ValueError(f"granularity must be PerGroup or PerRow, got {granularity}")
605584

585+
tensor_impl_ctr_kwargs = None
606586
if isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout):
587+
# We need to create a new layout object for each module because when
588+
# granulairty is PerRow, the layout objects cannot share the group_size
589+
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(layout.target)
590+
layout.set_params(
591+
bit_width=bit_width,
592+
group_size=group_size,
593+
has_weight_zeros=has_weight_zeros,
594+
has_bias=False,
595+
)
596+
607597
assert (
608598
weight.device == torch.device("cpu")
609599
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires weight.device=CPU"
@@ -613,20 +603,24 @@ def is_torchao_op_skippable(layout):
613603
assert (
614604
act_mapping_type == MappingType.ASYMMETRIC
615605
), "PackedLinearInt8DynamicActivationIntxWeightLayout requires act_mapping_type=MappingType.ASYMMETRIC"
616-
assert not layout.has_params_set(), "PackedLinearInt8DynamicActivationIntxWeightLayout params should not already be set"
617-
layout = PackedLinearInt8DynamicActivationIntxWeightLayout(
618-
bit_width=bit_width,
619-
group_size=group_size,
620-
has_weight_zeros=has_weight_zeros,
621-
has_bias=has_bias,
622-
target="aten" if layout.target == Target.ATEN else "native",
623-
)
624606

625-
# ATEN KleidiAI kernel
626-
# TODO: long term, we want to disfavor this kernel and instead use KleidiAI kernels in torchao
627-
# that are vailable via PackedLinearInt8DynamicActivationIntxWeightLayout(target="native")
628-
# where applicable
629-
if layout.target == Target.ATEN:
607+
tensor_impl_ctr_kwargs = {"bias": bias}
608+
609+
if layout.target == Target.NATIVE:
610+
# Check kernels are installed/loaded
611+
try:
612+
torch.ops.torchao._pack_8bit_act_4bit_weight
613+
except AttributeError:
614+
raise Exception(
615+
"TorchAO experimental kernels are not loaded. To install the kernels, run `USE_CPP=1 pip install .` from ao on a machine with an ARM CPU."
616+
+ " You can also set target to 'aten' if you are using ARM CPU."
617+
)
618+
elif layout.target == Target.ATEN:
619+
# TODO: long term, we want to disfavor this route for using KleidiAI in torchao
620+
# KleidiAI kernels are accessible via Target.NATIVE if torchao is built
621+
# with TORCHAO_BUILD_KLEIDIAI=1. The Target.NATIVE route has the advantage
622+
# of it automatially dispatching to different kernel libaries based on the CPU
623+
# capability and the desired quantization
630624
assert (
631625
TORCH_VERSION_AT_LEAST_2_6
632626
), "ATEN target requires torch version > 2.6.0"
@@ -657,7 +651,7 @@ def is_torchao_op_skippable(layout):
657651
else ZeroPointDomain.NONE,
658652
_layout=layout,
659653
use_hqq=False,
660-
tensor_impl_ctr_kwargs={"bias": bias} if has_bias else None,
654+
tensor_impl_ctr_kwargs=tensor_impl_ctr_kwargs,
661655
)
662656

663657
# Note that PackedLinearInt8DynamicActivationIntxWeightLayout has dynamic activation quantization fused
@@ -678,7 +672,10 @@ def is_torchao_op_skippable(layout):
678672
module.weight = torch.nn.Parameter(weight, requires_grad=False)
679673

680674
# If bias was packed with weights, set bias to None on module
681-
if has_bias:
675+
if (
676+
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
677+
and layout.has_bias
678+
):
682679
module.bias = None
683680

684681
return module

torchao/experimental/tests/test_int8_dynamic_activation_intx_weight.py

Lines changed: 11 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
6363
"""
6464
Checks the accuracy of different layouts by comparing the results to PlainLayout()
6565
"""
66-
m = 1
66+
m = 3
6767
n = 1071
6868
k = 4096
6969
activations = torch.randn(m, k)
@@ -96,15 +96,12 @@ def test_accuracy(self, layout, weight_dtype, has_weight_zeros, granularity):
9696
result = quantized_model(activations)
9797
expected_result = quantized_model_reference(activations)
9898

99-
if self._use_relaxed_tolerance(layout, weight_dtype, has_weight_zeros):
100-
self.assertTrue(
101-
torch.nn.functional.mse_loss(result, expected_result) <= 1e-5
102-
)
103-
else:
104-
self.assertTrue(torch.allclose(result, expected_result, atol=1e-4))
99+
# When weight_dtype is int4, the quantization error may be larger
100+
# because KleidiAI kernels may be used
101+
self._assert_close(result, expected_result, strict=(weight_dtype != torch.int4))
105102

106103
def test_accuracy_aten(self):
107-
m = 1
104+
m = 3
108105
n = 1024
109106
k = 4096
110107
activations = torch.randn(m, k)
@@ -113,7 +110,6 @@ def test_accuracy_aten(self):
113110
weight_dtype = torch.int4
114111
granularity = PerGroup(128)
115112
has_weight_zeros = False
116-
has_bias = False # KleidiAI throws if bias is packed with weights
117113

118114
reference_layout = PlainLayout()
119115
quantized_model = copy.deepcopy(model)
@@ -123,7 +119,6 @@ def test_accuracy_aten(self):
123119
weight_dtype=weight_dtype,
124120
granularity=granularity,
125121
has_weight_zeros=has_weight_zeros,
126-
has_bias=has_bias,
127122
layout=PackedLinearInt8DynamicActivationIntxWeightLayout(target="aten"),
128123
),
129124
)
@@ -143,16 +138,12 @@ def test_accuracy_aten(self):
143138
result = quantized_model(activations)
144139
expected_result = quantized_model_reference(activations)
145140

146-
self.assertTrue(torch.nn.functional.mse_loss(result, expected_result) <= 1e-8)
141+
self._assert_close(result, expected_result, strict=False)
147142

148-
def _use_relaxed_tolerance(self, layout, weight_dtype, has_weight_zeros):
149-
# Use relaxed tolerance in cases where KleidiAI kernels might
150-
# be selected.
151-
return (
152-
isinstance(layout, PackedLinearInt8DynamicActivationIntxWeightLayout)
153-
and weight_dtype == torch.int4
154-
and not has_weight_zeros
155-
)
143+
def _assert_close(self, result, expected_result, strict: bool = False):
144+
self.assertTrue(torch.nn.functional.mse_loss(result, expected_result) <= 1e-8)
145+
if strict:
146+
self.assertTrue(torch.allclose(result, expected_result, atol=1e-3))
156147

157148
def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
158149
self,
@@ -171,7 +162,7 @@ def test_export_compile_aoti_PackedLinearInt8DynamicActivationIntxWeightLayout(
171162
has_weight_zeros = True
172163
layers = [
173164
torch.nn.Linear(k0, k1, bias=False),
174-
torch.nn.Linear(k1, k2, bias=False),
165+
torch.nn.Linear(k1, k2, bias=True),
175166
torch.nn.Linear(k2, k3, bias=False),
176167
]
177168
model = torch.nn.Sequential(*layers)

0 commit comments

Comments
 (0)