|
72 | 72 | AQInt8WeightOnlyQuantizedLinearWeight2,
|
73 | 73 | AQInt8WeightOnlyQuantizedLinearWeight3,
|
74 | 74 | AutoQuantizableLinearWeight,
|
75 |
| - |
| 75 | + AQFloat8WeightOnlyQuantizedLinearWeight, |
76 | 76 | )
|
77 | 77 | from torch.ao.quantization.quantize_fx import convert_to_reference_fx, prepare_fx
|
78 | 78 | import os
|
|
98 | 98 | COMMON_DTYPES = [torch.float32, torch.float16, torch.bfloat16]
|
99 | 99 |
|
100 | 100 | COMMON_DEVICE_DTYPE = list(itertools.product(COMMON_DEVICES, COMMON_DTYPES)).copy()
|
| 101 | +is_H100 = torch.cuda.is_available() and torch.cuda.get_device_capability() >= (8, 9) |
101 | 102 |
|
102 | 103 | def _int8wo_api(mod):
|
103 | 104 | if TORCH_VERSION_AT_LEAST_2_4:
|
@@ -744,6 +745,14 @@ def test_aq_int8_weight_only_quant_3_subclass(self, device, dtype):
|
744 | 745 | AQInt8WeightOnlyQuantizedLinearWeight3.from_float, device, 35, test_dtype=dtype
|
745 | 746 | )
|
746 | 747 |
|
| 748 | + @parameterized.expand(COMMON_DEVICE_DTYPE) |
| 749 | + @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "autoquant+aqt needs newer pytorch") |
| 750 | + @unittest.skipIf(not is_H100, "Need H100 to run") |
| 751 | + def test_aq_float8_weight_only_quant_subclass(self, device, dtype): |
| 752 | + self._test_lin_weight_subclass_impl( |
| 753 | + AQFloat8WeightOnlyQuantizedLinearWeight.from_float, device, 30, test_dtype=dtype |
| 754 | + ) |
| 755 | + |
747 | 756 | @parameterized.expand(COMMON_DEVICE_DTYPE)
|
748 | 757 | @unittest.skipIf(not TORCH_VERSION_AT_LEAST_2_3, "int4 requires torch nightly.")
|
749 | 758 | # @unittest.skipIf(TORCH_VERSION_AT_LEAST_2_5, "int4 skipping 2.5+ for now")
|
|
0 commit comments