Skip to content

Commit 28d68e2

Browse files
authored
enforce AOBaseConfig type in quantize_'s config argument (#1861)
* Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned] * Update [ghstack-poisoned]
1 parent afeac2f commit 28d68e2

File tree

1 file changed

+3
-14
lines changed

1 file changed

+3
-14
lines changed

torchao/quantization/quant_api.py

Lines changed: 3 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -486,7 +486,7 @@ def insert_subclass(lin):
486486

487487
def quantize_(
488488
model: torch.nn.Module,
489-
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
489+
config: AOBaseConfig,
490490
filter_fn: Optional[Callable[[torch.nn.Module, str], bool]] = None,
491491
set_inductor_config: Optional[bool] = None,
492492
device: Optional[torch.types.Device] = None,
@@ -495,7 +495,7 @@ def quantize_(
495495
496496
Args:
497497
model (torch.nn.Module): input model
498-
config (Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]]): either (1) a workflow configuration object or (2) a function that applies tensor subclass conversion to the weight of a module and return the module (e.g. convert the weight tensor of linear to affine quantized tensor). Note: (2) will be deleted in a future release.
498+
config (AOBaseConfig): a workflow configuration object.
499499
filter_fn (Optional[Callable[[torch.nn.Module, str], bool]]): function that takes a nn.Module instance and fully qualified name of the module, returns True if we want to run `config` on
500500
the weight of the module
501501
set_inductor_config (bool, optional): Whether to automatically use recommended inductor config settings (defaults to None)
@@ -546,21 +546,10 @@ def quantize_(
546546
)
547547

548548
else:
549-
# old behavior, keep to avoid breaking BC
550-
warnings.warn(
549+
raise AssertionError(
551550
"""Passing a generic Callable to `quantize_` is no longer recommended and will be deprecated at a later release. Please see https://github.com/pytorch/ao/issues/1690 for instructions on how to pass in workflow configuration instead."""
552551
)
553552

554-
# make the variable name make sense
555-
apply_tensor_subclass = config
556-
557-
_replace_with_custom_fn_if_matches_filter(
558-
model,
559-
apply_tensor_subclass,
560-
_is_linear if filter_fn is None else filter_fn,
561-
device=device,
562-
)
563-
564553

565554
def _int8_asymm_per_token_quant(x: torch.Tensor) -> torch.Tensor:
566555
"""This is defined here instead of local function to support serialization"""

0 commit comments

Comments
 (0)