|
4 | 4 | # This source code is licensed under the BSD 3-Clause license found in the
|
5 | 5 | # LICENSE file in the root directory of this source tree.
|
6 | 6 |
|
7 |
| -import logging |
8 |
| - |
9 | 7 | import pytest
|
10 | 8 | import torch
|
11 | 9 |
|
| 10 | +from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
| 11 | + |
| 12 | +# We need to skip before doing any imports which would use triton, since |
| 13 | +# triton won't be available on CPU builds and torch < 2.5 |
| 14 | +if not (TORCH_VERSION_AT_LEAST_2_5 and torch.cuda.is_available()): |
| 15 | + pytest.skip("Unsupported PyTorch version", allow_module_level=True) |
| 16 | + |
| 17 | +from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( |
| 18 | + triton_fp8_col_major_jagged_colwise_scales, |
| 19 | + triton_fp8_row_major_jagged_rowwise_scales, |
| 20 | +) |
12 | 21 | from torchao.prototype.scaled_grouped_mm.utils import (
|
13 | 22 | _is_column_major,
|
14 | 23 | _to_2d_jagged_float8_tensor_colwise,
|
15 | 24 | _to_2d_jagged_float8_tensor_rowwise,
|
16 | 25 | )
|
17 | 26 |
|
18 |
| -logging.basicConfig(level=logging.INFO) |
19 |
| -logger = logging.getLogger(__name__) |
20 |
| - |
21 |
| -# triton only ships with pytorch cuda builds, so do import conditionally. |
22 |
| -if torch.cuda.is_available(): |
23 |
| - from torchao.prototype.scaled_grouped_mm.kernels.jagged_float8_scales import ( |
24 |
| - triton_fp8_col_major_jagged_colwise_scales, |
25 |
| - triton_fp8_row_major_jagged_rowwise_scales, |
26 |
| - ) |
27 |
| -from torchao.utils import TORCH_VERSION_AT_LEAST_2_5 |
28 |
| - |
29 | 27 |
|
30 | 28 | @pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
31 | 29 | @pytest.mark.skipIf(not TORCH_VERSION_AT_LEAST_2_5, "torch 2.5+ required")
|
|
0 commit comments