Skip to content

Commit 38b1f45

Browse files
authored
enable tests on mx_formats + Blackwell (#1905)
Update [ghstack-poisoned]
1 parent b1ecf65 commit 38b1f45

File tree

5 files changed

+13
-43
lines changed

5 files changed

+13
-43
lines changed

test/prototype/mx_formats/test_custom_cast.py

Lines changed: 6 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -42,10 +42,13 @@
4242
sem_vals_to_f32,
4343
)
4444
from torchao.prototype.mx_formats.mx_tensor import MXTensor
45-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
45+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
4646

4747
torch.manual_seed(0)
4848

49+
if not TORCH_VERSION_AT_LEAST_2_8:
50+
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
51+
4952

5053
@pytest.mark.skip(
5154
reason="TODO debug CI failure, low pri since this is not used in the MX code" # noqa: E501
@@ -311,10 +314,7 @@ def test_fp4_pack_unpack():
311314

312315
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
313316
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
314-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
315-
@pytest.mark.skipif(
316-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
317-
)
317+
@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0")
318318
def test_fp4_triton_unscaled_cast():
319319
packed_vals = torch.arange(0, 255, dtype=torch.uint8, device="cuda")
320320
f32_ref = f4_unpacked_to_f32(unpack_uint4(packed_vals))
@@ -324,10 +324,7 @@ def test_fp4_triton_unscaled_cast():
324324

325325
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
326326
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
327-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
328-
@pytest.mark.skipif(
329-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
330-
)
327+
@pytest.mark.skipif(is_sm_at_least_100(), reason="broken on CUDA capability 10.0")
331328
def test_fp4_triton_scaled_cast():
332329
size = (256,)
333330
orig_vals = torch.randn(size, dtype=torch.float, device="cuda") * 100
@@ -421,10 +418,6 @@ def test_fp6_e3m2_rounding(f32_val, f6_e3m2_enc, device):
421418

422419
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
423420
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
424-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
425-
@pytest.mark.skipif(
426-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
427-
)
428421
def test_fp6_e2m3_pack_unpack():
429422
orig_vals = torch.Tensor([[0.0, 0.5, 7.5, -0.0], [-0.875, 1.0, -6.0, 0.125]]).to(
430423
"cuda"
@@ -440,10 +433,6 @@ def test_fp6_e2m3_pack_unpack():
440433

441434
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
442435
@pytest.mark.skipif(not has_triton(), reason="unsupported without triton")
443-
@pytest.mark.skipif(not TORCH_VERSION_AT_LEAST_2_4, reason="requires PyTorch >= 2.4")
444-
@pytest.mark.skipif(
445-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
446-
)
447436
def test_fp6_e3m2_pack_unpack():
448437
orig_vals = torch.Tensor([[0.0, 5.0, 28.0, -0.0], [-0.25, 0.1875, 0.0625, 8.0]]).to(
449438
"cuda"

test/prototype/mx_formats/test_mx_linear.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,14 @@
2525
)
2626
from torchao.quantization.utils import compute_error
2727
from torchao.utils import (
28-
TORCH_VERSION_AT_LEAST_2_5,
28+
TORCH_VERSION_AT_LEAST_2_8,
2929
is_sm_at_least_89,
3030
is_sm_at_least_100,
3131
)
3232

3333
torch.manual_seed(2)
3434

35-
if not TORCH_VERSION_AT_LEAST_2_5:
35+
if not TORCH_VERSION_AT_LEAST_2_8:
3636
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3737

3838

@@ -169,10 +169,6 @@ def test_activation_checkpointing():
169169

170170

171171
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
172-
@pytest.mark.skipif(
173-
is_sm_at_least_100(),
174-
reason="triton does not work yet on CUDA capability 10.0",
175-
)
176172
@pytest.mark.parametrize(
177173
"recipe_name",
178174
[
@@ -265,9 +261,6 @@ def test_inference_linear(elem_dtype, bias, input_shape):
265261

266262

267263
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
268-
@pytest.mark.skipif(
269-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
270-
)
271264
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
272265
def test_inference_compile_simple(elem_dtype):
273266
"""
@@ -294,10 +287,6 @@ def test_inference_compile_simple(elem_dtype):
294287

295288

296289
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
297-
@pytest.mark.skipif(
298-
is_sm_at_least_100(),
299-
reason="triton does not work yet on CUDA capability 10.0",
300-
)
301290
@pytest.mark.skipif(
302291
not is_sm_at_least_100(),
303292
reason="MX gemms require CUDA capability 10.0",

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,9 @@
1010
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
1111
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
1212
from torchao.prototype.mx_formats.utils import to_blocked
13-
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
13+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_8, is_sm_at_least_100
1414

15-
if not TORCH_VERSION_AT_LEAST_2_4:
15+
if not TORCH_VERSION_AT_LEAST_2_8:
1616
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1717

1818

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,13 @@
2525
)
2626
from torchao.quantization.utils import compute_error
2727
from torchao.utils import (
28-
TORCH_VERSION_AT_LEAST_2_4,
28+
TORCH_VERSION_AT_LEAST_2_8,
2929
is_sm_at_least_89,
30-
is_sm_at_least_100,
3130
)
3231

3332
torch.manual_seed(2)
3433

35-
if not TORCH_VERSION_AT_LEAST_2_4:
34+
if not TORCH_VERSION_AT_LEAST_2_8:
3635
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
3736

3837

@@ -207,8 +206,6 @@ def test_transpose(elem_dtype, fp4_triton):
207206
"""
208207
if elem_dtype != DTYPE_FP4 and fp4_triton:
209208
pytest.skip("unsupported configuration")
210-
elif fp4_triton and is_sm_at_least_100():
211-
pytest.skip("triton does not work yet on CUDA capability 10.0")
212209

213210
M, K = 128, 256
214211
block_size = 32
@@ -265,9 +262,6 @@ def test_fp6_packing(elem_dtype, pack_fp6):
265262

266263

267264
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
268-
@pytest.mark.skipif(
269-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
270-
)
271265
@pytest.mark.parametrize("elem_dtype", SUPPORTED_ELEM_DTYPES)
272266
@pytest.mark.parametrize("hp_dtype", [torch.float32, torch.bfloat16])
273267
@pytest.mark.parametrize("all_zeros", [False, True])
@@ -324,9 +318,6 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
324318

325319

326320
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
327-
@pytest.mark.skipif(
328-
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
329-
)
330321
@pytest.mark.skipif(
331322
not is_sm_at_least_89(),
332323
reason="float8 in triton requires CUDA capability 8.9 or greater",

torchao/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -356,6 +356,7 @@ def torch_version_at_least(min_version):
356356
return is_fbcode() or compare_versions(torch.__version__, min_version) >= 0
357357

358358

359+
TORCH_VERSION_AT_LEAST_2_8 = torch_version_at_least("2.8.0")
359360
TORCH_VERSION_AT_LEAST_2_7 = torch_version_at_least("2.7.0")
360361
TORCH_VERSION_AT_LEAST_2_6 = torch_version_at_least("2.6.0")
361362
TORCH_VERSION_AT_LEAST_2_5 = torch_version_at_least("2.5.0")

0 commit comments

Comments
 (0)