Skip to content

Commit 90ee0ff

Browse files
committed
Add mx_fp4_kernel
stack-info: PR: #1661, branch: drisspg/stack/34
1 parent 7b9df4d commit 90ee0ff

File tree

3 files changed

+103
-67
lines changed

3 files changed

+103
-67
lines changed

test/prototype/mx_formats/test_mx_mm.py

Lines changed: 31 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -2,63 +2,45 @@
22
import torch
33

44
from torchao.float8.float8_utils import compute_error
5-
from torchao.ops import mx_fp8_bf16
6-
from torchao.prototype.mx_formats.mx_tensor import MXTensor
5+
from torchao.ops import mx_fp4_bf16, mx_fp8_bf16
6+
from torchao.prototype.mx_formats.mx_tensor import DTYPE_FP4, MXTensor
77
from torchao.prototype.mx_formats.utils import to_blocked
8-
from torchao.utils import (
9-
TORCH_VERSION_AT_LEAST_2_4,
10-
is_sm_at_least_100,
11-
)
8+
from torchao.utils import TORCH_VERSION_AT_LEAST_2_4, is_sm_at_least_100
129

1310
if not TORCH_VERSION_AT_LEAST_2_4:
1411
pytest.skip("Unsupported PyTorch version", allow_module_level=True)
1512

1613

17-
def run_matrix_test(M: int, K: int, N: int) -> float:
18-
"""
19-
Run matrix multiplication test with given dimensions.
20-
21-
Args:
22-
M, K, N: Matrix dimensions
23-
24-
Returns:
25-
float: SQNR (Signal-to-Quantization-Noise Ratio) value
26-
"""
14+
def run_matrix_test(M: int, K: int, N: int, format) -> float:
2715
dtype = torch.bfloat16
2816
device = torch.device("cuda")
2917

30-
# Initialize matrices
3118
a = torch.rand((M, K), dtype=dtype, device=device)
3219
b = torch.rand((N, K), dtype=dtype, device=device)
3320

34-
# Convert to MX format
35-
a_mx = MXTensor.to_mx(a, torch.float8_e4m3fn, 32)
36-
b_mx = MXTensor.to_mx(b, torch.float8_e4m3fn, 32)
21+
fmt = torch.float8_e4m3fn if format == "fp8" else DTYPE_FP4
22+
mx_func = mx_fp8_bf16 if format == "fp8" else mx_fp4_bf16
3723

38-
a_fp8 = a_mx._data
39-
b_fp8 = b_mx._data
40-
assert b_fp8.is_contiguous()
41-
b_fp8 = b_fp8.transpose(-1, -2)
24+
a_mx = MXTensor.to_mx(a, fmt, 32)
25+
b_mx = MXTensor.to_mx(b, fmt, 32)
4226

43-
# Get scales
44-
a_scale_e8 = a_mx._scale_e8m0.view(M, K // 32)
45-
b_scale_e8 = b_mx._scale_e8m0.view(N, K // 32)
27+
a_data = a_mx._data
28+
b_data = b_mx._data
29+
assert b_data.is_contiguous()
30+
b_data = b_data.transpose(-1, -2)
4631

47-
a_scale_block = to_blocked(a_scale_e8)
48-
b_scale_block = to_blocked(b_scale_e8)
32+
a_scale = a_mx._scale_e8m0.view(M, K // 32)
33+
b_scale = b_mx._scale_e8m0.view(N, K // 32)
34+
35+
a_scale_block = to_blocked(a_scale)
36+
b_scale_block = to_blocked(b_scale)
4937

50-
# Get reference output
5138
out_hp = a_mx.to_dtype(torch.bfloat16) @ b_mx.to_dtype(torch.bfloat16).transpose(
5239
-1, -2
5340
)
41+
out = mx_func(a_data, b_data, a_scale_block, b_scale_block)
5442

55-
# Run implementation
56-
out_e8_fp8 = mx_fp8_bf16(a_fp8, b_fp8, a_scale_block, b_scale_block)
57-
58-
# Calculate metrics
59-
sqnr = compute_error(out_hp, out_e8_fp8)
60-
61-
return sqnr.item()
43+
return compute_error(out_hp, out).item()
6244

6345

6446
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
@@ -68,35 +50,25 @@ def run_matrix_test(M: int, K: int, N: int) -> float:
6850
@pytest.mark.parametrize(
6951
"size",
7052
[
71-
# Small matrices
7253
(128, 128, 128),
7354
(256, 256, 256),
74-
(384, 384, 384),
75-
# Medium matrices
55+
(384, 384, 384), # Small
7656
(512, 512, 512),
77-
(640, 640, 640),
78-
(768, 768, 768),
79-
# Large matrices
80-
(896, 896, 896),
57+
(768, 768, 768), # Medium
8158
(1024, 1024, 1024),
82-
# Very large matrices
83-
(8192, 8192, 8192),
84-
# Non-square matrices
59+
(8192, 8192, 8192), # Large
8560
(128, 256, 384),
86-
(256, 384, 512),
87-
(384, 512, 640),
88-
# Non-aligned matrices
61+
(256, 384, 512), # Non-square
8962
(129, 256, 384),
90-
(256, 384, 536),
91-
(133, 512, 528),
63+
(133, 512, 528), # Non-aligned
9264
],
9365
ids=lambda x: f"{x[0]}x{x[1]}x{x[2]}",
9466
)
95-
def test_matrix_multiplication(size):
96-
"""
97-
Test matrix multiplication with various dimensions.
98-
Verifies that the SQNR meets minimum quality threshold.
99-
"""
67+
@pytest.mark.parametrize("format", ["fp8", "fp4"])
68+
def test_matrix_multiplication(size, format):
10069
M, K, N = size
101-
sqnr = run_matrix_test(M, K, N)
102-
assert sqnr >= 80.0, f"SQNR {sqnr} below threshold for dims {M}x{K}x{N}"
70+
sqnr = run_matrix_test(M, K, N, format)
71+
threshold = 80.0
72+
assert (
73+
sqnr >= threshold
74+
), f"{format} SQNR {sqnr} below threshold for dims {M}x{K}x{N}"

torchao/csrc/cuda/mx_kernels/mx_fp8_bf16.cu renamed to torchao/csrc/cuda/mx_kernels/mx_fp_bf16.cu

Lines changed: 42 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ using namespace cute;
3434

3535
template<typename Element>
3636
constexpr int GetAlignment() {
37-
if constexpr (std::is_same_v<Element, cutlass::nv_float4_t<cutlass::float_e2m1_t>>)
37+
if constexpr (std::is_same_v<Element, cutlass::mx_float4_t<cutlass::float_e2m1_t>>)
3838
return 32;
3939
return 16;
4040
}
@@ -46,11 +46,7 @@ template <typename ElementA,
4646
typename ClusterShape,
4747
typename PerSmTileShape_MNK>
4848
void run_gemm(at::Tensor& a, at::Tensor& b, at::Tensor& a_scale,
49-
at::Tensor& b_scale, at::Tensor& out) {
50-
int M = a.size(0);
51-
int K = a.size(1);
52-
int N = b.size(1);
53-
49+
at::Tensor& b_scale, at::Tensor& out, int M, int K, int N) {
5450
// A matrix configuration
5551
using LayoutATag = cutlass::layout::RowMajor; // Layout type for A matrix operand
5652
constexpr int AlignmentA = GetAlignment<ElementA>(); // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)
@@ -225,9 +221,12 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
225221
at::Tensor b_scale) {
226222
#if defined(BUILD_MX_KERNELS_CUTLASS)
227223
validate(a, b, a_scale, b_scale);
224+
auto M = a.size(0);
225+
auto K = a.size(1);
226+
auto N = b.size(1);
228227

229228
auto out =
230-
at::empty({a.size(0), b.size(1)}, a.options().dtype(at::kBFloat16));
229+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
231230
using ElementA = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
232231
using ElementB = cutlass::mx_float8_t<cutlass::float_e4m3_t>;
233232
using ElementD = cutlass::bfloat16_t;
@@ -236,16 +235,51 @@ at::Tensor mx_fp8_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
236235
using ClusterShape = Shape<_2,_1,_1>;
237236
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
238237

239-
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out);
238+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
240239
return out;
241240
#else
242241
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
243242
return at::Tensor{};
244243
#endif
245244
}
246245

246+
at::Tensor mx_fp4_bf16(at::Tensor a, at::Tensor b, at::Tensor a_scale,
247+
at::Tensor b_scale) {
248+
#if defined(BUILD_MX_KERNELS_CUTLASS)
249+
TORCH_CHECK(a.is_cuda(), "a must be CUDA tensor");
250+
TORCH_CHECK(b.is_cuda(), "b must be CUDA tensor");
251+
TORCH_CHECK(a_scale.is_cuda(), "a_scale must be CUDA tensor");
252+
TORCH_CHECK(b_scale.is_cuda(), "b_scale must be CUDA tensor");
253+
254+
auto M = a.size(0);
255+
auto K = a.size(1) * 2;
256+
auto N = b.size(1);
257+
258+
auto out =
259+
at::empty({M, N}, a.options().dtype(at::kBFloat16));
260+
using ElementA = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
261+
using ElementB = cutlass::mx_float4_t<cutlass::float_e2m1_t>;
262+
using ElementD = cutlass::bfloat16_t;
263+
264+
using MmaTileShape = Shape<_128,_128,_128>;
265+
using ClusterShape = Shape<_2,_1,_1>;
266+
using PerSmTileShape_MNK = Shape<_128,_128,_128>;
267+
268+
run_gemm<ElementA, ElementB, ElementD, MmaTileShape, ClusterShape, PerSmTileShape_MNK>(a, b, a_scale, b_scale, out, M, K, N);
269+
return out;
270+
#else
271+
TORCH_CHECK_NOT_IMPLEMENTED(false, __func__);
272+
return at::Tensor{};
273+
#endif
274+
}
275+
247276
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
248277
m.impl("torchao::mx_fp8_bf16", &mx_fp8_bf16);
249278
}
279+
TORCH_LIBRARY_IMPL(torchao, CUDA, m) {
280+
m.impl("torchao::mx_fp4_bf16", &mx_fp4_bf16);
281+
}
282+
283+
250284

251285
} // namespace torchao

torchao/ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
"rowwise_scaled_linear_cutlass_s8s4(Tensor input, Tensor input_scale, Tensor weight, Tensor weight_scale, Tensor bias) -> Tensor"
2727
)
2828
lib.define("mx_fp8_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
29+
lib.define("mx_fp4_bf16(Tensor a, Tensor b, Tensor a_scale, Tensor b_scale) -> Tensor")
2930

3031

3132
def register_custom_op(name):
@@ -621,3 +622,32 @@ def mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
621622
def meta_mx_fp8_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
622623
"""Meta impl for mx_fp8_bf16"""
623624
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)
625+
626+
627+
def mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
628+
"""Defines a matmul between two fp4 tensors w/ MX scales in E8MO and returns a bf16 tensor.
629+
630+
This op is prototype subject to change.
631+
632+
Note: The mx scales are E8MO tensors stored in uint8 tensors (for now).
633+
The layout of the scales is very particular, see:
634+
https://docs.nvidia.com/cuda/cublas/index.html#d-block-scaling-factors-layout
635+
636+
Args:
637+
A: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
638+
B: fp4 tensor (2 fp4 elements are packed into 1 byte -> elem0|elem1)
639+
A_scale: E8M0 scale tensor for A with groupsize=32 in swizzled layout
640+
B_scale: E8M0 scale tensor for B with groupsize=32 in swizzled layout
641+
642+
Returns:
643+
MXN bf16 Tensor
644+
645+
"""
646+
return torch.ops.torchao.mx_fp4_bf16.default(A, B, A_scale, B_scale)
647+
648+
649+
@register_custom_op("torchao::mx_fp4_bf16")
650+
def meta_mx_fp4_bf16(A: Tensor, B: Tensor, A_scale: Tensor, B_scale: Tensor):
651+
"""Meta impl for mx_fp4_bf16"""
652+
# Assume that the contraction happens in the K dim thus M,N are perserved post bit pack
653+
return torch.empty((A.size(0), B.size(1)), dtype=torch.bfloat16, device=A.device)

0 commit comments

Comments
 (0)