Skip to content

Commit 796b3bb

Browse files
tzj-fxzLeiWang1999
andauthored
[MXFP4] Fix bugs and optimize exponential operation (#750)
* [MXFP4] Fix bugs - Optimize exp2 with shift operation to boost performance - Fix bug of simple dequantization function call - Fix bug of scaling factor with bias * [Lint] --------- Co-authored-by: LeiWang1999 <leiwang1999@outlook.com>
1 parent e835762 commit 796b3bb

File tree

1 file changed

+21
-12
lines changed

1 file changed

+21
-12
lines changed

examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
4040
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
4141
e_bf16 = e_f4 + tir.const(126, "uint16")
4242
# Scale is the exponential part, within the representation of uint8
43-
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
44-
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
43+
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
44+
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
4545
m_f4 = f4 & tir.const(1, "uint16")
4646
val_bf16 = tir.reinterpret("bfloat16",
4747
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
@@ -218,7 +218,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
218218
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
219219
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
220220
Scale_local_thread = T.alloc_local((1,), storage_dtype)
221-
Scale_local_thread_exponent = T.alloc_local((1,), "float32")
221+
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)
222222

223223
for i in T.serial(0, block_N * block_K // threads // local_size):
224224
# First, load data from share memory to register.
@@ -231,8 +231,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
231231
si = index_scale // (block_K // scale_size)
232232
sj = index_scale % (block_K // scale_size)
233233
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
234-
Scale_local_thread_exponent[0] = T.exp2(
235-
T.cast(Scale_local_thread[0] - 127, "float"))
234+
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
236235

237236
# Then, dequant.
238237
T.call_extern(
@@ -288,7 +287,7 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
288287
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
289288
"""
290289
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
291-
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
290+
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
292291

293292
bx = T.get_block_binding(0)
294293
T.copy(B_shared, B_local)
@@ -300,8 +299,9 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
300299
Scale[
301300
bx * block_N + i, k * block_K // scale_size + j //
302301
scale_size], # Scale is the exponential part, within the representation of uint8
303-
dtype=in_dtype,
304-
)
302+
dtype=out_dtype,
303+
) * T.shift_left(
304+
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
305305
T.copy(B_dequantize_local, B_dequantize_shared)
306306

307307
return simple_dequant_bf16_fp4
@@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale):
374374
B = torch_convert_bit_twiddling(qB)
375375
for i in range(B.shape[0]):
376376
for j in range(B.shape[1]):
377-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
377+
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
378378
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
379379
C = C.to(torch.__getattribute__(dtypeC))
380380
return C
@@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale):
400400
B = torch_convert(qB)
401401
for i in range(B.shape[0]):
402402
for j in range(B.shape[1]):
403-
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
403+
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
404404
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
405405
C = C.to(torch.__getattribute__(dtypeC))
406406
return C
@@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
427427

428428
if tune:
429429
kernel = matmul(
430-
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size)
430+
m,
431+
n,
432+
k,
433+
"bfloat16",
434+
"bfloat16",
435+
"float32",
436+
num_bits=4,
437+
scale_size=scale_size,
438+
fast_dequant=fast_dequant)
431439
else:
432440
kernel = matmul(
433441
m,
@@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
443451
block_K=128,
444452
num_stages=2,
445453
threads=256,
446-
split=1)
454+
split=1,
455+
fast_dequant=fast_dequant)
447456

448457
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
449458

0 commit comments

Comments
 (0)