Skip to content

Commit 8554cb0

Browse files
Rachmaninotzj-fxz
andauthored
[Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE (#811)
* [Enhancement] Enhance dequantization examples and utilities - Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`. - Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance. - Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process. Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com> * fix typos in docstrings * remove redundant code * [Format] Unreproducible debug with T.print * [BugFix] Correct dtype in ref dequantize; larger data distribution * [Format] * [Refactor] Clean up and optimize example_dequant_groupgemm_bf16_mxfp4_hopper.py and utils.py - Removed unnecessary cache disabling and manual seed setting in the example. - Simplified nested loops into parallelized operations for better readability and performance. - Updated the assertion function in utils.py to print detailed error messages. - Adjusted tensor sizes in examples * [Refactor] Update import path in example_dequant_gemm_fine_grained.py - Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure. * lint * rename and add test * lint * [Feature] Enhance autotuning and configuration generation in example_dequant_groupedgemm_bf16_mxfp4_hopper.py - Added a new function `get_configs()` to generate hyperparameter configurations for tuning. - Updated the `matmul` function to utilize autotuning with the new configurations. - Improve kernel performance via vectorization and threadblock swizzle. - Enhanced the main function to support the new autotuning inputs and updated parameters for better performance. * lint * fix typo * fix typo and lint * make ci format check happy * fix ci --------- Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com> Co-authored-by: tzj-fxz <tzjfxz@gmail.com>
1 parent e4a346f commit 8554cb0

File tree

7 files changed

+603
-47
lines changed

7 files changed

+603
-47
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -389,9 +389,7 @@ def ref_program_twiddling(A, qB, Scale, Bias=None):
389389
"""
390390
dtypeC = "bfloat16"
391391
B = torch_convert_bit_twiddling(qB)
392-
for i in range(B.shape[0]):
393-
for j in range(B.shape[1]):
394-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
392+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
395393
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
396394
C = C.to(torch.__getattribute__(dtypeC))
397395
return C
@@ -414,9 +412,7 @@ def ref_program_twiddling_with_bias(A, qB, Scale, Bias):
414412
"""
415413
dtypeC = "bfloat16"
416414
B = torch_convert_bit_twiddling(qB)
417-
for i in range(B.shape[0]):
418-
for j in range(B.shape[1]):
419-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
415+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
420416
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
421417
C = C.to(torch.__getattribute__(dtypeC))
422418
return C
@@ -440,9 +436,7 @@ def ref_program_simple(A, qB, Scale, Bias=None):
440436
"""
441437
dtypeC = "bfloat16"
442438
B = torch_convert(qB)
443-
for i in range(B.shape[0]):
444-
for j in range(B.shape[1]):
445-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
439+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
446440
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
447441
C = C.to(torch.__getattribute__(dtypeC))
448442
return C
@@ -470,9 +464,7 @@ def ref_program_simple_with_bias(A, qB, Scale, Bias):
470464
"""
471465
dtypeC = "bfloat16"
472466
B = torch_convert(qB)
473-
for i in range(B.shape[0]):
474-
for j in range(B.shape[1]):
475-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
467+
B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // 32)])
476468
C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) + Bias
477469
C = C.to(torch.__getattribute__(dtypeC))
478470
return C

examples/dequantize_gemm/example_dequant_gemm_fine_grained.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ def matmul(
2323
threads,
2424
num_bits=4,
2525
):
26-
from bitblas.quantization import _tir_packed_to_unsigned_convert
26+
from tilelang.quantize import _tir_packed_to_unsigned_convert
2727
num_elems_per_byte = 8 // num_bits
2828
storage_dtype = "int8"
2929
storage_nbit = int("".join(c for c in storage_dtype if c.isdigit()))

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

Lines changed: 511 additions & 0 deletions
Large diffs are not rendered by default.

examples/dequantize_gemm/test_example_dequantize_gemm.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import example_dequant_gemm_fp4_hopper
55
import example_dequant_gemm_bf16_mxfp4_hopper
66
import example_dequant_gemm_bf16_mxfp4_hopper_tma
7+
import example_dequant_groupedgemm_bf16_mxfp4_hopper
78
import example_dequant_gemm_w4a8
89

910

@@ -31,6 +32,13 @@ def test_example_dequant_gemm_bf16_mxfp4_hopper_tma():
3132

3233

3334
@tilelang.testing.requires_cuda
35+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
36+
def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
37+
example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
38+
39+
40+
@tilelang.testing.requires_cuda
41+
@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
3442
def test_example_dequant_gemm_w4a8():
3543
example_dequant_gemm_w4a8.main()
3644

examples/dequantize_gemm/utils.py

Lines changed: 76 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@
33

44
def torch_convert_bit_twiddling(tensor):
55
"""
6-
Convert a 2-D uint8 tensor into a bfloat16 tensor by decoding pairs of input bytes with a bit-twiddling scheme.
7-
86
This function expects `tensor` to be a 2-D torch.Tensor of dtype `torch.uint8`. Each output element is produced by combining two input bytes and extracting a bf16-like 16-bit pattern according to one of four positional bit layouts (pos 0..3). The result is scaled by 2**126 to adjust the exponent bias and returned as dtype `torch.bfloat16`.
97
108
Parameters:
@@ -16,38 +14,46 @@ def torch_convert_bit_twiddling(tensor):
1614
Raises:
1715
AssertionError: If any byte inputs used for a conversion are not dtype `torch.uint8`.
1816
"""
17+
assert tensor.dim() == 2 and tensor.dtype == torch.uint8
18+
N, K = tensor.shape
19+
assert K % 2 == 0, "Number of columns must be even"
1920

20-
def _convert(val0, val1, pos) -> torch.bfloat16:
21-
assert val0.dtype == torch.uint8
22-
assert val1.dtype == torch.uint8
23-
val0 = val0.view(torch.uint8)
24-
val1 = val1.view(torch.uint8)
25-
val_concat = (val0.item() << 8) | val1.item()
26-
mask = 0b1000000111000000
27-
if pos == 0:
28-
bf16 = val_concat & mask
29-
elif pos == 1:
30-
bf16 = (val_concat << 3) & mask
31-
elif pos == 2:
32-
bf16 = (val_concat << 6) & mask
33-
elif pos == 3:
34-
mask1 = 0b1000000000000000
35-
mask2 = 0b0000000110000000
36-
mask3 = 0b0000000001000000
37-
bf16 = ((val_concat << 1) & mask1) | ((val_concat >> 3) & mask2) | (
38-
(val_concat >> 7) & mask3)
39-
bf16_new = torch.tensor([bf16], dtype=torch.uint16, device=val0.device).view(torch.bfloat16)
40-
# Add bias for change from fp4 to bf16
41-
bf16_new = bf16_new.item() * (2**126)
42-
return bf16_new
21+
# Combine pairs of uint8 values into uint32 for safe bitwise ops on CUDA
22+
val0 = tensor[:, 0::2].to(torch.int32)
23+
val1 = tensor[:, 1::2].to(torch.int32)
24+
val_concat = (val0 << 8) | val1 # (N, K//2), uint32
4325

44-
N = tensor.shape[0]
45-
K = tensor.shape[1]
46-
new_tensor = torch.empty(N, K * 2, dtype=torch.bfloat16, device=tensor.device)
47-
for i in range(new_tensor.shape[0]):
48-
for j in range(new_tensor.shape[1]):
49-
new_tensor[i][j] = _convert(tensor[i][j // 4 * 2], tensor[i][j // 4 * 2 + 1], j % 4)
50-
return new_tensor
26+
# Expand to match output shape where each pair generates 4 values
27+
val_concat_expanded = val_concat.repeat_interleave(4, dim=1) # (N, K//2*4)
28+
29+
# Positional encoding for bit-twiddling logic
30+
pos = torch.arange(K * 2, device=tensor.device) % 4 # (K*2,)
31+
32+
# Bit masks for decoding (as uint32 for CUDA compatibility)
33+
mask = 0b1000000111000000
34+
mask1 = 0b1000000000000000
35+
mask2 = 0b0000000110000000
36+
mask3 = 0b0000000001000000
37+
38+
# Calculate results for all 4 positions in parallel
39+
res0 = val_concat_expanded & mask
40+
res1 = (val_concat_expanded << 3) & mask
41+
res2 = (val_concat_expanded << 6) & mask
42+
res3 = ((val_concat_expanded << 1) & mask1) | ((val_concat_expanded >> 3) & mask2) | (
43+
(val_concat_expanded >> 7) & mask3)
44+
45+
# Select the correct result based on position
46+
bf16 = torch.where(pos == 0, res0, torch.where(pos == 1, res1,
47+
torch.where(pos == 2, res2, res3)))
48+
49+
# Convert to uint16 for .view(torch.bfloat16)
50+
bf16_uint16 = (bf16 & 0xFFFF).to(torch.uint16)
51+
bf16_bf16 = bf16_uint16.view(torch.bfloat16)
52+
53+
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
54+
bf16_new = bf16_bf16 * (2.0**126)
55+
56+
return bf16_new
5157

5258

5359
def torch_convert(tensor, scale_size=None, Scale=None):
@@ -106,3 +112,41 @@ def print_bit(name, val):
106112
val_cpu = val.cpu().item()
107113
binary_repr = f'{val_cpu:032b}'
108114
print(name, binary_repr)
115+
116+
117+
def print_red_warning(message):
118+
print(f"\033[31mWARNING: {message}\033[0m")
119+
120+
121+
def calc_sim(x, y, name="tensor"):
122+
x, y = x.data.double(), y.data.double()
123+
denominator = (x * x + y * y).sum()
124+
if denominator == 0:
125+
print_red_warning(f'{name} all zero')
126+
return 1
127+
sim = 2 * (x * y).sum() / denominator
128+
return sim
129+
130+
131+
def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
132+
x_mask = torch.isfinite(x)
133+
y_mask = torch.isfinite(y)
134+
if not torch.all(x_mask == y_mask):
135+
print_red_warning(f'{name} Error: isfinite mask mismatch')
136+
if raise_assert:
137+
raise AssertionError
138+
if not torch.isclose(
139+
x.masked_fill(x_mask, 0), y.masked_fill(y_mask, 0), rtol=0, atol=0,
140+
equal_nan=True).all():
141+
print_red_warning(f'{name} Error: nonfinite value mismatch')
142+
if raise_assert:
143+
raise AssertionError
144+
x = x.masked_fill(~x_mask, 0)
145+
y = y.masked_fill(~y_mask, 0)
146+
sim = calc_sim(x, y, name)
147+
diff = (1. - sim).item()
148+
print(f'{diff=}')
149+
if not (0 <= diff <= eps):
150+
print_red_warning(f'{name} Error: {diff=}')
151+
if raise_assert:
152+
raise AssertionError

tilelang/language/builtin.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -331,13 +331,13 @@ def shfl_up(value: Union[int, PrimExpr, tir.Call], offset: Union[int, PrimExpr,
331331

332332

333333
def sync_threads():
334-
"""Synchronize all threads in a warp.
334+
"""Synchronize all threads in a block.
335335
"""
336336
return tir.op.tvm_storage_sync("shared")
337337

338338

339339
def sync_global():
340-
"""Synchronize all threads in a block.
340+
"""Synchronize all threads in the entire grid.
341341
"""
342342
tx, ty, tz = get_thread_bindings()
343343
ex, ey, ez = get_block_extents()

tilelang/quantize/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
_tir_packed_to_fp4_to_f16, # noqa: F401
66
_tir_u8_to_f8_e4m3_to_f16, # noqa: F401
77
_tir_packed_to_unsigned_convert_with_zeros, # noqa: F401
8+
_tir_u8_to_f4_to_bf16, # noqa: F401
89
)
910

1011
from .utils import (

0 commit comments

Comments
 (0)