Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 21 additions & 12 deletions examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
# To handle the overflow, we may use the min function to limit the exponential part to 8 bits
# e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
Comment on lines +43 to +44
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Since the scale is now applied externally, this commented-out code is dead and can be removed to improve clarity. The scale parameter in the function signature (line 10) also becomes unused and could be removed in a follow-up refactoring.

m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret("bfloat16",
((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
Expand Down Expand Up @@ -218,7 +218,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
Scale_local_thread = T.alloc_local((1,), storage_dtype)
Scale_local_thread_exponent = T.alloc_local((1,), "float32")
Scale_local_thread_exponent = T.alloc_local((1,), out_dtype)

for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
Expand All @@ -231,8 +231,7 @@ def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared, Scale, k):
si = index_scale // (block_K // scale_size)
sj = index_scale % (block_K // scale_size)
Scale_local_thread[0] = Scale[bx * block_N + si, k * block_K // scale_size + sj]
Scale_local_thread_exponent[0] = T.exp2(
T.cast(Scale_local_thread[0] - 127, "float"))
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))

Comment on lines +234 to 235
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

T.shift_left(1, Scale) returns an integer, not BF16; negative exponents and overflow aren’t handled.

  • Type: T.shift_left with literal 1 yields an integer; assigning it into a BF16 buffer relies on implicit casting (often not allowed in TIR).
  • Semantics: bit-shift cannot represent negative exponents; if Scale encodes signed values (common), 2^Scale for negative Scale will be wrong. For large positive Scale (>=31), the integer overflows before cast.

Robust fix: build the exact BF16 2^scale by constructing the BF16 bit-pattern (handles negative exponents and avoids integer overflow). Replace the assignment with:

-                Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
+                # Treat stored scale as signed int8 exponent offset from 0, then form a BF16 value 2^scale.
+                scale_i8 = tir.reinterpret("int8", Scale_local_thread[0])
+                e_i16 = T.cast(scale_i8, "int16") + tir.const(127, "int16")   # BF16 bias
+                # Clamp to avoid subnormals/inf; adjust limits as desired (0..254 keeps finite, non-subnormal range)
+                e_i16 = T.min(T.max(e_i16, tir.const(0, "int16")), tir.const(254, "int16"))
+                e_u16 = T.cast(e_i16, "uint16")
+                Scale_local_thread_exponent[0] = tir.reinterpret(
+                    out_dtype,
+                    T.shift_left(e_u16, tir.const(7, "uint16")),  # mantissa = 0
+                )

If you truly want the shift approach for performance and can guarantee non-negative, small exponents, at minimum add explicit casts to satisfy the type-checker and document the precondition:

-                Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
+                # Precondition: Scale in [0, 30]
+                Scale_local_thread_exponent[0] = T.cast(
+                    T.shift_left(tir.const(1, "int32"), T.cast(Scale_local_thread[0], "int32")),
+                    out_dtype,
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
# Treat stored scale as signed int8 exponent, then build a BF16 2^scale by crafting its bit-pattern.
scale_i8 = tir.reinterpret("int8", Scale_local_thread[0])
# Add BF16 exponent bias (127) after widening to int16
e_i16 = T.cast(scale_i8, "int16") + tir.const(127, "int16")
# Clamp into [0,254] to avoid subnormals/inf (0->min norm, 255->NaN/Inf)
e_i16 = T.min(T.max(e_i16, tir.const(0, "int16")),
tir.const(254, "int16"))
# Reinterpret as unsigned, then shift into BF16 exponent bits
e_u16 = T.cast(e_i16, "uint16")
Scale_local_thread_exponent[0] = tir.reinterpret(
out_dtype,
T.shift_left(e_u16, tir.const(7, "uint16")), # mantissa = 0
)

# Then, dequant.
T.call_extern(
Expand Down Expand Up @@ -288,7 +287,7 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
- Mutates B_dequantize_shared by storing the dequantized BF16 fragment.
"""
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, in_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)

bx = T.get_block_binding(0)
T.copy(B_shared, B_local)
Expand All @@ -300,8 +299,9 @@ def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared, Scale, k):
Scale[
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
Comment on lines 299 to +304
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The expression to access the scale value is duplicated. To improve readability and avoid redundant calculations, you can store the scale value in a local variable before using it. For example:

scale_val = Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
    num_bits,
    B_local[i, j // num_elems_per_byte],
    j % num_elems_per_byte,
    scale_val,  # This argument is unused and can be removed if the function is refactored.
    dtype=out_dtype,
) * T.shift_left(1, scale_val)

T.copy(B_dequantize_local, B_dequantize_shared)
Comment on lines 300 to 305
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🛠️ Refactor suggestion

Simple path repeats the same issues: int shift typedness, negative exponents, and overflow.

... ) * T.shift_left(1, Scale[...]) multiplies BF16 by an integer expr, has the same negative-exponent problem, and may overflow for large exponents.

Apply the same bit-pattern approach inline to form a BF16 2^scale value:

-                ) * T.shift_left(
-                    1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
+                ) * tir.reinterpret(
+                    out_dtype,
+                    T.shift_left(
+                        T.cast(
+                            T.min(
+                                T.max(
+                                    T.cast(
+                                        tir.reinterpret(
+                                            "int8",
+                                            Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size],
+                                        ),
+                                        "int16",
+                                    ) + tir.const(127, "int16"),
+                                    tir.const(0, "int16"),
+                                ),
+                                tir.const(254, "int16"),
+                            ),
+                            "uint16",
+                        ),
+                        tir.const(7, "uint16"),
+                    ),
+                )

If you keep the shift optimization with restricted Scale ranges, use explicit casts and add an assert:

+                T.Assert(
+                    (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size] <= tir.const(30, "uint8")),
+                    "Scale must be <= 30 when using integer shift.",
+                )
-                ) * T.shift_left(
-                    1, (Scale[...]))
+                ) * T.cast(
+                    T.shift_left(
+                        tir.const(1, "int32"),
+                        T.cast(Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size], "int32"),
+                    ),
+                    out_dtype,
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
bx * block_N + i, k * block_K // scale_size + j //
scale_size], # Scale is the exponential part, within the representation of uint8
dtype=in_dtype,
)
dtype=out_dtype,
) * T.shift_left(
1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
T.copy(B_dequantize_local, B_dequantize_shared)
bx * block_N + i, k * block_K // scale_size + j // scale_size], # Scale is the exponential part, within the representation of uint8
dtype=out_dtype,
) * tir.reinterpret(
out_dtype,
T.shift_left(
T.cast(
T.min(
T.max(
T.cast(
tir.reinterpret(
"int8",
Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size],
),
"int16",
) + tir.const(127, "int16"),
tir.const(0, "int16"),
),
tir.const(254, "int16"),
),
"uint16",
),
tir.const(7, "uint16"),
),
)
T.copy(B_dequantize_local, B_dequantize_shared)


return simple_dequant_bf16_fp4
Expand Down Expand Up @@ -374,7 +374,7 @@ def ref_program_twiddling(A, qB, Scale):
B = torch_convert_bit_twiddling(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand All @@ -400,7 +400,7 @@ def ref_program_simple(A, qB, Scale):
B = torch_convert(qB)
for i in range(B.shape[0]):
for j in range(B.shape[1]):
B[i][j] = B[i][j] * (2**(Scale[i][j // 32] - 127))
B[i][j] = B[i][j] * (2**(Scale[i][j // 32]))
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C
Expand All @@ -427,7 +427,15 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):

if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, scale_size=scale_size)
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
scale_size=scale_size,
fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
Expand All @@ -443,7 +451,8 @@ def main(m=256, n=256, k=256, scale_size=32, fast_dequant=True, tune=False):
block_K=128,
num_stages=2,
threads=256,
split=1)
split=1,
fast_dequant=fast_dequant)

profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)

Expand Down
Loading