Skip to content

Commit 16b3979

Browse files
skip
1 parent 4b88373 commit 16b3979

File tree

1 file changed

+11
-13
lines changed

1 file changed

+11
-13
lines changed

test/prototype/scaled_grouped_mm/test_kernels.py

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -4,28 +4,26 @@
44
# This source code is licensed under the BSD 3-Clause license found in the
55
# LICENSE file in the root directory of this source tree.
66

7-
import logging
8-
97
import pytest
108
import torch
119

10+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
11+
12+
# We need to skip before doing any imports which would use triton, since
13+
# triton won't be available on CPU builds and torch < 2.5
14+
if not (TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available()):
15+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
16+
17+
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
18+
triton_fp8_col_major_jagged_colwise_scales,
19+
triton_fp8_row_major_jagged_rowwise_scales,
20+
)
1221
from torchao.prototype.scaled_grouped_mm.utils import (
1322
_is_column_major,
1423
_to_2d_jagged_float8_tensor_colwise,
1524
_to_2d_jagged_float8_tensor_rowwise,
1625
)
1726

18-
logging.basicConfig(level=logging.INFO)
19-
logger = logging.getLogger(__name__)
20-
21-
# triton only ships with pytorch cuda builds, so do import conditionally.
22-
if torch.cuda.is_available():
23-
from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import (
24-
triton_fp8_col_major_jagged_colwise_scales,
25-
triton_fp8_row_major_jagged_rowwise_scales,
26-
)
27-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_5
28-
2927

3028
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
3129
@pytest.mark.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "torch 2.5+ required")

0 commit comments

Comments
 (0)