|
160 | 160 | "Int8DynActInt4WeightQuantizer", |
161 | 161 | "Float8DynamicActivationFloat8SemiSparseWeightConfig", |
162 | 162 | "ModuleFqnToConfig", |
| 163 | + "FqnToConfig", |
163 | 164 | ] |
164 | 165 |
|
165 | 166 | LAYOUT_TO_ZERO_POINT_DOMAIN = { |
@@ -479,7 +480,7 @@ def quantize_( |
479 | 480 |
|
480 | 481 | for module_fqn, module in model.named_modules(): |
481 | 482 | if ( |
482 | | - _fqn_matches_fqn_config(module_fqn, config) |
| 483 | + fqn_matches_fqn_config(module_fqn, config) |
483 | 484 | or _module_param_matches_fqn_config(module, module_fqn, config) |
484 | 485 | or ("_default" in config.fqn_to_config and _is_linear(module)) |
485 | 486 | ): |
@@ -1254,17 +1255,22 @@ def _int4_weight_only_quantize_tensor(weight, config): |
1254 | 1255 |
|
1255 | 1256 | @register_quantize_module_handler(Int4WeightOnlyConfig) |
1256 | 1257 | def _int4_weight_only_transform( |
1257 | | - module: torch.nn.Module, config: Int4WeightOnlyConfig |
| 1258 | + module: torch.nn.Module, |
| 1259 | + config: Int4WeightOnlyConfig, |
| 1260 | + *, |
| 1261 | + parameter_name: str = "weight", |
1258 | 1262 | ) -> torch.nn.Module: |
1259 | 1263 | if config.set_inductor_config: |
1260 | 1264 | torchao.quantization.utils.recommended_inductor_config_setter() |
1261 | 1265 |
|
1262 | | - assert hasattr(module, "weight"), ( |
1263 | | - "applying int8 weight only quant requires module to have weight attribute" |
| 1266 | + assert hasattr(module, parameter_name), ( |
| 1267 | + "applying int8 weight only quant requires module to have {parameter_name} attribute" |
1264 | 1268 | + " but {module} does not have one" |
1265 | 1269 | ) |
1266 | | - new_weight = _int4_weight_only_quantize_tensor(module.weight, config) |
1267 | | - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) |
| 1270 | + new_weight = _int4_weight_only_quantize_tensor( |
| 1271 | + getattr(module, parameter_name), config |
| 1272 | + ) |
| 1273 | + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) |
1268 | 1274 | module.extra_repr = types.MethodType(_linear_extra_repr, module) |
1269 | 1275 | return module |
1270 | 1276 |
|
@@ -2298,18 +2304,19 @@ def _intx_weight_only_transform( |
2298 | 2304 | *, |
2299 | 2305 | custom_scale: Optional[torch.Tensor] = None, |
2300 | 2306 | custom_zero_point: Optional[torch.Tensor] = None, |
| 2307 | + parameter_name="weight", |
2301 | 2308 | ) -> torch.nn.Module: |
2302 | | - assert hasattr(module, "weight"), ( |
2303 | | - "applying intx weight only quant requires module to have weight attribute" |
| 2309 | + assert hasattr(module, parameter_name), ( |
| 2310 | + "applying intx weight only quant requires module to have {parameter_name} attribute" |
2304 | 2311 | + " but {module} does not have one" |
2305 | 2312 | ) |
2306 | 2313 | new_weight = _intx_weight_only_quantize_tensor( |
2307 | | - module.weight, |
| 2314 | + getattr(module, parameter_name), |
2308 | 2315 | config, |
2309 | 2316 | custom_scale=custom_scale, |
2310 | 2317 | custom_zero_point=custom_zero_point, |
2311 | 2318 | ) |
2312 | | - module.weight = torch.nn.Parameter(new_weight, requires_grad=False) |
| 2319 | + setattr(module, parameter_name, torch.nn.Parameter(new_weight, requires_grad=False)) |
2313 | 2320 |
|
2314 | 2321 | if isinstance(module, nn.Linear): |
2315 | 2322 | module.extra_repr = types.MethodType(_linear_extra_repr, module) |
@@ -2446,6 +2453,8 @@ def __post_init__(self): |
2446 | 2453 | Float8DynamicActivationFloat8WeightConfig, |
2447 | 2454 | Float8WeightOnlyConfig, |
2448 | 2455 | Int8WeightOnlyConfig, |
| 2456 | + Int4WeightOnlyConfig, |
| 2457 | + IntxWeightOnlyConfig, |
2449 | 2458 | } |
2450 | 2459 |
|
2451 | 2460 |
|
@@ -2541,7 +2550,7 @@ def _fqn_to_config_handler( |
2541 | 2550 | return module |
2542 | 2551 |
|
2543 | 2552 |
|
2544 | | -def _fqn_matches_fqn_config( |
| 2553 | +def fqn_matches_fqn_config( |
2545 | 2554 | fqn: str, |
2546 | 2555 | config: FqnToConfig, |
2547 | 2556 | ): |
@@ -2586,7 +2595,7 @@ def _module_param_matches_fqn_config( |
2586 | 2595 | for name, param in module.named_parameters(): |
2587 | 2596 | if name in dir(module): |
2588 | 2597 | parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name |
2589 | | - if _fqn_matches_fqn_config(parameter_fqn, config): |
| 2598 | + if fqn_matches_fqn_config(parameter_fqn, config): |
2590 | 2599 | return True |
2591 | 2600 |
|
2592 | 2601 | return False |
|
0 commit comments