Skip to content

Commit bc509dc

Browse files
authored
config migration: smoothquant (#1851)
* Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent 2cf8fda commit bc509dc

File tree

5 files changed

+87
-77
lines changed

5 files changed

+87
-77
lines changed

test/prototype/test_smoothquant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import torch
66

77
from torchao.prototype.smoothquant import (
8+
SmoothQuantConfig,
89
SmoothQuantObservedLinear,
910
insert_smooth_quant_observer_,
1011
load_smooth_quant_recipe,
1112
save_smooth_quant_recipe,
12-
smooth_quant,
1313
)
1414
from torchao.quantization import quantize_
1515
from torchao.quantization.utils import (
@@ -85,7 +85,7 @@ def forward(self, x):
8585
m(data)
8686
# quantize
8787
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
88-
quantize_(m, smooth_quant(), is_observed_linear)
88+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
8989
with torch.inference_mode():
9090
if TORCH_VERSION_AT_LEAST_2_5:
9191
m = torch.compile(m, fullgraph=True)
@@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype):
173173

174174
# quantize
175175
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
176-
quantize_(m, smooth_quant(), is_observed_linear)
176+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
177177
if TORCH_VERSION_AT_LEAST_2_5:
178178
# earlier versions are not compatible
179179
m = torch.compile(m, fullgraph=True)

torchao/prototype/smoothquant/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or st
2727
## Usage of API
2828
The following APIs are provided:
2929
- insert_smooth_quant_observer_
30-
- smooth_quant
30+
- SmoothQuantConfig
3131
- save_smooth_quant_recipe (advanced)
3232
- load_smooth_quant_recipe (advanced)
3333

@@ -37,11 +37,11 @@ insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic")
3737
```
3838
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe.
3939

40-
`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
40+
`SmoothQuantConfig` configures appliying SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
4141
```python
4242
from torchao.prototype.smoothquant import SmoothQuantObservedLinear
4343
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
44-
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear)
44+
torchao.quantization.quantize_(model, SmoothQuantConfig(), is_observed_linear)
4545
```
4646
`is_observed_linear` is a filter so that we only quantize observed linear layers.
4747

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from .api import (
2+
SmoothQuantConfig,
23
insert_smooth_quant_observer_,
34
load_smooth_quant_recipe,
45
save_smooth_quant_recipe,
5-
smooth_quant,
66
)
77
from .core import SmoothQuantObservedLinear
88

99
__all__ = [
1010
"insert_smooth_quant_observer_",
1111
"load_smooth_quant_recipe",
1212
"save_smooth_quant_recipe",
13-
"smooth_quant",
13+
"SmoothQuantConfig",
1414
"SmoothQuantObservedLinear",
1515
]

torchao/prototype/smoothquant/api.py

Lines changed: 77 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,30 @@
1+
import types
2+
from dataclasses import dataclass
13
from typing import Dict, Optional
24

35
import torch
46

7+
from torchao.core.config import AOBaseConfig
58
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
69
from torchao.prototype.smoothquant.core import (
710
SmoothQuantObservedLinear,
811
SmoothQuantObserver,
912
)
13+
from torchao.quantization import quantize_
1014
from torchao.quantization.linear_activation_quantized_tensor import (
1115
to_linear_activation_quantized,
1216
)
1317
from torchao.quantization.linear_activation_scale import (
1418
to_weight_tensor_with_linear_activation_scale_metadata,
1519
)
16-
from torchao.quantization.quant_api import _replace_with_custom_fn_if_matches_filter
20+
from torchao.quantization.quant_api import (
21+
_linear_extra_repr,
22+
_replace_with_custom_fn_if_matches_filter,
23+
)
1724
from torchao.quantization.quant_primitives import MappingType
25+
from torchao.quantization.transform_module import (
26+
register_quantize_module_handler,
27+
)
1828
from torchao.quantization.utils import _get_per_token_block_size
1929
from torchao.quantization.weight_tensor_linear_activation_quantization import (
2030
to_weight_tensor_with_linear_activation_quantization_metadata,
@@ -53,32 +63,6 @@ def replace_with_observer(layer):
5363
_replace_with_custom_fn_if_matches_filter(model, replace_with_observer, _is_linear)
5464

5565

56-
def _observed_linear_subclass_inserter(constructor):
57-
"""
58-
Replaces unquantized observed linear instances with quantized linear instances.
59-
60-
Args:
61-
constructor: the function which applies quantization to the observed linear layer
62-
"""
63-
64-
def insert_subclass(observed_linear):
65-
# creates the new linear layer using constructor
66-
linear = torch.nn.Linear(
67-
observed_linear.in_features,
68-
observed_linear.out_features,
69-
observed_linear.bias is not None,
70-
device=observed_linear.weight.device,
71-
dtype=observed_linear.weight.dtype,
72-
)
73-
linear.weight = torch.nn.Parameter(
74-
constructor(observed_linear), requires_grad=False
75-
)
76-
linear.bias = observed_linear.bias
77-
return linear
78-
79-
return insert_subclass
80-
81-
8266
def save_smooth_quant_recipe(
8367
model: torch.nn.Module, save_path: str
8468
) -> Dict[str, torch.Tensor]:
@@ -121,7 +105,14 @@ def recurse(module: torch.nn.Module, name: str = ""):
121105
# act_scales is None for dynamic quantization
122106
if any(x is None for x in (smoothing_factor, wei_scales)):
123107
return module
124-
return smooth_quant(smoothing_factor, act_scales, wei_scales)(module)
108+
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
109+
wrapper = torch.nn.Sequential(module)
110+
quantize_(
111+
wrapper,
112+
SmoothQuantConfig(smoothing_factor, act_scales, wei_scales),
113+
is_observed_linear,
114+
)
115+
return wrapper[0]
125116

126117
mod_new = module
127118

@@ -158,54 +149,73 @@ def static_quantize(self, input, scale, zero_point):
158149
)
159150

160151

161-
def smooth_quant(
162-
smoothing_factor: Optional[torch.Tensor] = None,
163-
act_scales: Optional[torch.Tensor] = None,
164-
wei_scales: Optional[torch.Tensor] = None,
165-
):
152+
@dataclass
153+
class SmoothQuantConfig(AOBaseConfig):
166154
"""
167-
Quantizes linear layers when passed into quantize_()
155+
Configuration for quantizing linear layers when passed into quantize_()
168156
169157
Args:
170158
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
171159
act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
172160
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
173161
"""
174162

175-
def quantize_weight(observed_linear):
176-
target_dtype = torch.int8
177-
# act_scales is None for dynamic quantization thus not checked
178-
if any(x is None for x in (smoothing_factor, wei_scales)):
179-
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
180-
weight = observed_linear.obs.weight * factor
181-
else:
182-
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
183-
weight = observed_linear.weight * factor
184-
weight = weight.to(observed_linear.weight.dtype)
185-
block_size = (1, weight.size(1))
186-
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
187-
qw = to_affine_quantized_intx_static(
188-
weight,
189-
w_scales,
190-
wei_zero_points,
191-
block_size,
192-
target_dtype,
193-
)
163+
smoothing_factor: Optional[torch.Tensor] = None
164+
act_scales: Optional[torch.Tensor] = None
165+
wei_scales: Optional[torch.Tensor] = None
194166

195-
if x_scale is None:
196-
# dynamic quant
197-
qw = to_linear_activation_quantized(
198-
qw, _ActQuantizer(target_dtype).dynamic_quantize
199-
)
200-
else:
201-
# static quant
202-
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
203-
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
204-
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
205-
)
206167

207-
return to_weight_tensor_with_linear_activation_scale_metadata(
208-
qw, factor.to(qw.dtype)
168+
@register_quantize_module_handler(SmoothQuantConfig)
169+
def _smooth_quant_transform(
170+
module: torch.nn.Module,
171+
config: SmoothQuantConfig,
172+
):
173+
smoothing_factor = config.smoothing_factor
174+
act_scales = config.act_scales
175+
wei_scales = config.wei_scales
176+
observed_linear = module
177+
178+
linear = torch.nn.Linear(
179+
observed_linear.in_features,
180+
observed_linear.out_features,
181+
observed_linear.bias is not None,
182+
device=observed_linear.weight.device,
183+
dtype=observed_linear.weight.dtype,
184+
)
185+
linear.bias = observed_linear.bias
186+
187+
target_dtype = torch.int8
188+
# act_scales is None for dynamic quantization thus not checked
189+
if any(x is None for x in (smoothing_factor, wei_scales)):
190+
factor, x_scale, w_scales = observed_linear.obs.calculate_qparams()
191+
weight = observed_linear.obs.weight * factor
192+
else:
193+
factor, x_scale, w_scales = smoothing_factor, act_scales, wei_scales
194+
weight = observed_linear.weight * factor
195+
weight = weight.to(observed_linear.weight.dtype)
196+
block_size = (1, weight.size(1))
197+
wei_zero_points = torch.zeros_like(w_scales, dtype=torch.int64)
198+
qw = to_affine_quantized_intx_static(
199+
weight,
200+
w_scales,
201+
wei_zero_points,
202+
block_size,
203+
target_dtype,
204+
)
205+
206+
if x_scale is None:
207+
# dynamic quant
208+
qw = to_linear_activation_quantized(
209+
qw, _ActQuantizer(target_dtype).dynamic_quantize
210+
)
211+
else:
212+
# static quant
213+
x_zero_point = torch.zeros_like(x_scale, dtype=torch.int64)
214+
qw = to_weight_tensor_with_linear_activation_quantization_metadata(
215+
qw, _ActQuantizer(target_dtype).static_quantize, x_scale, x_zero_point
209216
)
210217

211-
return _observed_linear_subclass_inserter(quantize_weight)
218+
qw = to_weight_tensor_with_linear_activation_scale_metadata(qw, factor.to(qw.dtype))
219+
linear.weight = torch.nn.Parameter(qw, requires_grad=False)
220+
linear.extra_repr = types.MethodType(_linear_extra_repr, module)
221+
return linear

torchao/prototype/smoothquant/example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from transformers import AutoModelForCausalLM, AutoTokenizer
1010

1111
from torchao.prototype.smoothquant import (
12+
SmoothQuantConfig,
1213
SmoothQuantObservedLinear,
1314
insert_smooth_quant_observer_,
14-
smooth_quant,
1515
)
1616
from torchao.quantization import quantize_
1717

@@ -145,7 +145,7 @@ def wikitext2_ppl(
145145
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
146146
print(f"running SmoothQuant with {quant_mode} quantization")
147147
t0 = time.time()
148-
quantize_(model, smooth_quant(), is_observed_linear)
148+
quantize_(model, SmoothQuantConfig(), is_observed_linear)
149149
print(f"time for quantization: {time.time() - t0:.02f} seconds")
150150
if model_save_path is not None:
151151
print(f"Saving quantized model to {model_save_path}")

0 commit comments

Comments
 (0)