Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
d2293ad
[Dequant] Add bit-twiddling dequantize cuda for fp4-->bf16
tzj-fxz Aug 13, 2025
47cacb1
[Dequant] Add extern call and serial dequantization
tzj-fxz Aug 14, 2025
12e25f5
[Dequant] Parallel Dequant wait for fence debug.
tzj-fxz Aug 14, 2025
e60ab76
[Scale] Add scale matrix to mxfp4 gemm
tzj-fxz Aug 14, 2025
6d1fffd
[Remove] Remove fence-buggy example and some generated source cuda code
tzj-fxz Aug 15, 2025
0dea3db
[MXFP4] Update initial version of MXFP4 GEMM
tzj-fxz Aug 15, 2025
3aae12c
[Scale] Add scale to latest mxfp4 gemm
tzj-fxz Aug 15, 2025
0cb0b59
[Lint]
tzj-fxz Aug 15, 2025
5e575f9
[BugFix] Load Scale, disabe TMA to recover performance
tzj-fxz Aug 15, 2025
49e54af
[Lint]
tzj-fxz Aug 15, 2025
07752a9
[Lint]
tzj-fxz Aug 15, 2025
b86874b
[Scale] Use L2 to hold Scale and enable TMA will slightly boost perfo…
tzj-fxz Aug 18, 2025
09ed919
[Lint]
tzj-fxz Aug 18, 2025
e01f86a
Update example_dequant_gemm_bf16_fp4_hopper_serial.py
LeiWang1999 Aug 18, 2025
a6ae3c3
Remove deprecated dequantization examples for BF16 and MXFP4 in the d…
LeiWang1999 Aug 18, 2025
668100f
Refactor dequantization examples for improved readability and consist…
LeiWang1999 Aug 18, 2025
cc88146
Refactor index_to_coordinates usage in bitnet example and update dequ…
LeiWang1999 Aug 18, 2025
a51569d
lint fix
LeiWang1999 Aug 18, 2025
4def1e9
ci fix
LeiWang1999 Aug 18, 2025
0149177
Remove non-existent example
tzj-fxz Aug 18, 2025
c856ced
[BugFix] Add smem swizzle to recover performance of TMA
tzj-fxz Aug 18, 2025
f604ede
[BugFix] Enough reg for producer when threads=512
tzj-fxz Aug 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from tvm import DataType
from tilelang.intrinsics.mma_layout import (
make_mma_swizzle_layout as make_swizzle_layout,)
from tilelang.intrinsics.utils import index_to_coordinates
import numpy as np

from tilelang.intrinsics.mma_macro_generator import (
Expand Down Expand Up @@ -200,7 +199,7 @@ def main(
index = (
i * threads * local_size_compressed +
thread_bindings * local_size_compressed + v)
vi, vj = index_to_coordinates(index, B_shared_shape)
vi, vj = T.index_to_coordinates(index, B_shared_shape)
B_local[v] = B_shared[vi, vj]

T.call_extern(
Expand All @@ -212,7 +211,7 @@ def main(

for v in T.vectorized(0, local_size):
index = (i * threads * local_size + thread_bindings * local_size + v)
vi, vj = index_to_coordinates(index, B_dequantize_shared_shape)
vi, vj = T.index_to_coordinates(index, B_dequantize_shared_shape)
B_dequantize_shared[vi, vj] = B_dequantize_local[v]

for ki in T.serial(0, (block_K // micro_size_k)):
Expand Down
245 changes: 245 additions & 0 deletions examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
import tilelang
import tilelang.language as T
from tilelang import tvm as tvm
from tvm import DataType
from tvm import tir
import torch
from utils import torch_convert_bit_twiddling, torch_convert


def get_configs():
import itertools
iter_params = dict(
block_M=[64, 128, 256],
block_N=[64, 128, 256],
block_K=[128],
num_stages=[0, 2],
threads=[128, 256, 512],
split=[1, 2],
)
return [{
k: v for k, v in zip(iter_params, values)
} for values in itertools.product(*iter_params.values())]


@tilelang.autotune(configs=get_configs(),)
@tilelang.jit(
out_idx=[-1],
pass_configs={
tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True,
tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True
},
)
def matmul(M,
N,
K,
in_dtype,
out_dtype,
accum_dtype,
source_format='uint',
num_bits=4,
fast_dequant=True,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1):
num_elems_per_byte = 8 // num_bits
storage_dtype = "uint8"

QK = K // num_elems_per_byte
Block_QK = block_K // num_elems_per_byte
A_shape = (M, K)
B_shape = (N, QK)
A_shared_shape = (block_M, block_K)
B_shared_shape = (block_N, Block_QK)
B_dequantize_shared_shape = (block_N, block_K)
assert K % (block_K * split) == 0

from tilelang.quantize import get_mxfp_intrin_group

# fast_dequant_bf16_fp4_twiddling
mxfp_intrin_info = get_mxfp_intrin_group(
out_dtype=in_dtype,
source_format=source_format,
source_bit=num_bits,
storage_dtype=storage_dtype,
use_twiddling=True,
)

import_source = mxfp_intrin_info["c_source"]
func_name = mxfp_intrin_info["func_name"]
assert import_source is not None, "mxfp_intrin_info is not found"
assert func_name is not None, "mxfp_intrin_info is not found"
import_source = import_source

def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]

# Some variables for dequantization in each thread
MAX_TRANSACTION_SIZE_BITS = 128
local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits
local_compress_size = local_size // num_elems_per_byte

@T.macro
def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared):
# import fast_dequantize plugin
T.import_source(import_source)

tx = T.get_thread_binding()

B_local_thread = T.alloc_local((local_compress_size,), storage_dtype)
B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype)
for i in T.serial(0, block_N * block_K // threads // local_size):
# First, load data from share memory to register.
# Prepare for dequant.
for v in T.vectorized(0, local_compress_size):
index = i * threads * local_compress_size + tx * local_compress_size + v
B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK]

# Then, dequant.
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)

# Finally, store the dequantized data to shared memory.
for v in T.vectorized(0, local_size):
index = i * threads * local_size + tx * local_size + v
B_dequantize_shared[index // block_K,
index % block_K] = B_dequantize_local_thread[v]

return fast_dequant_bf16_fp4_twiddling

def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"):
assert in_dtype in ["fp4"]
assert out_dtype in ["bfloat16"]

def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr,
scale: tir.PrimExpr, dtype: str):
assert nbit == 4
assert dtype == "bfloat16"
assert val.dtype == "uint8"
mask = tir.const((1 << nbit) - 1, "uint16")
f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask
s = f4 >> tir.const(3, "uint16")
e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16")
# Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126
e_bf16 = e_f4 + tir.const(126, "uint16")
# Scale is the exponential part, within the representation of uint8
# To handle the overflow, we use the max function to limit the exponential part to 8 bits
e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16"))
m_f4 = f4 & tir.const(1, "uint16")
val_bf16 = tir.reinterpret(
"bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16"))
| (m_f4 << tir.const(6, "uint16"))).astype("uint16"))
return val_bf16

@T.macro
def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared):
B_local = T.alloc_fragment(B_shared_shape, storage_dtype)
B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype)
T.copy(B_shared, B_local)
for i, j in T.Parallel(block_N, block_K):
B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16(
num_bits,
B_shared[i, j // num_elems_per_byte],
j % num_elems_per_byte,
0, # No scale for test
dtype=out_dtype,
)
T.copy(B_dequantize_local, B_dequantize_shared)

return simple_dequant_bf16_fp4

@T.prim_func
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
Comment on lines +162 to +164
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Missing Scale tensor parameter

The main function signature doesn't include a Scale tensor parameter, but the mxfp4 variant has it. This inconsistency between the two examples suggests missing functionality.

The function is missing the Scale tensor parameter that's present in the mxfp4 variant. Add it to support proper FP4 dequantization with scaling:

     def main(
             A: T.Tensor(A_shape, in_dtype),
             B: T.Tensor(B_shape, storage_dtype),
+            Scale: T.Tensor(Scale_shape, storage_dtype),
             C: T.Tensor((M, N), out_dtype),
     ):

Also need to define Scale_shape earlier:

     B_shape = (N, QK)
+    Scale_shape = (N, K // 32)  # Assuming scale_size=32 as default
     A_shared_shape = (block_M, block_K)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
B_shape = (N, QK)
Scale_shape = (N, K // 32) # Assuming scale_size=32 as default
A_shared_shape = (block_M, block_K)
def main(
A: T.Tensor(A_shape, in_dtype),
B: T.Tensor(B_shape, storage_dtype),
Scale: T.Tensor(Scale_shape, storage_dtype),
C: T.Tensor((M, N), out_dtype),
):
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py around lines
162-164, the function signature is missing the Scale tensor parameter present in
the mxfp4 variant; add a Scale entry to the signature (e.g., Scale:
T.Tensor(Scale_shape, storage_dtype or appropriate scale dtype) alongside A, B,
C) and ensure you define Scale_shape earlier in the file (matching the expected
scale dimensions for the FP4 dequantization) so both variants are consistent.

):
with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by):
A_shared = T.alloc_shared(A_shared_shape, in_dtype)
B_shared = T.alloc_shared(B_shared_shape, storage_dtype)
B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)

C_local = T.alloc_fragment((block_M, block_N), accum_dtype)
C_shared = T.alloc_shared((block_M, block_N), out_dtype)

T.annotate_layout({
C_shared: tilelang.layout.make_swizzled_layout(C_shared),
})

T.clear(C_local)
for k in T.Pipelined(K // block_K, num_stages=num_stages):
T.copy(A[by * block_M, k * block_K], A_shared)
T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared)

if fast_dequant:
get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared)
else:
get_simple_dequant_func()(B_shared, B_dequantize_shared)

T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True)

T.copy(C_local, C_shared)
T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N])

return main


def ref_program_twiddling(A, qB):
dtypeC = "bfloat16"
B = torch_convert_bit_twiddling(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C


def ref_program_simple(A, qB):
dtypeC = "bfloat16"
B = torch_convert(qB)
C = torch.matmul(A.to(torch.float), B.T.to(torch.float))
C = C.to(torch.__getattribute__(dtypeC))
return C


def main(m=256, n=256, k=256, fast_dequant=True, tune=False):
total_flops = 2 * m * n * k
if tune:
kernel = matmul(
m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant)
else:
kernel = matmul(
m,
n,
k,
"bfloat16",
"bfloat16",
"float32",
num_bits=4,
fast_dequant=fast_dequant,
block_M=256,
block_N=128,
block_K=128,
num_stages=2,
threads=256,
split=1)
profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto)
if fast_dequant:
profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01)
else:
profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01)
latency = profiler.do_bench(warmup=500)
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))


if __name__ == "__main__":
main(256, 256, 256, True)
main(256, 256, 256, False)
Loading
Loading