Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchao/quantization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
float8_static_activation_float8_weight,
float8_weight_only,
fpx_weight_only,
fqn_matches_fqn_config,
gemlite_uintx_weight_only,
int4_dynamic_activation_int4_weight,
int4_weight_only,
Expand Down Expand Up @@ -142,6 +143,7 @@
"float8_static_activation_float8_weight",
"uintx_weight_only",
"fpx_weight_only",
"fqn_matches_fqn_config",
"gemlite_uintx_weight_only",
"swap_conv2d_1x1_to_linear",
"Int4DynamicActivationInt4WeightConfig",
Expand Down
15 changes: 6 additions & 9 deletions torchao/quantization/quant_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,15 +480,17 @@ def quantize_(

for module_fqn, module in model.named_modules():
if (
_fqn_matches_fqn_config(module_fqn, config)
fqn_matches_fqn_config(module_fqn, config)
or _module_param_matches_fqn_config(module, module_fqn, config)
or ("_default" in config.fqn_to_config and _is_linear(module))
):
module_name = (
module_fqn.rsplit(".", 1) if "." in module_fqn else module_fqn
)
# this replaces inplace, so no need to reassign
_fqn_to_config_handler(module, module_name, config, device)
_fqn_to_config_handler(module, module_name, config)
if device is not None:
module.to(device=device)
return
if isinstance(config, AOBaseConfig):
filter_fn = _is_linear if filter_fn is None else filter_fn
Expand Down Expand Up @@ -2470,7 +2472,6 @@ def _fqn_to_config_handler(
module: torch.nn.Module,
fqn: str,
config: FqnToConfig,
device: Optional[torch.device] = None,
):
"""This function expects a module that either is specified in FqnToConfig or has a parameter that is specified in FqnToConfig.

Expand All @@ -2479,17 +2480,13 @@ def _fqn_to_config_handler(
fqn (str): The fully qualified name of the module containing the parameters.
config (FqnToConfig): Configuration object containing regex patterns / fqn mapped
to quantization configurations.
device (Optional[torch.device]): The device to move the module to as part of quantization

Returns:
torch.nn.Module: The modified module with quantized parameters.

Raises:
NotImplementedError: If the quantization configuration is not yet supported for parameter quantization.
"""
if device is not None:
module = module.to(device)

parameter_config_found = False
top_level_params = []
for i, (parameter_name, param) in enumerate(list(module.named_parameters())):
Expand Down Expand Up @@ -2563,7 +2560,7 @@ def _fqn_to_config_handler(
return module


def _fqn_matches_fqn_config(
def fqn_matches_fqn_config(
fqn: str,
config: FqnToConfig,
):
Expand Down Expand Up @@ -2608,7 +2605,7 @@ def _module_param_matches_fqn_config(
for name, param in module.named_parameters():
if name in dir(module):
parameter_fqn = f"{fqn}.{name}" if len(fqn) > 0 else name
if _fqn_matches_fqn_config(parameter_fqn, config):
if fqn_matches_fqn_config(parameter_fqn, config):
return True

return False
Expand Down
Loading