Skip to content

Commit 8d110bf

Browse files
authored
modify cast from hp to mx to help inductor fuse (#1786)
Summary: Thanks to investigation from @eellison, moving the reshape to the end of the cast helps inductor fuse the cast into a single kernel. This doesn't yet work with fp4, but let's unblock fp8 and deal with fp4 later. Fixes #1690 Note: in the repro with swizzling from #1773, we go from 3 to 2 kernels. Further investigation is needed whether we can fuse the swizzling. Test Plan: ``` pytest test/prototype/mx_formats/test_mx_tensor.py -x -s -k test_to_mx_inductor_single_kernel ``` Reviewers: Subscribers: Tasks: Tags:
1 parent d00ee41 commit 8d110bf

File tree

2 files changed

+34
-1
lines changed

2 files changed

+34
-1
lines changed

test/prototype/mx_formats/test_mx_tensor.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
import pytest
88
import torch
9+
from torch._inductor.utils import run_and_get_code
10+
from torch.testing import FileCheck
911

1012
from torchao.prototype.mx_formats.config import MXGemmKernelChoice
1113
from torchao.prototype.mx_formats.constants import (
@@ -284,3 +286,25 @@ def test_to_mx_from_mx_compile_numerics(elem_dtype, hp_dtype, all_zeros):
284286
use_fp4_custom_triton_dequant_kernel,
285287
)
286288
torch.testing.assert_close(x_mx_dq, x_mx_c_dq, atol=0, rtol=0)
289+
290+
291+
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
292+
@pytest.mark.skipif(
293+
is_sm_at_least_100(), reason="triton does not work yet on CUDA capability 10.0"
294+
)
295+
@pytest.mark.skipif(
296+
not is_sm_at_least_89(),
297+
reason="float8 in triton requires CUDA capability 8.9 or greater",
298+
)
299+
def test_to_mx_inductor_single_kernel():
300+
"""
301+
Verify that inductor can fuse the cast of a high precision tensor to mx
302+
into a single kernel
303+
"""
304+
# TODO(future PR): add fp4 and fp6 here
305+
# TODO(#1773): add swizzled scale format here
306+
x = torch.randn(2048, 2048, dtype=torch.bfloat16, device="cuda")
307+
block_size = 32
308+
to_mx_c = torch.compile(MXTensor.to_mx, fullgraph=True)
309+
out, code = run_and_get_code(to_mx_c, x, torch.float8_e4m3fn, block_size)
310+
FileCheck().check("def call(").check_count(".run(", 1, exactly=True).run(code[0])

torchao/prototype/mx_formats/mx_tensor.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -205,16 +205,25 @@ def to_mx(
205205
data_lp = torch.clamp(
206206
data_hp / scale_fp.unsqueeze(1), min=-1 * max_pos, max=max_pos
207207
)
208-
data_lp = data_lp.reshape(orig_shape)
209208

210209
# cast to target dtype
211210
if elem_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
212211
data_lp = data_lp.to(elem_dtype)
212+
# need to reshape at the end to help inductor fuse things
213+
data_lp = data_lp.reshape(orig_shape)
213214
elif elem_dtype == DTYPE_FP6_E2M3:
214215
data_lp = f32_to_f6_e2m3_unpacked(data_lp)
216+
# need to reshape at the end to help inductor fuse things
217+
data_lp = data_lp.reshape(orig_shape)
215218
elif elem_dtype == DTYPE_FP6_E3M2:
216219
data_lp = f32_to_f6_e3m2_unpacked(data_lp)
220+
# need to reshape at the end to help inductor fuse things
221+
data_lp = data_lp.reshape(orig_shape)
217222
elif elem_dtype == DTYPE_FP4:
223+
# can't reshape at the end without handling it in the packing code,
224+
# punt until later since we'll need to rethink the torch.compile
225+
# approach for fp4x2 in any case
226+
data_lp = data_lp.reshape(orig_shape)
218227
data_lp = f32_to_f4_unpacked(data_lp)
219228
data_lp = pack_uint4(data_lp)
220229
else:

0 commit comments

Comments
 (0)