Skip to content

Commit 0f6bae5

Browse files
authored
Move and rename GranularityType -> Granularity (#1038)
* Make module swap the main QAT flow again Summary: Following #987, this commit makes module swap the main QAT flow today. We remove all tensor subclass fake quantize injection logic since this is not needed in both the long term and the short term plans for QAT. In the short term, we will continue to use a full module swap flow, and only migrate to the long term flow once there is general distributed support for tensor subclasses and when tensor subclass composability provides meaningful benefits. Test Plan: python test/quantization/test_qat.py [ghstack-poisoned] * Move and rename GranularityType -> Granularity Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned] * Update base for Update on "Move and rename GranularityType -> Granularity" Summary: Move GranularityType to quant_primitives.py to be consistent with other similar fields like MappingType and ZeroPointDomain. Test Plan: CI [ghstack-poisoned]
1 parent 107e378 commit 0f6bae5

File tree

15 files changed

+143
-111
lines changed

15 files changed

+143
-111
lines changed

test/dtypes/test_affine_quantized_float.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,17 @@
2626
float8_weight_only,
2727
quantize_,
2828
)
29-
from torchao.quantization.observer import PerRow, PerTensor
29+
from torchao.quantization.granularity import (
30+
PerRow,
31+
PerTensor,
32+
)
3033
from torchao.quantization.quant_api import (
3134
float8_static_activation_float8_weight,
3235
)
33-
from torchao.quantization.quant_primitives import MappingType, choose_qparams_affine
36+
from torchao.quantization.quant_primitives import (
37+
MappingType,
38+
choose_qparams_affine,
39+
)
3440

3541
random.seed(0)
3642
torch.manual_seed(0)

test/quantization/test_observer.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,13 @@
99
from torch.testing._internal import common_utils
1010
from torch.testing._internal.common_utils import TestCase
1111

12-
from torchao.quantization.observer import (
13-
AffineQuantizedMinMaxObserver,
12+
from torchao.quantization.granularity import (
1413
PerAxis,
1514
PerTensor,
1615
)
16+
from torchao.quantization.observer import (
17+
AffineQuantizedMinMaxObserver,
18+
)
1719
from torchao.quantization.quant_api import (
1820
insert_observers_,
1921
)
@@ -42,7 +44,7 @@ def test_min_max_per_tensor_affine(self):
4244
obs = AffineQuantizedMinMaxObserver(
4345
MappingType.ASYMMETRIC,
4446
torch.uint8,
45-
granularity_type=PerTensor(),
47+
granularity=PerTensor(),
4648
eps=torch.finfo(torch.float32).eps,
4749
scale_dtype=torch.float,
4850
zero_point_dtype=torch.int,
@@ -54,7 +56,7 @@ def test_min_max_per_channel_affine(self):
5456
obs = AffineQuantizedMinMaxObserver(
5557
MappingType.ASYMMETRIC,
5658
torch.uint8,
57-
granularity_type=PerAxis(axis=0),
59+
granularity=PerAxis(axis=0),
5860
eps=torch.finfo(torch.float32).eps,
5961
scale_dtype=torch.float,
6062
zero_point_dtype=torch.int,
@@ -68,7 +70,7 @@ def test_block_size_calc_success(self):
6870
obs = AffineQuantizedMinMaxObserver(
6971
MappingType.SYMMETRIC,
7072
torch.float8_e4m3fn,
71-
granularity_type=PerTensor(),
73+
granularity=PerTensor(),
7274
eps=torch.finfo(torch.float32).eps,
7375
scale_dtype=torch.float,
7476
zero_point_dtype=torch.int,
@@ -87,7 +89,7 @@ def test_block_size_calc_success(self):
8789
obs = AffineQuantizedMinMaxObserver(
8890
MappingType.SYMMETRIC,
8991
torch.float8_e4m3fn,
90-
granularity_type=PerAxis(1),
92+
granularity=PerAxis(1),
9193
eps=torch.finfo(torch.float32).eps,
9294
scale_dtype=torch.float,
9395
zero_point_dtype=torch.int,
@@ -102,7 +104,7 @@ def test_block_size_row_errors(self):
102104
obs = AffineQuantizedMinMaxObserver(
103105
MappingType.SYMMETRIC,
104106
torch.float8_e4m3fn,
105-
granularity_type=PerAxis(0),
107+
granularity=PerAxis(0),
106108
eps=torch.finfo(torch.float32).eps,
107109
scale_dtype=torch.float,
108110
zero_point_dtype=torch.int,
@@ -121,7 +123,7 @@ def test_block_size_row_errors(self):
121123
obs = AffineQuantizedMinMaxObserver(
122124
MappingType.SYMMETRIC,
123125
torch.float8_e4m3fn,
124-
granularity_type=PerAxis(1),
126+
granularity=PerAxis(1),
125127
eps=torch.finfo(torch.float32).eps,
126128
scale_dtype=torch.float,
127129
zero_point_dtype=torch.int,
@@ -149,7 +151,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
149151
input_observer = AffineQuantizedMinMaxObserver(
150152
MappingType.SYMMETRIC,
151153
torch.float8_e4m3fn,
152-
granularity_type=PerTensor(),
154+
granularity=PerTensor(),
153155
eps=torch.finfo(torch.float32).eps,
154156
scale_dtype=torch.float,
155157
zero_point_dtype=torch.int,
@@ -159,7 +161,7 @@ def test_linear_observer_tensor(self, observe_weight: bool):
159161
weight_observer = AffineQuantizedMinMaxObserver(
160162
MappingType.SYMMETRIC,
161163
torch.float8_e4m3fn,
162-
granularity_type=PerTensor(),
164+
granularity=PerTensor(),
163165
eps=torch.finfo(torch.float32).eps,
164166
scale_dtype=torch.float,
165167
zero_point_dtype=torch.int,

torchao/_models/llama/eval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,9 @@
2424
float8_dynamic_activation_float8_weight,
2525
float8_static_activation_float8_weight,
2626
)
27-
from torchao.quantization.observer import PerRow, PerTensor
2827
from torchao._models._eval import TransformerEvalWrapper, InputRecorder
2928
from torchao._models.llama.model import prepare_inputs_for_model
29+
from torchao.quantization.granularity import PerRow, PerTensor
3030

3131
from tokenizer import get_tokenizer
3232
import time
@@ -255,4 +255,4 @@ def run_evaluation(
255255
args.calibration_limit,
256256
args.calibration_seq_length,
257257
args.pad_calibration_inputs,
258-
)
258+
)

torchao/_models/llama/generate.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ def main(
216216
float8_weight_only,
217217
float8_dynamic_activation_float8_weight,
218218
)
219-
from torchao.quantization.observer import PerTensor, PerRow
219+
from torchao.quantization.granularity import PerTensor, PerRow
220220
if "int8wo" in quantization:
221221
quantize_(model, int8_weight_only())
222222
if "int8dq" in quantization:

torchao/prototype/awq/api.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,13 @@
11
import torch
22
import torch.nn.functional as F
33

4+
from torchao.quantization.granularity import PerGroup
45
from torchao.quantization.quant_primitives import (
56
MappingType,
67
ZeroPointDomain,
78
_DTYPE_TO_QVALUE_BOUNDS,
89
)
910
from torchao.quantization import to_weight_tensor_with_linear_activation_scale_metadata
10-
from torchao.quantization.observer import PerGroup
1111
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
1212
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
1313
from torchao.dtypes import(

torchao/prototype/awq/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,20 +7,21 @@
77
from torch.utils._python_dispatch import return_and_correct_aliasing
88
from torchao.dtypes.uintx import _DTYPE_TO_BIT_WIDTH, UintxLayoutType
99
from torchao.dtypes import to_affine_quantized_intx
10+
from torchao.quantization.granularity import Granularity
1011
from torchao.quantization.quant_primitives import (
1112
MappingType,
1213
ZeroPointDomain,
1314
)
1415
from torchao.quantization.observer import (
15-
AffineQuantizedObserverBase, GranularityType
16+
AffineQuantizedObserverBase,
1617
)
1718

1819

1920
class AWQObserver(AffineQuantizedObserverBase):
2021
def __init__(self,
2122
weight: torch.Tensor,
2223
bias: torch.Tensor,
23-
quantization_granularity: GranularityType,
24+
quantization_granularity: Granularity,
2425
mapping_type: MappingType,
2526
target_dtype: torch.dtype,
2627
n_validation_examples: int,
@@ -40,7 +41,7 @@ def __init__(self,
4041
Args:
4142
weight: The weight tensor to be observed.
4243
bias: The bias tensor to be observed.
43-
quantization_granularity: Granularity type which specifies how many weights share the same scale/zero point
44+
quantization_granularity: Granularity which specifies how many weights share the same scale/zero point
4445
input_dtype: The data type of the input tensor.
4546
mapping_type: Always set to asymmetric
4647
target_dtype: The target data type of the quantized tensor
@@ -153,4 +154,4 @@ def from_float(cls, float_linear: torch.nn.Linear, act_obs: AWQObserver):
153154
observed_linear = cls(float_linear.in_features, float_linear.out_features, act_obs, False, device=float_linear.weight.device, dtype=float_linear.weight.dtype)
154155
observed_linear.weight = float_linear.weight
155156
observed_linear.bias = float_linear.bias
156-
return observed_linear
157+
return observed_linear

torchao/quantization/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ change_linear_weights_to_int8_dqtensors(model)
137137
```python
138138
# for torch 2.4+
139139
from torchao.quantization import quantize_, float8_dynamic_activation_float8_weight
140-
from torchao.quantization.observer import PerTensor
140+
from torchao.quantization.quant_api import PerTensor
141141
quantize_(model, float8_dynamic_activation_float8_weight(granularity=PerTensor()))
142142
```
143143

torchao/quantization/autoquant.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,14 @@
1212
from torchao.dtypes import AffineQuantizedTensor, PlainLayoutType, TensorCoreTiledLayoutType, Float8LayoutType
1313
from torchao.quantization.linear_activation_quantized_tensor import LinearActivationQuantizedTensor
1414
from torch.utils._python_dispatch import return_and_correct_aliasing
15-
from .quant_primitives import (
16-
safe_int_mm,
15+
from .granularity import (
16+
PerAxis,
17+
PerRow,
18+
PerTensor,
1719
)
20+
from .quant_primitives import safe_int_mm
1821
from torchao.utils import TORCH_VERSION_AT_LEAST_2_3, TORCH_VERSION_AT_LEAST_2_5
1922
from torchao.quantization.utils import quantize_activation_per_token_absmax
20-
from torchao.quantization.observer import PerAxis, PerTensor, PerRow
2123
from torchao.float8.inference import Float8MMConfig
2224

2325
import torch.nn.functional as F

torchao/quantization/granularity.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
from dataclasses import dataclass
8+
9+
10+
@dataclass(frozen=True)
11+
class Granularity:
12+
"""
13+
Base class for representing the granularity of quantization.
14+
15+
This class serves as a parent for specific granularity types used in
16+
quantization operations, such as per-tensor or per-axis quantization.
17+
"""
18+
pass
19+
20+
@dataclass(frozen=True)
21+
class PerTensor(Granularity):
22+
"""
23+
Represents per-tensor granularity in quantization.
24+
25+
This granularity type calcualtes the quantization parameters
26+
based off the entire tensor.
27+
"""
28+
pass
29+
30+
@dataclass(frozen=True)
31+
class PerAxis(Granularity):
32+
"""
33+
Represents per-axis granularity in quantization.
34+
35+
This granularity type calcualtes different quantization parameters
36+
along a specified axis of the tensor.
37+
38+
For example if the input tensor is shape [8, 16] and axis=0, then
39+
the quantization parameters are calculated for each row of the tensor.
40+
Giving a total of 8 quantization parameters.
41+
42+
43+
Attributes:
44+
axis (int): The axis along which reduction is performed.
45+
"""
46+
axis: int
47+
48+
@dataclass(frozen=True)
49+
50+
class PerGroup(Granularity):
51+
"""
52+
Represents per-channel group granularity in quantization.
53+
54+
This granularity type calcualtes different quantization parameters
55+
for each group of <group_size> elements.
56+
57+
For example if the input tensor is shape [8, 16], and the group size is 4, then
58+
the input tensor is reshaped to [64, 4]
59+
quantization parameters are calculated for each group of 4 elements,
60+
giving a total of 64 quantization parameters.
61+
62+
Attributes:
63+
group_size (int): The size of each quantization group
64+
65+
"""
66+
group_size: int
67+
68+
class PerRow(Granularity):
69+
"""
70+
Represents row-wise granularity in quantization.
71+
72+
This is a special case of per-axis quantization and is unique to Float8 matmuls
73+
where the input is quantized with a block_size of (1, ..., input.shape[-1]). And the weight
74+
is quantized with a block_size of (1, weight.shape[1]).
75+
"""
76+
pass

0 commit comments

Comments
 (0)