Skip to content

Commit 60d63a6

Browse files
authored
Rename AOPerModuleConfig to ModuleFqnToConfig (#2243)
* Rename AOPerModuleConfig to ModuleFqnToConfig Summary: to be more explicit on what this config means Test Plan: CI Reviewers: Subscribers: Tasks: Tags: * renaming rest
1 parent efac465 commit 60d63a6

File tree

4 files changed

+21
-20
lines changed

4 files changed

+21
-20
lines changed

test/quantization/test_config_serialization.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
config_to_dict,
2121
)
2222
from torchao.quantization.quant_api import (
23-
AOPerModuleConfig,
2423
Float8DynamicActivationFloat8WeightConfig,
2524
Float8WeightOnlyConfig,
2625
FPXWeightOnlyConfig,
@@ -30,6 +29,7 @@
3029
Int8DynamicActivationInt4WeightConfig,
3130
Int8DynamicActivationInt8WeightConfig,
3231
Int8WeightOnlyConfig,
32+
ModuleFqnToConfig,
3333
PerRow,
3434
UIntXWeightOnlyConfig,
3535
)
@@ -68,9 +68,9 @@
6868
# Sparsity configs
6969
SemiSparseWeightConfig(),
7070
BlockSparseWeightConfig(blocksize=128),
71-
AOPerModuleConfig({}),
72-
AOPerModuleConfig({"_default": Int4WeightOnlyConfig(), "linear1": None}),
73-
AOPerModuleConfig(
71+
ModuleFqnToConfig({}),
72+
ModuleFqnToConfig({"_default": Int4WeightOnlyConfig(), "linear1": None}),
73+
ModuleFqnToConfig(
7474
{
7575
"linear1": Int4WeightOnlyConfig(),
7676
"linear2": Int8DynamicActivationInt4WeightConfig(),

test/quantization/test_quant_api.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@
3838
PerGroup,
3939
)
4040
from torchao.quantization.quant_api import (
41-
AOPerModuleConfig,
4241
Int4WeightOnlyConfig,
4342
Int8DynamicActivationInt4WeightConfig,
4443
Int8WeightOnlyConfig,
4544
IntxWeightOnlyConfig,
45+
ModuleFqnToConfig,
4646
Quantizer,
4747
TwoStepQuantizer,
4848
_replace_with_custom_fn_if_matches_filter,
@@ -946,10 +946,10 @@ def test_workflow_e2e_numerics(self, config):
946946
assert sqnr >= 16.5, f"SQNR {sqnr} is too low"
947947

948948
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
949-
def test_ao_per_module_config_default(self):
949+
def test_module_fqn_to_config_default(self):
950950
config1 = Int4WeightOnlyConfig(group_size=32)
951951
config2 = Int8WeightOnlyConfig()
952-
config = AOPerModuleConfig({"_default": config1, "linear2": config2})
952+
config = ModuleFqnToConfig({"_default": config1, "linear2": config2})
953953
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
954954
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
955955
quantize_(model, config)
@@ -960,10 +960,10 @@ def test_ao_per_module_config_default(self):
960960
assert isinstance(model.linear2.weight._layout, PlainLayout)
961961

962962
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
963-
def test_ao_per_module_config_module_name(self):
963+
def test_module_fqn_to_config_module_name(self):
964964
config1 = Int4WeightOnlyConfig(group_size=32)
965965
config2 = Int8WeightOnlyConfig()
966-
config = AOPerModuleConfig({"linear1": config1, "linear2": config2})
966+
config = ModuleFqnToConfig({"linear1": config1, "linear2": config2})
967967
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
968968
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
969969
quantize_(model, config)
@@ -974,7 +974,7 @@ def test_ao_per_module_config_module_name(self):
974974
assert isinstance(model.linear2.weight._layout, PlainLayout)
975975

976976
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_6, "Need torch 2.6+")
977-
def test_ao_per_module_config_embedding_linear(self):
977+
def test_module_fqn_to_config_embedding_linear(self):
978978
weight_dtype = torch.int8
979979
granularity = PerGroup(8)
980980
mapping_type = MappingType.SYMMETRIC
@@ -987,7 +987,7 @@ def test_ao_per_module_config_embedding_linear(self):
987987
# example model linear is Linear(16, 8)
988988
linear_config = Int8DynamicActivationInt4WeightConfig(group_size=16)
989989

990-
config = AOPerModuleConfig({"emb": embedding_config, "linear": linear_config})
990+
config = ModuleFqnToConfig({"emb": embedding_config, "linear": linear_config})
991991
indices = torch.randint(0, 10, (32,))
992992
indices = indices.unsqueeze(0)
993993
example_inputs = (indices,)
@@ -1006,9 +1006,9 @@ def test_ao_per_module_config_embedding_linear(self):
10061006
assert isinstance(model.linear.weight, LinearActivationQuantizedTensor)
10071007

10081008
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
1009-
def test_ao_per_module_config_skip(self):
1009+
def test_module_fqn_to_config_skip(self):
10101010
config1 = Int4WeightOnlyConfig(group_size=32)
1011-
config = AOPerModuleConfig({"_default": config1, "linear2": None})
1011+
config = ModuleFqnToConfig({"_default": config1, "linear2": None})
10121012
model = ToyLinearModel().cuda().to(dtype=torch.bfloat16)
10131013
example_inputs = model.example_inputs(device="cuda", dtype=torch.bfloat16)
10141014
quantize_(model, config)

torchao/quantization/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@
3939
AffineQuantizedObserverBase,
4040
)
4141
from .quant_api import (
42-
AOPerModuleConfig,
4342
CutlassInt4PackedLayout,
4443
Float8DynamicActivationFloat8SemiSparseWeightConfig,
4544
Float8DynamicActivationFloat8WeightConfig,
@@ -55,6 +54,7 @@
5554
Int8DynamicActivationIntxWeightConfig,
5655
Int8WeightOnlyConfig,
5756
IntxWeightOnlyConfig,
57+
ModuleFqnToConfig,
5858
PlainLayout,
5959
TensorCoreTiledLayout,
6060
UIntXWeightOnlyConfig,
@@ -147,7 +147,7 @@
147147
"IntxWeightOnlyConfig",
148148
"FPXWeightOnlyConfig",
149149
"GemliteUIntXWeightOnlyConfig",
150-
"AOPerModuleConfig",
150+
"ModuleFqnToConfig",
151151
# smooth quant - subject to change
152152
"get_scale",
153153
"SmoothFakeDynQuantMixin",

torchao/quantization/quant_api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@
136136
"Int8DynActInt4WeightQuantizer",
137137
"Int8DynActInt4WeightGPTQQuantizer",
138138
"Float8DynamicActivationFloat8SemiSparseWeightConfig",
139+
"ModuleFqnToConfig",
139140
]
140141

141142
LAYOUT_TO_ZERO_POINT_DOMAIN = {
@@ -596,10 +597,10 @@ def quantize_(
596597
"""
597598
filter_fn = _is_linear if filter_fn is None else filter_fn
598599

599-
if isinstance(config, AOPerModuleConfig):
600+
if isinstance(config, ModuleFqnToConfig):
600601
_replace_with_custom_fn_if_matches_filter_with_name(
601602
model,
602-
_ao_per_module_config_handler,
603+
_module_fqn_to_config_handler,
603604
filter_fn,
604605
device=device,
605606
extra_args=(config,),
@@ -2002,7 +2003,7 @@ def _fpx_weight_only_transform(
20022003

20032004

20042005
@dataclass
2005-
class AOPerModuleConfig(AOBaseConfig):
2006+
class ModuleFqnToConfig(AOBaseConfig):
20062007
"""Per module configurations for torchao quantize_ API
20072008
20082009
Args:
@@ -2018,8 +2019,8 @@ class AOPerModuleConfig(AOBaseConfig):
20182019
)
20192020

20202021

2021-
def _ao_per_module_config_handler(
2022-
module: torch.nn.Module, module_fqn: str, config: AOPerModuleConfig
2022+
def _module_fqn_to_config_handler(
2023+
module: torch.nn.Module, module_fqn: str, config: ModuleFqnToConfig
20232024
):
20242025
c = None
20252026
if module_fqn in config.module_fqn_to_config:

0 commit comments

Comments
 (0)