Skip to content

Commit 90502df

Browse files
committed
update
1 parent 06b211c commit 90502df

File tree

2 files changed

+24
-12
lines changed

2 files changed

+24
-12
lines changed

torchao/quantization/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
float8_static_activation_float8_weight,
7070
float8_weight_only,
7171
fpx_weight_only,
72+
fqn_matches_fqn_config,
7273
gemlite_uintx_weight_only,
7374
int4_dynamic_activation_int4_weight,
7475
int4_weight_only,
@@ -221,4 +222,6 @@
221222
"Int4WeightOnlyGPTQQuantizer",
222223
"MultiTensor",
223224
"MultiTensorInputRecorder",
225+
# helper functions
226+
"fqn_matches_fqn_config",
224227
]

torchao/quantization/quant_api.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@
160160
"Int8DynActInt4WeightQuantizer",
161161
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
162162
"ModuleFqnToConfig",
163+
"FqnToConfig",
163164
]
164165

165166
LAYOUT_TO_ZERO_POINT_DOMAIN = {
@@ -479,7 +480,7 @@ def quantize_(
479480

480481
for module_fqn, module in model.named_modules():
481482
if (
482-
_fqn_matches_fqn_config(module_fqn, config)
483+
fqn_matches_fqn_config(module_fqn, config)
483484
or _module_param_matches_fqn_config(module, module_fqn, config)
484485
or ("_default" in config.fqn_to_config and _is_linear(module))
485486
):
@@ -1254,17 +1255,22 @@ def _int4_weight_only_quantize_tensor(weight, config):
12541255

12551256
@register_quantize_module_handler(Int4WeightOnlyConfig)
12561257
def _int4_weight_only_transform(
1257-
module: torch.nn.Module, config: Int4WeightOnlyConfig
1258+
module: torch.nn.Module,
1259+
config: Int4WeightOnlyConfig,
1260+
*,
1261+
parameter_name: str = "weight",
12581262
) -> torch.nn.Module:
12591263
if config.set_inductor_config:
12601264
torchao.quantization.utils.recommended_inductor_config_setter()
12611265

1262-
assert hasattr(module, "weight"), (
1263-
"applying int8 weight only quant requires module to have weight attribute"
1266+
assert hasattr(module, parameter_name), (
1267+
"applying int8 weight only quant requires module to have {parameter_name} attribute"
12641268
+ " but {module} does not have one"
12651269
)
1266-
new_weight = _int4_weight_only_quantize_tensor(module.weight, config)
1267-
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
1270+
new_weight = _int4_weight_only_quantize_tensor(
1271+
getattr(module, parameter_name), config
1272+
)
1273+
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))
12681274
module.extra_repr = types.MethodType(_linear_extra_repr, module)
12691275
return module
12701276

@@ -2298,18 +2304,19 @@ def _intx_weight_only_transform(
22982304
*,
22992305
custom_scale: Optional[torch.Tensor] = None,
23002306
custom_zero_point: Optional[torch.Tensor] = None,
2307+
parameter_name="weight",
23012308
) -> torch.nn.Module:
2302-
assert hasattr(module, "weight"), (
2303-
"applying intx weight only quant requires module to have weight attribute"
2309+
assert hasattr(module, parameter_name), (
2310+
"applying intx weight only quant requires module to have {parameter_name} attribute"
23042311
+ " but {module} does not have one"
23052312
)
23062313
new_weight = _intx_weight_only_quantize_tensor(
2307-
module.weight,
2314+
getattr(module, parameter_name),
23082315
config,
23092316
custom_scale=custom_scale,
23102317
custom_zero_point=custom_zero_point,
23112318
)
2312-
module.weight = torch.nn.Parameter(new_weight, requires_grad=False)
2319+
setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False))
23132320

23142321
if isinstance(module, nn.Linear):
23152322
module.extra_repr = types.MethodType(_linear_extra_repr, module)
@@ -2446,6 +2453,8 @@ def __post_init__(self):
24462453
Float8DynamicActivationFloat8WeightConfig,
24472454
Float8WeightOnlyConfig,
24482455
Int8WeightOnlyConfig,
2456+
Int4WeightOnlyConfig,
2457+
IntxWeightOnlyConfig,
24492458
}
24502459

24512460

@@ -2541,7 +2550,7 @@ def _fqn_to_config_handler(
25412550
return module
25422551

25432552

2544-
def _fqn_matches_fqn_config(
2553+
def fqn_matches_fqn_config(
25452554
fqn: str,
25462555
config: FqnToConfig,
25472556
):
@@ -2586,7 +2595,7 @@ def _module_param_matches_fqn_config(
25862595
for name, param in module.named_parameters():
25872596
if name in dir(module):
25882597
parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name
2589-
if _fqn_matches_fqn_config(parameter_fqn, config):
2598+
if fqn_matches_fqn_config(parameter_fqn, config):
25902599
return True
25912600

25922601
return False

0 commit comments

Comments
 (0)