Skip to content

Commit 6f3d127

Browse files
committed
Update
[ghstack-poisoned]
1 parent 0ecb02d commit 6f3d127

File tree

5 files changed

+11
-20
lines changed

5 files changed

+11
-20
lines changed

test/prototype/test_smoothquant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import torch
66

77
from torchao.prototype.smoothquant import (
8+
SmoothQuantConfig,
89
SmoothQuantObservedLinear,
910
insert_smooth_quant_observer_,
1011
load_smooth_quant_recipe,
1112
save_smooth_quant_recipe,
12-
smooth_quant,
1313
)
1414
from torchao.quantization import quantize_
1515
from torchao.quantization.utils import (
@@ -85,7 +85,7 @@ def forward(self, x):
8585
m(data)
8686
# quantize
8787
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
88-
quantize_(m, smooth_quant(), is_observed_linear)
88+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
8989
with torch.inference_mode():
9090
if TORCH_VERSION_AT_LEAST_2_5:
9191
m = torch.compile(m, fullgraph=True)
@@ -173,7 +173,7 @@ def test_save_load_recipe(alpha, quant_mode, device, idtype):
173173

174174
# quantize
175175
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
176-
quantize_(m, smooth_quant(), is_observed_linear)
176+
quantize_(m, SmoothQuantConfig(), is_observed_linear)
177177
if TORCH_VERSION_AT_LEAST_2_5:
178178
# earlier versions are not compatible
179179
m = torch.compile(m, fullgraph=True)

torchao/prototype/smoothquant/README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ python example.py -m MODLE_ID --device=<cuda or cpu> --quant-mode=<dynamic or st
2727
## Usage of API
2828
The following APIs are provided:
2929
- insert_smooth_quant_observer_
30-
- smooth_quant
30+
- SmoothQuantConfig
3131
- save_smooth_quant_recipe (advanced)
3232
- load_smooth_quant_recipe (advanced)
3333

@@ -37,11 +37,11 @@ insert_smooth_quant_observer_(model, alpha=0.5, quant_mode="dynamic")
3737
```
3838
After insertion, run the model for calibration on a certain dataset or (advanced) load a recipe.
3939

40-
`smooth_quant` applies SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
40+
`SmoothQuantConfig` configures appliying SmoothQuant to each linear layer of the model. Use it by calling `torchao.quantization.quantize_`. For example:
4141
```python
4242
from torchao.prototype.smoothquant import SmoothQuantObservedLinear
4343
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
44-
torchao.quantization.quantize_(model, smooth_quant(), is_observed_linear)
44+
torchao.quantization.quantize_(model, SmoothQuantConfig(), is_observed_linear)
4545
```
4646
`is_observed_linear` is a filter so that we only quantize observed linear layers.
4747

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,15 @@
11
from .api import (
2+
SmoothQuantConfig,
23
insert_smooth_quant_observer_,
34
load_smooth_quant_recipe,
45
save_smooth_quant_recipe,
5-
smooth_quant,
66
)
77
from .core import SmoothQuantObservedLinear
88

99
__all__ = [
1010
"insert_smooth_quant_observer_",
1111
"load_smooth_quant_recipe",
1212
"save_smooth_quant_recipe",
13-
"smooth_quant",
13+
"SmoothQuantConfig",
1414
"SmoothQuantObservedLinear",
1515
]

torchao/prototype/smoothquant/api.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def recurse(module: torch.nn.Module, name: str = ""):
109109
wrapper = torch.nn.Sequential(module)
110110
quantize_(
111111
wrapper,
112-
smooth_quant(smoothing_factor, act_scales, wei_scales),
112+
SmoothQuantConfig(smoothing_factor, act_scales, wei_scales),
113113
is_observed_linear,
114114
)
115115
return wrapper[0]
@@ -165,10 +165,6 @@ class SmoothQuantConfig(AOBaseConfig):
165165
wei_scales: Optional[torch.Tensor] = None
166166

167167

168-
# for bc
169-
smooth_quant = SmoothQuantConfig
170-
171-
172168
@register_quantize_module_handler(SmoothQuantConfig)
173169
def _smooth_quant_transform(
174170
module: torch.nn.Module,
@@ -177,7 +173,6 @@ def _smooth_quant_transform(
177173
smoothing_factor = config.smoothing_factor
178174
act_scales = config.act_scales
179175
wei_scales = config.wei_scales
180-
# weight = module.weight
181176
observed_linear = module
182177

183178
linear = torch.nn.Linear(
@@ -187,11 +182,7 @@ def _smooth_quant_transform(
187182
device=observed_linear.weight.device,
188183
dtype=observed_linear.weight.dtype,
189184
)
190-
# linear.weight = torch.nn.Parameter(
191-
# constructor(observed_linear), requires_grad=False
192-
# )
193185
linear.bias = observed_linear.bias
194-
# return linear
195186

196187
target_dtype = torch.int8
197188
# act_scales is None for dynamic quantization thus not checked

torchao/prototype/smoothquant/example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@
99
from transformers import AutoModelForCausalLM, AutoTokenizer
1010

1111
from torchao.prototype.smoothquant import (
12+
SmoothQuantConfig,
1213
SmoothQuantObservedLinear,
1314
insert_smooth_quant_observer_,
14-
smooth_quant,
1515
)
1616
from torchao.quantization import quantize_
1717

@@ -145,7 +145,7 @@ def wikitext2_ppl(
145145
is_observed_linear = lambda m, fqn: isinstance(m, SmoothQuantObservedLinear)
146146
print(f"running SmoothQuant with {quant_mode} quantization")
147147
t0 = time.time()
148-
quantize_(model, smooth_quant(), is_observed_linear)
148+
quantize_(model, SmoothQuantConfig(), is_observed_linear)
149149
print(f"time for quantization: {time.time() - t0:.02f} seconds")
150150
if model_save_path is not None:
151151
print(f"Saving quantized model to {model_save_path}")

0 commit comments

Comments
 (0)