18
18
fake_quantize_per_token ,
19
19
)
20
20
from torchao .quantization .quant_primitives import get_group_qparams_symmetric
21
- from torchao .quantization .utils import TORCH_VERSION_AFTER_2_3
21
+ from torchao .quantization .utils import TORCH_VERSION_AFTER_2_4
22
22
23
23
24
24
# TODO: put this in a common test utils file
@@ -58,7 +58,7 @@ def _get_qmin_qmax(self, n_bit: int):
58
58
qmax = 2 ** (n_bit - 1 ) - 1
59
59
return (qmin , qmax )
60
60
61
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
61
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
62
62
def test_fake_quantize_per_channel_group (self ):
63
63
n_bit = 4
64
64
(qmin , qmax ) = self ._get_qmin_qmax (n_bit )
@@ -84,7 +84,7 @@ def test_fake_quantize_per_channel_group(self):
84
84
)
85
85
torch .testing .assert_close (out , out_ptq , atol = 0 , rtol = 0 )
86
86
87
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
87
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
88
88
def test_fake_quantize_per_token (self ):
89
89
(qmin , qmax ) = self ._get_qmin_qmax (8 )
90
90
@@ -130,7 +130,7 @@ def _set_ptq_weight(
130
130
ptq_linear .scales = s
131
131
ptq_linear .zeros = zp
132
132
133
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
133
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
134
134
def test_qat_8da4w_linear (self ):
135
135
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATLinear
136
136
from torchao .quantization .GPTQ import Int8DynActInt4WeightLinear
@@ -155,7 +155,7 @@ def test_qat_8da4w_linear(self):
155
155
ptq_out = ptq_linear (x2 )
156
156
torch .testing .assert_close (ptq_out , qat_out , atol = 0 , rtol = 0 )
157
157
158
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
158
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
159
159
def test_qat_8da4w_quantizer (self ):
160
160
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
161
161
from torchao .quantization .GPTQ import Int8DynActInt4WeightQuantizer
@@ -189,7 +189,7 @@ def test_qat_8da4w_quantizer(self):
189
189
for k in ptq_state_dict .keys ():
190
190
torch .testing .assert_close (ptq_state_dict [k ], converted_state_dict [k ], atol = 0 , rtol = 0 )
191
191
192
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
192
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
193
193
def test_qat_8da4w_quantizer_meta_weights (self ):
194
194
from torchao .quantization .prototype .qat import Int8DynActInt4WeightQATQuantizer
195
195
@@ -201,7 +201,7 @@ def test_qat_8da4w_quantizer_meta_weights(self):
201
201
qat_model = qat_quantizer .prepare (m )
202
202
self .assertTrue (all (v .is_meta for v in qat_model .state_dict ().values ()))
203
203
204
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
204
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
205
205
def test_qat_8da4w_quantizer_disable_fake_quant (self ):
206
206
"""
207
207
Test that 8da4w QAT with disabled fake quant matches nn.Linear in forward.
@@ -254,7 +254,7 @@ def test_qat_8da4w_quantizer_disable_fake_quant(self):
254
254
qat_out2 = qat_model2 (* x2 )
255
255
torch .testing .assert_close (qat_out , qat_out2 , atol = 0 , rtol = 0 )
256
256
257
- @unittest .skipIf (not TORCH_VERSION_AFTER_2_3 , "skipping when torch verion is 2.3 or lower" )
257
+ @unittest .skipIf (not TORCH_VERSION_AFTER_2_4 , "skipping when torch verion is 2.4 or lower" )
258
258
def test_qat_8da4w_quantizer_disable_fake_quant_backward (self ):
259
259
"""
260
260
Test that 8da4w QAT with disabled fake quant matches nn.Linear in backward.
0 commit comments