Skip to content

Commit db17c3d

Browse files
authored
Merge branch 'pytorch:main' into rocm_autoquant
2 parents 66b89a2 + 2c901b3 commit db17c3d

34 files changed

+2772
-151
lines changed

.github/workflows/build_wheels_linux.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ jobs:
2828
os: linux
2929
with-cpu: enable
3030
with-cuda: enable
31-
with-rocm: disable
31+
with-rocm: enable
3232
with-xpu: enable
3333
# Note: if free-threaded python is required add py3.13t here
3434
python-versions: '["3.9"]'

benchmarks/mx_formats/cast_bench.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD 3-Clause license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
17
from typing import Callable, Tuple
28

39
import fire
410
import torch
511
import triton
612
from torch._inductor.utils import do_bench_using_profiling
713

8-
from torchao.prototype.mx_formats.custom_cast import (
14+
from torchao.prototype.mx_formats.kernels import (
915
triton_to_mxfp8_dim1,
1016
)
1117
from torchao.prototype.mx_formats.mx_tensor import to_mx

setup.py

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -311,16 +311,17 @@ def get_extensions():
311311
glob.glob(os.path.join(extensions_cuda_dir, "**/*.cu"), recursive=True)
312312
)
313313

314-
extensions_hip_dir = os.path.join(
315-
extensions_dir, "cuda", "tensor_core_tiled_layout"
316-
)
317-
hip_sources = list(
318-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
319-
)
320-
extensions_hip_dir = os.path.join(extensions_dir, "cuda", "sparse_marlin")
321-
hip_sources += list(
322-
glob.glob(os.path.join(extensions_hip_dir, "*.cu"), recursive=True)
323-
)
314+
# Define HIP source directories
315+
hip_source_dirs = [
316+
os.path.join(extensions_dir, "cuda", "tensor_core_tiled_layout"),
317+
# TODO: Add sparse_marlin back in once we have a ROCm build for it
318+
# os.path.join(extensions_dir, "cuda", "sparse_marlin")
319+
]
320+
321+
# Collect all HIP sources from the defined directories
322+
hip_sources = []
323+
for hip_dir in hip_source_dirs:
324+
hip_sources.extend(glob.glob(os.path.join(hip_dir, "*.cu"), recursive=True))
324325

325326
# Collect CUDA source files if needed
326327
if not IS_ROCM and use_cuda:

test/prototype/mx_formats/test_custom_cast.py renamed to test/prototype/mx_formats/test_kernels.py

Lines changed: 33 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,17 @@
1616
F6_E2M3_EXP_BIAS,
1717
F6_E3M2_EXP_BIAS,
1818
)
19-
from torchao.prototype.mx_formats.custom_cast import (
19+
from torchao.prototype.mx_formats.fp_format_spec import (
20+
_assert_equals,
21+
dtype_to_interesting_values,
22+
float4_e2m1_interesting_values,
23+
float6_e2m3_interesting_values,
24+
float6_e3m2_interesting_values,
25+
get_sem_bits,
26+
sem_bits_to_sem_vals,
27+
sem_vals_to_f32,
28+
)
29+
from torchao.prototype.mx_formats.kernels import (
2030
f4_unpacked_to_f32,
2131
f6_e2m3_unpacked_to_f32,
2232
f6_e3m2_unpacked_to_f32,
@@ -33,17 +43,8 @@
3343
triton_to_mxfp8_dim1_reference,
3444
unpack_uint4,
3545
)
36-
from torchao.prototype.mx_formats.fp_format_spec import (
37-
_assert_equals,
38-
dtype_to_interesting_values,
39-
float4_e2m1_interesting_values,
40-
float6_e2m3_interesting_values,
41-
float6_e3m2_interesting_values,
42-
get_sem_bits,
43-
sem_bits_to_sem_vals,
44-
sem_vals_to_f32,
45-
)
4646
from torchao.prototype.mx_formats.mx_tensor import MXTensor
47+
from torchao.prototype.mx_formats.utils import to_blocked
4748
from torchao.utils import (
4849
TORCH_VERSION_AT_LEAST_2_8,
4950
is_sm_at_least_89,
@@ -465,3 +466,24 @@ def test_triton_mxfp8_dim1_randn(M, K):
465466
x_mx_t, x_s_t = triton_to_mxfp8_dim1(x, inner_block_size=32)
466467
torch.testing.assert_close(x_mx_t, x_mx_ref, rtol=0, atol=0)
467468
torch.testing.assert_close(x_s_t, x_s_ref, rtol=0, atol=0)
469+
470+
471+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
472+
@pytest.mark.parametrize(
473+
"shape",
474+
[
475+
(63, 1023),
476+
(128, 4),
477+
(128, 8),
478+
(256, 8),
479+
(300, 9),
480+
(133, 512),
481+
(528, 512),
482+
(128, 1),
483+
],
484+
)
485+
def test_rearrange(shape):
486+
scales = torch.randint(256, size=shape, device="cuda", dtype=torch.uint8)
487+
eager = to_blocked(scales, False)
488+
triton = to_blocked(scales, True)
489+
torch.testing.assert_close(eager, triton, atol=0, rtol=0)

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
DTYPE_FP6_E3M2,
1818
SUPPORTED_ELEM_DTYPES,
1919
)
20-
from torchao.prototype.mx_formats.custom_cast import pack_uint4, pack_uint6
20+
from torchao.prototype.mx_formats.kernels import pack_uint4, pack_uint6
2121
from torchao.prototype.mx_formats.mx_tensor import (
2222
MXTensor,
2323
ScaleCalculationMode,

0 commit comments

Comments
 (0)