Skip to content

Commit b5fe5a7

Browse files
committed
Update on "Add generic fake quantized linear for QAT"
**Summary:** This commit adds a generic fake quantized linear module to replace the uses of the existing more specific QAT linears. For example, `Int8DynActInt4WeightQATLinear` can be expressed as follows: ``` from torchao.quantization.prototype.qat.api import FakeQuantizeConfig from torchao.quantization.prototype.qat.linear import FakeQuantizedLinear activation_config = FakeQuantizeConfig(torch.int8, "per_token", is_symmetric=False) weight_config = FakeQuantizeConfig(torch.int4, group_size=8) fq_linear = FakeQuantizedLinear(16, 32, False, activation_config, weight_config) ``` The main motivation is to provide a more flexible way to perform QAT on models with linear layers. Previously, we would have to create a new linear class every time we wish to experiment with different fake quantization settings, e.g. different group size or different bit width. Now we can express this easily using a single linear module. **Test Plan:** python test/quantization/test_qat.py -k test_fake_quantize_config_granularity python test/quantization/test_qat.py -k test_fake_quantize_config_granularity_error_cases python test/quantization/test_qat.py -k test_fake_quantize_config_mapping_type python test/quantization/test_qat.py -k test_fake_quantized_linear_8da4w python test/quantization/test_qat.py -k test_fake_quantized_linear_4w [ghstack-poisoned]
2 parents 5642f44 + 622b6df commit b5fe5a7

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

test/integration/test_integration.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,7 @@ def test_swap(self):
328328
assert torch.allclose(y_ref, y)
329329

330330
@unittest.skipIf(not torch.cuda.is_available(), "Need CUDA available")
331+
@unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "newer dtypes not supported")
331332
def test_weight_t_and_non_t_numerics_match(self):
332333
# verify that numerics match whether weight is stored
333334
# in transposed format (for cuBLAS) vs non-transposed format

0 commit comments

Comments
 (0)