Description
Currently torchao QAT has two APIs, tensor subclasses and module swap. The original plan was to deprecate and eventually remove the old module swap API in favor of the tensor subclass API. However, users are starting to rely on the module API for production uses due to gaps in the tensor subclass API. In this RFC, we discuss the few long term plans for these two APIs in torchao.
API Today
We use a quantizer API today to abstract the implementation details from the user. Currently we support both tensor subclass and module swap APIs using different quantizers:
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer
from torchao.quantization.prototype.qat._module_swap_api import Int8DynActInt4WeightQATQuantizerModuleSwap
# tensor subclass version
qat_quantizer = Int8DynActInt4WeightQATQuantizer()
# module swap version
qat_quantizer = Int8DynActInt4WeightQATQuantizerModuleSwap()
# prepare: inserts fake quantizes but keeps everything in bf16
fake_quantized_model = qat_quantizer.prepare(model)
# train or fine-tune as before
train(fake_quantized_model)
# convert: actually quantizes the model to lower bit-widths
quantized_model = qat_quantizer.convert(fake_quantized_model)
Module Swap vs Tensor Subclass
Although tensor subclasses are generally adopted in torchao, the main gap today are (1) the lack of general distributed support, and (2) steep learning curve. For these two reasons, some users prefer the module swap flow, and have begun implementing new features in this flow, such as embedding quantization and static QAT.
To summarize the pros and cons of both approaches:
Tensor subclass | Module swap | |
---|---|---|
Consistency | ✓ The rest of torchao uses tensor subclasses, including PTQ, quantized training, float8 inference, and sparsity. These all use the same quantize_ API. | ✖ Diverges from PTQ flow (not necessarily a con) |
Composability | ✓ Better composability with other tensor subclasses such as DTensor and NestedJaggedTensor | ✖ Pure module swap misses out on potential composability benefits with other tensor subclasses (no clear benefits for QAT today) |
Distributed support | ✖ Currently only supports FSDP2. Internal implementation of each distributed strategy is exposed to the subclass. There are problems with how tensor subclasses interact with FSDP1 and DDP. Fixing these is not a priority for the distributed team. | ✓ Works with any distribution strategy, including non-PyTorch ones like FAIR FSDP |
Developer experience | ✖ Steep learning curve, difficult for new users to extend, confusing error messages | ✓ Easy to understand and extend. Supports module-level features like range learning |
We can separate tensor subclass usage into two categories:
- Injection. This refers to how we insert fake quantization logic into the model. For example, tensor subclass injection means we look for
nn.Linear
modules and swap out the weight tensor, while module swap injection means we look fornn.Linear
modules and swap out the whole module with our customQATLinear
. Today, the tensor subclass flow in torchao uses the former, while the module swap flow uses the latter. - Fake quantization implementation. This refers to how we represent fake quantization during training. We can use our custom
AffineFakeQuantizedTensor
to encode the desired fake quantization configurations, or we can use plaintorch.Tensor
. - We can combine these two in the same flow. For example, use module swap for injection and tensor subclass for data representation.
Long Term Flow
We propose to use module swap for injection and tensor subclass for implementing fake quantization in the long term. This has the following pros and cons compared to the alternatives:
- ✓ Single QAT flow in torchao
- ✓ Lower bar of entry; new users can continue to contribute features quickly
- ✓ Consistent with float8 training in torchao
- ✓ Composes well with other tensor subclasses like DTensor (e.g. cast to int8 before all-gather)
- ✖ Need additional work to support all distributed strategies
Note: In the short term, we will continue to use plain torch.Tensor
s for fake quantization due to the lack of general distributed support for tensor subclasses. The distributed strategies we should support before migrating to the long term flow include DDP and FSDP1. Additionally, we should migrate only if tensor subclass composability provides meaningful performance benefits, such as faster fake quantization through efficient int8 kernels.