Skip to content

Commit 48bc81c

Browse files
authored
Add generic fake quantized linear for QAT (#1020)
* Make module swap the main QAT flow again Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] * Add generic fake quantized linear for QAT Summary: This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. Test Plan: python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig( bit_width=8, granularity="per_token", symmetric=False, dynamic=True, ) weight_config = FakeQuantizeConfig( bit_width=4, group_size=8, symmetric=True, dynamic=True, ) fq_linear = FakeQuantizedLinear( 16, 32, False, activation_config, weight_config, ) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned] * Update base for Update on "Add generic fake quantized linear for QAT" **Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
1 parent c82492d commit 48bc81c

File tree

10 files changed

+903
-211
lines changed

10 files changed

+903
-211
lines changed

test/integration/test_integration.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def test_swap(self):
328328
assert torch.allclose(y_ref, y)
329329

330330
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
331+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
331332
def test_weight_t_and_non_t_numerics_match(self):
332333
# verify that numerics match whether weight is stored
333334
# in transposed format (for cuBLAS) vs non-transposed format
@@ -1126,6 +1127,7 @@ def test_shape_logger(self):
11261127
class SmoothquantIntegrationTest(unittest.TestCase):
11271128
@torch.no_grad()
11281129
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1130+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
11291131
def test_non_dynamically_quantizable_linear(self):
11301132
if torch.cuda.is_available() and torch.cuda.get_device_capability() < (8, 0):
11311133
self.skipTest("test requires SM capability of at least (8, 0).")

test/quantization/test_qat.py

Lines changed: 293 additions & 43 deletions
Large diffs are not rendered by default.

torchao/quantization/README.md

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,7 @@ change_linear_weights_to_int8_dqtensors(model)
136136

137137
```python
138138
# for torch 2.4+
139-
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140-
from torchao.quantization.quant_api import PerTensor
139+
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight, PerTensor
141140
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142141
```
143142

@@ -321,7 +320,7 @@ This API works today but has not been extensively tested and benchmarked yet. Ha
321320

322321
```python
323322
# for torch 2.5+
324-
from torchao.quantization.quant_api import quantize_, PerRow, float8_dynamic_activation_float8_weight
323+
from torchao.quantization import quantize_, PerRow, float8_dynamic_activation_float8_weight
325324
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerRow()))
326325
```
327326

torchao/quantization/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,12 @@
1212
from .weight_only import * # noqa: F403
1313
from .unified import *
1414
from .autoquant import *
15-
from .linear_activation_quantized_tensor import ( # noqat: F403
15+
from .granularity import *
16+
from .linear_activation_quantized_tensor import (
1617
LinearActivationQuantizedTensor,
1718
to_linear_activation_quantized,
1819
)
19-
from .linear_activation_scale import ( # noqat: F403
20+
from .linear_activation_scale import (
2021
to_weight_tensor_with_linear_activation_scale_metadata,
2122
)
2223

torchao/quantization/granularity.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ class PerTensor(Granularity):
2222
"""
2323
Represents per-tensor granularity in quantization.
2424
25-
This granularity type calcualtes the quantization parameters
25+
This granularity type calculates the quantization parameters
2626
based off the entire tensor.
2727
"""
2828
pass
@@ -32,26 +32,24 @@ class PerAxis(Granularity):
3232
"""
3333
Represents per-axis granularity in quantization.
3434
35-
This granularity type calcualtes different quantization parameters
35+
This granularity type calculates different quantization parameters
3636
along a specified axis of the tensor.
3737
3838
For example if the input tensor is shape [8, 16] and axis=0, then
3939
the quantization parameters are calculated for each row of the tensor.
4040
Giving a total of 8 quantization parameters.
4141
42-
4342
Attributes:
4443
axis (int): The axis along which reduction is performed.
4544
"""
4645
axis: int
4746

4847
@dataclass(frozen=True)
49-
5048
class PerGroup(Granularity):
5149
"""
5250
Represents per-channel group granularity in quantization.
5351
54-
This granularity type calcualtes different quantization parameters
52+
This granularity type calculates different quantization parameters
5553
for each group of <group_size> elements.
5654
5755
For example if the input tensor is shape [8, 16], and the group size is 4, then
@@ -74,3 +72,19 @@ class PerRow(Granularity):
7472
is quantized with a block_size of (1, weight.shape[1]).
7573
"""
7674
pass
75+
76+
class PerToken(Granularity):
77+
"""
78+
Represents per-token granularity in quantization.
79+
80+
This granularity type calculates a different set of quantization parameters
81+
for each token, which is represented as the last dimension of the tensor.
82+
83+
For example, if the input tensor has shape [2, 3, 4], then there are 6 tokens
84+
with 4 elements each, and we will calculate 6 sets of quantization parameters,
85+
one for each token.
86+
87+
If the input tensor has only two dimensions, e.g. [8, 16], then this is
88+
equivalent to `PerAxis(axis=0)`, which yields 8 sets of quantization parameters.
89+
"""
90+
pass

torchao/quantization/prototype/qat/api.py

Lines changed: 214 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,224 @@
44
# This source code is licensed under the license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
from typing import Any, List
7+
from dataclasses import dataclass
8+
from enum import Enum
9+
from typing import Any, List, Optional, Union
810

911
import torch
1012

13+
from torchao.quantization.granularity import (
14+
Granularity,
15+
PerAxis,
16+
PerGroup,
17+
PerToken,
18+
)
1119
from torchao.quantization.unified import TwoStepQuantizer
20+
from torchao.quantization.quant_primitives import (
21+
_SUB_BYTE_INT_BOUNDS,
22+
_SUB_BYTE_UINT_BOUNDS,
23+
MappingType,
24+
TorchAODType,
25+
ZeroPointDomain,
26+
)
27+
28+
29+
@dataclass
30+
class FakeQuantizeConfig:
31+
"""
32+
Config for how to fake quantize weights or activations.
33+
34+
args:
35+
dtype: dtype to simulate during fake quantization, e.g. torch.int8.
36+
For PyTorch versions older than 2.6, you may use `TorchAODType` to represent
37+
torch.int1 to torch.int7 instead, e.g. TorchAODType.INT4.
38+
granularity: granularity of scales and zero points, e.g. PerGroup(32).
39+
We also support the following strings:
40+
1) 'per_token': equivalent to PerToken()
41+
2) 'per_channel': equivalent to PerAxis(0)
42+
3) 'per_group': equivalent to PerGroup(group_size), must be combined
43+
with separate `group_size` kwarg, Alternatively, just set the
44+
`group_size` kwarg and leave this field empty.
45+
mapping_type: whether to use symmetric (default) or asymmetric quantization
46+
Alternatively, set `is_symmetric` (bool) and leave this field empty.
47+
scale_precision: scale dtype (default torch.fp32)
48+
zero_point_precision: zero point dtype (default torch.int32)
49+
zero_point_domain: whether zero point is in integer (default) or float domain
50+
is_dynamic: whether to use dynamic (defualt) or static scale and zero points
51+
range_learning: whether to learn scale and zero points during training (coming soon)
52+
53+
kwargs (optional):
54+
group_size: size of each group in per group fake quantization,
55+
can be set instead of `granularity`
56+
is_symmetric: whether to use symmetric or asymmetric quantization,
57+
can be set instead of `mapping_type`
58+
59+
Example usage::
60+
61+
# Per token asymmetric quantization
62+
FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False)
63+
FakeQuantizeConfig(torch.int8, PerToken(), MappingType.ASYMMETRIC)
64+
65+
# Per channel symmetric quantization
66+
FakeQuantizeConfig(torch.int4, "per_channel")
67+
FakeQuantizeConfig(torch.int4, "per_channel", is_symmetric=True)
68+
FakeQuantizeConfig(torch.int4, PerAxis(0), MappingType.SYMMETRIC)
69+
70+
# Per group symmetric quantization
71+
FakeQuantizeConfig(torch.int4, group_size=32)
72+
FakeQuantizeConfig(torch.int4, group_size=32, is_symmetric=True)
73+
FakeQuantizeConfig(torch.int4, "per_group", group_size=32, is_symmetric=True)
74+
FakeQuantizeConfig(torch.int4, PerGroup(32), MappingType.SYMMETRIC)
75+
"""
76+
dtype: Union[torch.dtype, TorchAODType]
77+
granularity: Granularity
78+
mapping_type: MappingType
79+
scale_precision: torch.dtype
80+
zero_point_precision: torch.dtype
81+
zero_point_domain: ZeroPointDomain
82+
is_dynamic: bool = True
83+
range_learning: bool = False
84+
85+
def __init__(
86+
self,
87+
dtype: Union[torch.dtype, TorchAODType],
88+
granularity: Union[Granularity, str, None] = None,
89+
mapping_type: Optional[MappingType] = None,
90+
scale_precision: torch.dtype = torch.float32,
91+
zero_point_precision: torch.dtype = torch.int32,
92+
zero_point_domain: ZeroPointDomain = ZeroPointDomain.INT,
93+
is_dynamic: bool = True,
94+
range_learning: bool = False,
95+
*,
96+
group_size: Optional[int] = None,
97+
is_symmetric: Optional[bool] = None,
98+
):
99+
self.dtype = dtype
100+
self.granularity = self._get_granularity(granularity, group_size)
101+
self.mapping_type = self._get_mapping_type(mapping_type, is_symmetric)
102+
self.scale_precision = scale_precision
103+
self.zero_point_precision = zero_point_precision
104+
self.zero_point_domain = zero_point_domain
105+
self.is_dynamic = is_dynamic
106+
self.range_learning = range_learning
107+
108+
# Validate dtype
109+
all_dtypes = [torch.int8, torch.uint8]
110+
all_dtypes.extend(list(_SUB_BYTE_INT_BOUNDS.keys()))
111+
all_dtypes.extend(list(_SUB_BYTE_UINT_BOUNDS.keys()))
112+
if dtype not in all_dtypes:
113+
raise ValueError("Unsupported dtype '%s', choose from %s" % (dtype, all_dtypes))
114+
115+
def _get_granularity(
116+
self,
117+
granularity: Union[Granularity, str, None],
118+
group_size: Optional[int],
119+
) -> Granularity:
120+
"""
121+
Parse the `Granularity` represented in the args.
122+
123+
Granularity can be specified in one of three ways:
124+
1) `Granularity` object: one of PerToken(), PerAxis(), and PerGroup(group_size)
125+
2) str: one of 'per_token', 'per_channel', and 'per_group'
126+
3) None: `group_size` must be set instead, represents per group granularity
127+
"""
128+
# If group_size is set, then granularity must be either "per_group" or None
129+
if group_size is not None and granularity != "per_group" and granularity is not None:
130+
raise ValueError("`group_size` conflicts with granularity '%s'" % granularity)
131+
132+
# Case 1: Granularity object
133+
if isinstance(granularity, Granularity):
134+
if not isinstance(granularity, (PerToken, PerAxis, PerGroup)):
135+
raise ValueError("Granularity '%s' is not supported" % granularity)
136+
if isinstance(granularity, PerAxis) and granularity.axis != 0:
137+
raise ValueError("Only axis=0 is supported for PerAxis granularity")
138+
return granularity
139+
140+
# Case 2: str granularity
141+
if granularity == "per_token":
142+
return PerToken()
143+
elif granularity == "per_channel":
144+
return PerAxis(axis=0)
145+
elif granularity == "per_group":
146+
if group_size is None:
147+
raise ValueError("Granularity was 'per_group' but no `group_size` was set")
148+
return PerGroup(group_size)
149+
elif isinstance(granularity, str):
150+
raise ValueError(
151+
"Unexpected granularity: '%s', must be one of %s" %
152+
(granularity, ["per_token", "per_channel", "per_group"])
153+
)
154+
155+
# Case 3: None granularity + group_size was specified
156+
if granularity is not None:
157+
raise ValueError(
158+
"Granularity '%s' has unexpected type %s" % (granularity, type(granularity))
159+
)
160+
if group_size is None:
161+
raise ValueError("At least one of `granularity` or `group_size` must be set")
162+
return PerGroup(group_size)
163+
164+
def _get_mapping_type(
165+
self,
166+
mapping_type: Optional[MappingType],
167+
is_symmetric: Optional[bool],
168+
) -> MappingType:
169+
"""
170+
Parse the `MappingType` represented in the args.
171+
172+
Mapping type can be specified in one of two ways:
173+
1): `MappingType` object: one of SYMMETRIC or ASYMMETRIC
174+
2): is_symmetric bool
175+
"""
176+
if mapping_type is not None and is_symmetric is not None:
177+
raise ValueError("Cannot set both `mapping_type` and `is_symmetric`")
178+
179+
# Case 0: Default to symmetric
180+
if mapping_type is None and is_symmetric is None:
181+
return MappingType.SYMMETRIC
182+
183+
# Case 1: MappingType object
184+
if mapping_type is not None:
185+
if mapping_type not in [MappingType.SYMMETRIC, MappingType.ASYMMETRIC]:
186+
raise ValueError("MappingType '%s' is not supported" % mapping_type)
187+
return mapping_type
188+
189+
# Case 2: is_symmetric flag
190+
assert is_symmetric is not None
191+
if is_symmetric:
192+
return MappingType.SYMMETRIC
193+
else:
194+
return MappingType.ASYMMETRIC
195+
196+
@property
197+
def group_size(self) -> int:
198+
"""
199+
If this is per group granularity, return the group size.
200+
Otherwise, throw an error.
201+
"""
202+
if isinstance(self.granularity, PerGroup):
203+
return self.granularity.group_size
204+
else:
205+
raise ValueError("`group_size` is undefined for %s granularity" % self.granularity)
206+
207+
@property
208+
def is_symmetric(self) -> bool:
209+
"""
210+
Return True if mapping type is symmetric, else False (asymmetric).
211+
"""
212+
return self.mapping_type == MappingType.SYMMETRIC
213+
214+
def __setattr__(self, name: str, value: Any):
215+
"""
216+
Support setting `group_size` and `is_symmetric`.
217+
"""
218+
if name == "group_size":
219+
super().__setattr__("granularity", PerGroup(value))
220+
elif name == "is_symmetric":
221+
mapping_type = MappingType.SYMMETRIC if value else MappingType.ASYMMETRIC
222+
super().__setattr__("mapping_type", mapping_type)
223+
else:
224+
super().__setattr__(name, value)
12225

13226

14227
class ComposableQATQuantizer(TwoStepQuantizer):

0 commit comments

Comments
 (0)