Skip to content

Commit f9dfd54

Browse files
authored
Skip some tests for torch 2.4 (#1981)
Signed-off-by: yiliu30 <yi4.liu@intel.com>
1 parent 46d9192 commit f9dfd54

File tree

3 files changed

+7
-7
lines changed

3 files changed

+7
-7
lines changed

neural_compressor/torch/algorithms/pt2e_quant/utility.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.ao.quantization.quantizer import QuantizationSpec
2222
from torch.ao.quantization.quantizer.x86_inductor_quantizer import QuantizationConfig, X86InductorQuantizer
2323

24-
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2
24+
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5
2525

2626

2727
def create_quant_spec_from_config(dtype, sym, granularity, algo, is_dynamic=False) -> QuantizationSpec:
@@ -102,8 +102,8 @@ def create_xiq_quantizer_from_pt2e_config(config, is_dynamic=False) -> X86Induct
102102
# set global
103103
global_config = _map_inc_config_to_torch_quant_config(config, is_dynamic)
104104
quantizer.set_global(global_config)
105-
# need torch >= 2.3.2
106-
if GT_TORCH_VERSION_2_3_2: # pragma: no cover
105+
# need torch >= 2.5
106+
if GT_OR_EQUAL_TORCH_VERSION_2_5: # pragma: no cover
107107
op_type_config_dict, op_name_config_dict = config._get_op_name_op_type_config()
108108
if op_type_config_dict:
109109
for op_type, config in op_type_config_dict.items():

neural_compressor/torch/utils/environ.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_torch_version():
104104
return version
105105

106106

107-
GT_TORCH_VERSION_2_3_2 = get_torch_version() > Version("2.3.2")
107+
GT_OR_EQUAL_TORCH_VERSION_2_5 = get_torch_version() >= Version("2.5")
108108

109109

110110
def get_accelerator(device_name="auto"):

test/3x/torch/quantization/test_pt2e_quant.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
prepare,
1616
quantize,
1717
)
18-
from neural_compressor.torch.utils import GT_TORCH_VERSION_2_3_2, TORCH_VERSION_2_2_2, get_torch_version
18+
from neural_compressor.torch.utils import GT_OR_EQUAL_TORCH_VERSION_2_5, TORCH_VERSION_2_2_2, get_torch_version
1919

2020
torch.manual_seed(0)
2121

@@ -131,7 +131,7 @@ def calib_fn(model):
131131
logger.warning("out shape is %s", out.shape)
132132
assert out is not None
133133

134-
@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.2")
134+
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
135135
def test_quantize_simple_model_with_set_local(self, force_not_import_ipex):
136136
model, example_inputs = self.build_simple_torch_model_and_example_inputs()
137137
float_model_output = model(*example_inputs)
@@ -243,7 +243,7 @@ def get_node_in_graph(graph_module):
243243
nodes_in_graph[n] = 1
244244
return nodes_in_graph
245245

246-
@pytest.mark.skipif(not GT_TORCH_VERSION_2_3_2, reason="Requires torch>=2.3.0")
246+
@pytest.mark.skipif(not GT_OR_EQUAL_TORCH_VERSION_2_5, reason="Requires torch>=2.5")
247247
def test_mixed_fp16_and_int8(self, force_not_import_ipex):
248248
model, example_inputs = self.build_model_include_conv_and_linear()
249249
model = export(model, example_inputs=example_inputs)

0 commit comments

Comments
 (0)