Skip to content

Commit 7f16377

Browse files
committed
remove the set_inductor_config argument of quantize_.
Summary: Test Plan: ``` pytest test/quantization/test_quant_api.py -s -x -k test_workflow_e2e_numerics ``` Reviewers: Subscribers: Tasks: Tags: ghstack-source-id: e6694cc ghstack-comment-id: 2712016215 Pull Request resolved: #1865
1 parent 53be2a4 commit 7f16377

File tree

6 files changed

+70
-31
lines changed

6 files changed

+70
-31
lines changed

test/integration/test_integration.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -113,7 +113,7 @@
113113

114114
def _int8wo_api(mod):
115115
if TORCH_VERSION_AT_LEAST_2_4:
116-
quantize_(mod, int8_weight_only(), set_inductor_config=False)
116+
quantize_(mod, int8_weight_only(set_inductor_config=False))
117117
if not TORCH_VERSION_AT_LEAST_2_5 or (
118118
not TORCH_VERSION_AT_LEAST_2_6 and torch._inductor.config.freezing
119119
):
@@ -124,7 +124,7 @@ def _int8wo_api(mod):
124124

125125
def _int8wo_groupwise_api(mod):
126126
group_size = 32
127-
quantize_(mod, int8_weight_only(group_size=group_size), set_inductor_config=False)
127+
quantize_(mod, int8_weight_only(group_size=group_size, set_inductor_config=False))
128128

129129

130130
def _int8da_int8w_api(
@@ -136,8 +136,8 @@ def _int8da_int8w_api(
136136
mod,
137137
int8_dynamic_activation_int8_weight(
138138
act_mapping_type=act_mapping_type,
139+
set_inductor_config=False,
139140
),
140-
set_inductor_config=False,
141141
)
142142
if not TORCH_VERSION_AT_LEAST_2_5:
143143
unwrap_tensor_subclass(mod)
@@ -152,20 +152,21 @@ def _int4wo_api(mod, use_hqq=False):
152152
):
153153
quantize_(
154154
mod,
155-
int4_weight_only(layout=Int4CPULayout(), use_hqq=use_hqq),
156-
set_inductor_config=False,
155+
int4_weight_only(
156+
layout=Int4CPULayout(), use_hqq=use_hqq, set_inductor_config=False
157+
),
157158
)
158159
unwrap_tensor_subclass(mod)
159160
elif TORCH_VERSION_AT_LEAST_2_4:
160-
quantize_(mod, int4_weight_only(), set_inductor_config=False)
161+
quantize_(mod, int4_weight_only(set_inductor_config=False))
161162
if not TORCH_VERSION_AT_LEAST_2_5:
162163
unwrap_tensor_subclass(mod)
163164
else:
164165
change_linear_weights_to_int4_woqtensors(mod)
165166

166167

167168
def _int8da_int4w_api(mod):
168-
quantize_(mod, int8_dynamic_activation_int4_weight(), set_inductor_config=False)
169+
quantize_(mod, int8_dynamic_activation_int4_weight(set_inductor_config=False))
169170
if not TORCH_VERSION_AT_LEAST_2_5:
170171
unwrap_tensor_subclass(mod)
171172

test/prototype/test_quantized_training.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,6 @@ def _reset():
4646
torch._dynamo.reset()
4747

4848

49-
# we always use `quantize_(set_inductor_config=False)` to reduce compile time in CI.
5049
class TestQuantizedTraining(TestCase):
5150
@parametrize("device", _DEVICES)
5251
def test_int8_stochastic_rounding(self, device):
@@ -81,7 +80,6 @@ def test_int8_weight_only_correctness(self, leading_dims, bias, device):
8180
quantize_(
8281
linear_int8,
8382
int8_weight_only_quantized_training(),
84-
set_inductor_config=False,
8583
)
8684
linear_fp32.weight.data = linear_int8.weight.data.dequantize()
8785

@@ -108,7 +106,6 @@ def test_int8_weight_only_compile(self, leading_dims, bias, device):
108106
quantize_(
109107
linear_eager,
110108
int8_weight_only_quantized_training(),
111-
set_inductor_config=False,
112109
)
113110
linear_compiled = copy.deepcopy(linear_eager)
114111
linear_compiled.compile()
@@ -145,9 +142,7 @@ def test_int8_weight_only_training(self, compile, device):
145142
nn.Linear(embed_dim * 2, n_classes),
146143
).to(device)
147144
model_int8 = copy.deepcopy(model_fp32)
148-
quantize_(
149-
model_int8, int8_weight_only_quantized_training(), set_inductor_config=False
150-
)
145+
quantize_(model_int8, int8_weight_only_quantized_training())
151146

152147
if compile:
153148
model_fp32.compile()
@@ -195,7 +190,7 @@ def test_int8_mixed_precision_training(self, compile, config, module_swap):
195190
linear_int8mp = copy.deepcopy(linear)
196191
config.module_swap = module_swap
197192
apply_func = int8_mixed_precision_training(config)
198-
quantize_(linear_int8mp, apply_func, set_inductor_config=False)
193+
quantize_(linear_int8mp, apply_func)
199194

200195
if compile:
201196
linear.compile()
@@ -255,7 +250,7 @@ def forward(self, x):
255250
nn.Linear(embed_dim, embed_dim),
256251
).to(device)
257252
model = copy.deepcopy(model_ref)
258-
quantize_(model, bitnet_training(), set_inductor_config=False)
253+
quantize_(model, bitnet_training())
259254

260255
# change model_ref to use BitLinear
261256
model_ref[0].__class__ = BitLinear
@@ -346,8 +341,8 @@ def _run_subtest(self, args):
346341
base_model = Transformer(model_args).cuda()
347342
fsdp_model = copy.deepcopy(base_model)
348343

349-
quantize_(base_model.layers, quantize_fn, set_inductor_config=False)
350-
quantize_(fsdp_model.layers, quantize_fn, set_inductor_config=False)
344+
quantize_(base_model.layers, quantize_fn)
345+
quantize_(fsdp_model.layers, quantize_fn)
351346

352347
for layer in fsdp_model.layers:
353348
fully_shard(layer, mp_policy=mp_policy)

torchao/prototype/awq/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
import torchao
67
from torchao.core.config import AOBaseConfig
78
from torchao.dtypes import (
89
TensorCoreTiledLayout,
@@ -101,11 +102,13 @@ class AWQUIntXConfig(AOBaseConfig):
101102
quant_dtype: The data type of the quantized weights. Currently only torch.uint4 is intended to be used but can be used with torch.uint1 -> torch.uint8
102103
group_size: Quantization granularity. Use -1 for channel wise quantization
103104
weight_quant_fn: The quantization function to be used, which takes in the weight and returns the quantized weight. If None, then affine uint4 quantization is used
105+
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
104106
"""
105107

106108
quant_dtype: torch.dtype = torch.uint4
107109
group_size: int = 64
108110
use_hqq: bool = False
111+
set_inductor_config: bool = True
109112

110113

111114
# for bc
@@ -120,6 +123,8 @@ def _awq_uintx_transform(
120123
quant_dtype = config.quant_dtype
121124
group_size = config.group_size
122125
use_hqq = config.use_hqq
126+
if config.set_inductor_config:
127+
torchao.quantization.utils.recommended_inductor_config_setter()
123128
observed_linear = module
124129

125130
assert (

torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import torch
44

5+
import torchao
56
from torchao.core.config import AOBaseConfig
67
from torchao.quantization.quant_primitives import (
78
MappingType,
@@ -18,6 +19,7 @@ class IntNWeightOnlyConfig(AOBaseConfig):
1819
Args:
1920
`group_size`: parameter for quantization, controls the granularity of quantization, smaller size is more fine grained, choices are [512, 256, 128, 64, 32]
2021
`n`: number of bits to quantize to, choices are [8, 6, 5, 4, 3, 2]
22+
`set_inductor_config`: if True, adjusts `torchinductor` settings to recommended values.
2123
Usage:
2224
from torchao.quantization import quantize_
2325
quantize_(model, intN_weight_only(n=your_bit_choice, group_size=group_size), optional_filter_func_for_desired_layers_to_quantize)
@@ -26,6 +28,7 @@ class IntNWeightOnlyConfig(AOBaseConfig):
2628
group_size: int = 32
2729
n: int = 8
2830
symmetric: bool = False
31+
set_inductor_config: bool = True
2932

3033

3134
# for bc
@@ -41,6 +44,8 @@ def _intN_weight_only_transform(
4144
n = config.n
4245
symmetric = config.symmetric
4346
weight = module.weight
47+
if config.set_inductor_config:
48+
torchao.quantization.utils.recommended_inductor_config_setter()
4449

4550
# for asymmetric quantization
4651
def apply_intN_weight_only_quant_asym(weight):

torchao/prototype/smoothquant/api.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import torch
66

7+
import torchao
78
from torchao.core.config import AOBaseConfig
89
from torchao.dtypes import to_affine_quantized_intx, to_affine_quantized_intx_static
910
from torchao.prototype.smoothquant.core import (
@@ -158,11 +159,13 @@ class SmoothQuantConfig(AOBaseConfig):
158159
smoothing_factor: The smoothing factor for the layer. Acquired from the layer's observer if None.
159160
act_scales: The activation scales for the layer. Acquired from the layer's observer if None.
160161
wei_scales: The weight scales for the layer. Acquired from the layer's observer if None.
162+
set_inductor_config: if True, adjusts `torchinductor` settings to recommended values.
161163
"""
162164

163165
smoothing_factor: Optional[torch.Tensor] = None
164166
act_scales: Optional[torch.Tensor] = None
165167
wei_scales: Optional[torch.Tensor] = None
168+
set_inductor_config: bool = True
166169

167170

168171
@register_quantize_module_handler(SmoothQuantConfig)
@@ -173,6 +176,8 @@ def _smooth_quant_transform(
173176
smoothing_factor = config.smoothing_factor
174177
act_scales = config.act_scales
175178
wei_scales = config.wei_scales
179+
if config.set_inductor_config:
180+
torchao.quantization.utils.recommended_inductor_config_setter()
176181
observed_linear = module
177182

178183
linear = torch.nn.Linear(

0 commit comments

Comments
 (0)