File tree Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Expand file tree Collapse file tree 1 file changed +4
-2
lines changed Original file line number Diff line number Diff line change 4
4
5
5
from torch import nn
6
6
from torch .testing ._internal .common_utils import TestCase , run_tests
7
- from torchao .utils import TORCH_VERSION_AT_LEAST_2_5
7
+ from torchao .utils import TORCH_VERSION_AT_LEAST_2_5 , unwrap_tensor_subclass
8
8
from torchao .dtypes import MarlinSparseLayoutType
9
9
from torchao .sparsity .sparse_api import apply_fake_sparsity
10
10
from torchao .quantization .quant_api import int4_weight_only , quantize_
@@ -55,7 +55,6 @@ def test_quant_sparse_marlin_layout_eager(self):
55
55
56
56
assert torch .allclose (dense_result , sparse_result , atol = 3e-1 ), "Results are not close"
57
57
58
- @pytest .mark .skipif (not TORCH_VERSION_AT_LEAST_2_5 , reason = "Needs PyTorch 2.5" )
59
58
@pytest .mark .skipif (not torch .cuda .is_available (), reason = "Need CUDA available" )
60
59
def test_quant_sparse_marlin_layout_compile (self ):
61
60
apply_fake_sparsity (self .model )
@@ -68,6 +67,9 @@ def test_quant_sparse_marlin_layout_compile(self):
68
67
69
68
# Sparse + quantized
70
69
quantize_ (self .model , int4_weight_only (layout_type = MarlinSparseLayoutType ()))
70
+ if not TORCH_VERSION_AT_LEAST_2_5 :
71
+ unwrap_tensor_subclass (self .model )
72
+
71
73
self .model .forward = torch .compile (self .model .forward , fullgraph = True )
72
74
sparse_result = self .model (self .input )
73
75
You can’t perform that action at this time.
0 commit comments