Description
summary
This issue tracks the migration of quantize_
per-workflow configuration from Callables to configs..
We are migrating the way quantize_
workflows are configured from callables (tensor subclass inserters) to direct configuration (config objects). Motivation: align with the rest of the ecosystem, enable inspection of configs after instantiation, remove a common source of confusion.
What is changing:
Specifically, here is how the signature of quantize_
's second argument will change:
#
# torchao v0.8.0 and before
#
def quantize(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
...,
): ...
#
# torchao v0.9.0
#
def quantize(
model: torch.nn.Module,
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
...,
): ...
#
# torchao v0.10.0 or later (exact version TBD)
#
def quantize(
model: torch.nn.Module,
config: AOBaseConfig,
...,
): ...
- the name of the second argument to
quantize_
changed fromapply_tensor_subclass
toconfig
. Since the vast majority of callsites today are passing in configuration with a positional argument, this change should not affect most people. - the type of the second argument to
quantize_
will change fromCallable[[torch.nn.Module], torch.nn.Module]
toconfig: AOBaseConfig
, following a deprecation process detailed below. - for individual workflows, the user facing API name changed from snake case (
int8_weight_only
) to camel case (Int8WeightOnlyConfig
). All argument names for each config are kept as-is. We will keep the old snake case names (int8_weight_only
) around and alias them to the new names (int8_weight_only = Int8WeightOnlyConfig
), to avoid breaking callsites. We plan to keep the old names forever. Here are all the workflow config name changes:
old name (will keep working) | new name (recommended) |
---|---|
int4_weight_only |
Int4WeightOnlyConfig |
float8_dynamic_activation_float8_weight |
Float8DynamicQuantizationFloat8WeightConfig |
float8_static_activation_float8_weight |
Float8StaticActivationFloat8WeightConfig |
float8_weight_only |
Float8WeightOnlyConfig |
fpx_weight_only |
FPXWeightOnlyConfig |
gemlite_uintx_weight_only |
GemliteUIntXWeightOnlyConfig |
int4_dynamic_activation_int4_weight |
Int4DynamicActivationInt4WeightConfig |
int8_dynamic_activation_int4_weight |
Int8DynamicActivationInt4WeightConfig |
int8_dynamic_activation_int8_semi_sparse_weight |
n/a (deprecated) |
int8_dynamic_activation_int8_weight |
Int8DynamicActivationInt8WeightConfig |
int8_weight_only |
Int8WeightOnlyConfig |
uintx_weight_only |
UIntXWeightOnlyConfig |
Configuration for prototype workflows using quantize_
will be migrated at a later time. sparsify_
will be migrated in a similar fashion at a later time.
How these changes can affect you:
- If you are a user of existing
quantize_
API workflows and are passing in config by a positional argument (quantize_(model, int8_weight_only(group_size=128))
), you are not affected. This syntax will keep working going forward. You have the option to migrate your callsite to the new config name (quantize_(model, Int8WeightOnlyConfig(group_size=128))
at your own pace. - If you are a user of existing
quantize_
API workflows and are passing in config by a keyword argument (quantize_(model, tensor_subclass_inserter=int8_weight_only(group_size=128))
), your callsite will break. You will need to change your callsite toquantize_(model, config=int8_weight_only(group_size=128))
. We don't expect many people to be in this bucket. - If you are a developer writing new workflows for the
quantize_
API, you will need to use the new configuration system. Please see migration ofquantize_
workflow configuration from callables to configs #1690 for details. - If you are a user of
sparsify_
, you are not affected for now and a similar change will happen in a future version of torchao.
This migration will be a two step process:
- in torchao v0.9.0, we will enable the new syntax while starting the deprecation process for the old syntax.
- in torchao v.0.10.0 or later, we will remove the old syntax
We will keep the old callable syntax supported by quantize_
for one release cycle, and delete it afterwards. We will keep the old names as aliases for new names going forward (example: int4_weight_only
as an alias of Int4WeightOnlyConfig
) to keep existing callsites working without changes.
impact on API users
If you are just using the torchao quantize_
API as specified in the README, this is not BC-breaking. For example, the following syntax will keep working.
quantize_(model, int8_weight_only())
Note that the type of the object created by int8_weight_only()
will change from a Callable to a config. You have the option to migrate to the explicit config creation, as follows:
quantize_(model, Int8WeightOnlyConfig())
user facing API changes
signature of quantize_
#
# before
#
def quantize(
model: torch.nn.Module,
apply_tensor_subclass: Callable[[torch.nn.Module], torch.nn.Module],
...,
): ...
#
# after - intermediate state, support both old and new for one release
#
def quantize(
model: torch.nn.Module,
config: Union[AOBaseConfig, Callable[[torch.nn.Module], torch.nn.Module]],
...,
): ...
#
# after - long term state
#
def quantize(
model: torch.nn.Module,
config: AOBaseConfig,
...,
): ...
usage example
An example for int4_weight_only
#
# before
#
quantize_(m, int4_weight_only(group_size=32))
#
# after, with new user facing names
#
quantize_(m, Int4WeightOnlyConfig(group_size=32))
#
# AND, after, with BC names
#
quantize_(m, int4_weight_only(group_size=32))
developer facing changes
See the PR details for examples, but they can be summarized as:
#
# old
#
# quantize_ calls the instance of calling this function on each module of the model
def int4_weight_only(group_size: int, ...) -> Callable:
def new_callable(weight: torch.Tensor):
# configuration is captured here via local variables
...
# return type is a Callable
return _get_linear_subclass_inserter(new_callable)
#
# new
#
# config base class
class AOBaseConfig(abc.ABC):
pass
# user facing configuration of a workflow
@dataclass
class Int4WeightOnlyConfig(AOBaseConfig):
group_size: int = 128
...
# not user facing transform of a module according to a worfklow's configuration
@register_quantize_module_handler(Int4WeightOnlyConfig)
def _int4_weight_only_transform(
module: torch.nn.Module,
config: Int4WeightOnlyConfig,
) -> torch.nn.Module:
# map to AQT, not user facing
...
migration status
quantize_ non-prototype workflow configuration
- int4_weight_only - [bc-breaking] enable direct configuration in quantize_ #1595
- qat: intx_quantization_aware_training - [bc-breaking] enable direct configuration in quantize_ #1595
- qat: from_intx_quantization_aware_training - [bc-breaking] enable direct configuration in quantize_ #1595
- float8_dynamic_activation_float8_weight - config migration: float8* #1694
- float8_static_activation_float8_weight - config migration: float8* #1694
- float8_weight_only - config migration: float8* #1694
- fpx_weight_only - config migration: fpx, gemlite, uintx #1697
- gemlite_uintx_weight_only - config migration: fpx, gemlite, uintx #1697
- int4_dynamic_activation_int4_weight - config migration: int* #1696
- int8_dynamic_activation_int4_weight - config migration: int* #1696
- int8_dynamic_activation_int8_semi_sparse_weight - marked for deprecation, confirmed with @jcaip we can delete this
- int8_dynamic_activation_int8_weight - config migration: int* #1696
- int8_weight_only - config migration: int* #1696
- uintx_weight_only - config migration: fpx, gemlite, uintx #1697
- nf4 - migrate nf4 to configs #1857
quantize_ prototype workflow configuration
Grep for callsites:
grep -r "quantize_(" torchao/prototype
- smoothquant: config migration: smoothquant #1851
- sparsity/superblock: using already migrated configs
- autoround: add assertion error about config migration to prototype/autoround #1852 (skipped)
- awq: migrate prototype/awq to configs #1853
- quantization/mixed_precision/scripts/mp_quant_eval: migrates prototype/mixed_precision to configs #1854
- quantization/mixed_precision/scripts/naive_intNwo: migrates prototype/mixed_precision to configs #1854
- quantization/mixed_precision/scripts/utils: migrates prototype/mixed_precision to configs #1854
- quantized_training: migrate prototype/quantized_training to configs #1855
- parq - the
quantize_
used here is a different function, so nothing to do - codebook: migrate prototype codebook quant to configs #1858
experimental
sparsify_
- everything: migrate
sparsify_
to configs #1856
tutorials (replace with new registration API)
- calibration_flow/static_quant.py - migrate static quant tutorials to direct configuration #1710
- calibration_flow/gptq_like.py - migrate static quant tutorials to direct configuration #1710
- calibration_flow/awq_like.py - migrate static quant tutorials to direct configuration #1710
replace docblocks and public facing descriptions with new names
- README.md - update torchao READMEs with new configuration APIs #1711
- QAT README.md - update torchao READMEs with new configuration APIs #1711
- quantization README.md - update torchao READMEs with new configuration APIs #1711
verify partner integrations still work
- HF callsite: https://github.com/huggingface/transformers/blob/1feebb5b4150882deabddd190a541f336f3be817/src/transformers/quantizers/quantizer_torchao.py#L199
- SGLANG callsite: https://github.com/sgl-project/sglang/blob/2f47d710ae9cb1bdbbe0fe2392a0634827d257b3/python/sglang/srt/layers/torchao_utils.py#L39
- Diffusers callsite: https://github.com/huggingface/diffusers/blob/7fb481f840b5d73982cafd1affe89f21a5c0b20b/src/diffusers/quantizers/torchao/torchao_quantizer.py#L234
confirmed two out of three here: vkuzo/pytorch_scripts#28