-
Notifications
You must be signed in to change notification settings - Fork 293
[Feature] Low-bit twiddling dequantization and FP4 GEMM #725
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
22 commits
Select commit
Hold shift + click to select a range
d2293ad
[Dequant] Add bit-twiddling dequantize cuda for fp4-->bf16
tzj-fxz 47cacb1
[Dequant] Add extern call and serial dequantization
tzj-fxz 12e25f5
[Dequant] Parallel Dequant wait for fence debug.
tzj-fxz e60ab76
[Scale] Add scale matrix to mxfp4 gemm
tzj-fxz 6d1fffd
[Remove] Remove fence-buggy example and some generated source cuda code
tzj-fxz 0dea3db
[MXFP4] Update initial version of MXFP4 GEMM
tzj-fxz 3aae12c
[Scale] Add scale to latest mxfp4 gemm
tzj-fxz 0cb0b59
[Lint]
tzj-fxz 5e575f9
[BugFix] Load Scale, disabe TMA to recover performance
tzj-fxz 49e54af
[Lint]
tzj-fxz 07752a9
[Lint]
tzj-fxz b86874b
[Scale] Use L2 to hold Scale and enable TMA will slightly boost perfo…
tzj-fxz 09ed919
[Lint]
tzj-fxz e01f86a
Update example_dequant_gemm_bf16_fp4_hopper_serial.py
LeiWang1999 a6ae3c3
Remove deprecated dequantization examples for BF16 and MXFP4 in the d…
LeiWang1999 668100f
Refactor dequantization examples for improved readability and consist…
LeiWang1999 cc88146
Refactor index_to_coordinates usage in bitnet example and update dequ…
LeiWang1999 a51569d
lint fix
LeiWang1999 4def1e9
ci fix
LeiWang1999 0149177
Remove non-existent example
tzj-fxz c856ced
[BugFix] Add smem swizzle to recover performance of TMA
tzj-fxz f604ede
[BugFix] Enough reg for producer when threads=512
tzj-fxz File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
245 changes: 245 additions & 0 deletions
245
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| import tilelang | ||
| import tilelang.language as T | ||
| from tilelang import tvm as tvm | ||
| from tvm import DataType | ||
| from tvm import tir | ||
| import torch | ||
| from utils import torch_convert_bit_twiddling, torch_convert | ||
|
|
||
|
|
||
| def get_configs(): | ||
| import itertools | ||
| iter_params = dict( | ||
| block_M=[64, 128, 256], | ||
| block_N=[64, 128, 256], | ||
| block_K=[128], | ||
| num_stages=[0, 2], | ||
| threads=[128, 256, 512], | ||
| split=[1, 2], | ||
| ) | ||
| return [{ | ||
| k: v for k, v in zip(iter_params, values) | ||
| } for values in itertools.product(*iter_params.values())] | ||
|
|
||
|
|
||
| @tilelang.autotune(configs=get_configs(),) | ||
| @tilelang.jit( | ||
| out_idx=[-1], | ||
| pass_configs={ | ||
| tilelang.PassConfigKey.TL_DISABLE_TMA_LOWER: True, | ||
| tilelang.PassConfigKey.TL_DISABLE_WARP_SPECIALIZED: True | ||
| }, | ||
| ) | ||
| def matmul(M, | ||
| N, | ||
| K, | ||
| in_dtype, | ||
| out_dtype, | ||
| accum_dtype, | ||
| source_format='uint', | ||
| num_bits=4, | ||
| fast_dequant=True, | ||
| block_M=256, | ||
| block_N=128, | ||
| block_K=128, | ||
| num_stages=2, | ||
| threads=256, | ||
| split=1): | ||
| num_elems_per_byte = 8 // num_bits | ||
| storage_dtype = "uint8" | ||
|
|
||
| QK = K // num_elems_per_byte | ||
| Block_QK = block_K // num_elems_per_byte | ||
| A_shape = (M, K) | ||
| B_shape = (N, QK) | ||
| A_shared_shape = (block_M, block_K) | ||
| B_shared_shape = (block_N, Block_QK) | ||
| B_dequantize_shared_shape = (block_N, block_K) | ||
| assert K % (block_K * split) == 0 | ||
|
|
||
| from tilelang.quantize import get_mxfp_intrin_group | ||
|
|
||
| # fast_dequant_bf16_fp4_twiddling | ||
| mxfp_intrin_info = get_mxfp_intrin_group( | ||
| out_dtype=in_dtype, | ||
| source_format=source_format, | ||
| source_bit=num_bits, | ||
| storage_dtype=storage_dtype, | ||
| use_twiddling=True, | ||
| ) | ||
|
|
||
| import_source = mxfp_intrin_info["c_source"] | ||
| func_name = mxfp_intrin_info["func_name"] | ||
| assert import_source is not None, "mxfp_intrin_info is not found" | ||
| assert func_name is not None, "mxfp_intrin_info is not found" | ||
| import_source = import_source | ||
|
|
||
| def get_fast_dequant_twiddling_func(in_dtype="fp4", out_dtype="bfloat16"): | ||
| assert in_dtype in ["fp4"] | ||
| assert out_dtype in ["bfloat16"] | ||
|
|
||
| # Some variables for dequantization in each thread | ||
| MAX_TRANSACTION_SIZE_BITS = 128 | ||
| local_size = MAX_TRANSACTION_SIZE_BITS // DataType(out_dtype).bits | ||
| local_compress_size = local_size // num_elems_per_byte | ||
|
|
||
| @T.macro | ||
| def fast_dequant_bf16_fp4_twiddling(B_shared, B_dequantize_shared): | ||
| # import fast_dequantize plugin | ||
| T.import_source(import_source) | ||
|
|
||
| tx = T.get_thread_binding() | ||
|
|
||
| B_local_thread = T.alloc_local((local_compress_size,), storage_dtype) | ||
| B_dequantize_local_thread = T.alloc_local((local_size,), out_dtype) | ||
| for i in T.serial(0, block_N * block_K // threads // local_size): | ||
| # First, load data from share memory to register. | ||
| # Prepare for dequant. | ||
| for v in T.vectorized(0, local_compress_size): | ||
| index = i * threads * local_compress_size + tx * local_compress_size + v | ||
| B_local_thread[v] = B_shared[index // Block_QK, index % Block_QK] | ||
|
|
||
| # Then, dequant. | ||
| T.call_extern( | ||
| func_name, | ||
| T.address_of(B_local_thread[0]), | ||
| T.address_of(B_dequantize_local_thread[0]), | ||
| 1, | ||
| dtype=out_dtype, | ||
| ) | ||
|
|
||
| # Finally, store the dequantized data to shared memory. | ||
| for v in T.vectorized(0, local_size): | ||
| index = i * threads * local_size + tx * local_size + v | ||
| B_dequantize_shared[index // block_K, | ||
| index % block_K] = B_dequantize_local_thread[v] | ||
|
|
||
| return fast_dequant_bf16_fp4_twiddling | ||
|
|
||
| def get_simple_dequant_func(in_dtype="fp4", out_dtype="bfloat16"): | ||
| assert in_dtype in ["fp4"] | ||
| assert out_dtype in ["bfloat16"] | ||
|
|
||
| def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, | ||
| scale: tir.PrimExpr, dtype: str): | ||
| assert nbit == 4 | ||
| assert dtype == "bfloat16" | ||
| assert val.dtype == "uint8" | ||
| mask = tir.const((1 << nbit) - 1, "uint16") | ||
| f4 = (val >> (pos.astype("uint16") * tir.const(nbit, "uint16"))) & mask | ||
| s = f4 >> tir.const(3, "uint16") | ||
| e_f4 = (f4 & tir.const(6, "uint16")) >> tir.const(1, "uint16") | ||
| # Exponential bias between f4 and bf16 is 2^(8-1) - 2^(2-1) = 126 | ||
| e_bf16 = e_f4 + tir.const(126, "uint16") | ||
| # Scale is the exponential part, within the representation of uint8 | ||
| # To handle the overflow, we use the max function to limit the exponential part to 8 bits | ||
| e_bf16 = T.min(e_bf16 + scale, tir.const((1 << 8) - 1, "uint16")) | ||
| m_f4 = f4 & tir.const(1, "uint16") | ||
| val_bf16 = tir.reinterpret( | ||
| "bfloat16", ((((s << tir.const(8, "uint16")) | e_bf16) << tir.const(7, "uint16")) | ||
| | (m_f4 << tir.const(6, "uint16"))).astype("uint16")) | ||
| return val_bf16 | ||
|
|
||
| @T.macro | ||
| def simple_dequant_bf16_fp4(B_shared, B_dequantize_shared): | ||
| B_local = T.alloc_fragment(B_shared_shape, storage_dtype) | ||
| B_dequantize_local = T.alloc_fragment(B_dequantize_shared_shape, out_dtype) | ||
| T.copy(B_shared, B_local) | ||
| for i, j in T.Parallel(block_N, block_K): | ||
| B_dequantize_local[i, j] = _tir_u8_to_f4_to_bf16( | ||
| num_bits, | ||
| B_shared[i, j // num_elems_per_byte], | ||
| j % num_elems_per_byte, | ||
| 0, # No scale for test | ||
| dtype=out_dtype, | ||
| ) | ||
| T.copy(B_dequantize_local, B_dequantize_shared) | ||
|
|
||
| return simple_dequant_bf16_fp4 | ||
|
|
||
| @T.prim_func | ||
| def main( | ||
| A: T.Tensor(A_shape, in_dtype), | ||
| B: T.Tensor(B_shape, storage_dtype), | ||
| C: T.Tensor((M, N), out_dtype), | ||
| ): | ||
| with T.Kernel(T.ceildiv(N, block_N), T.ceildiv(M, block_M), threads=threads) as (bx, by): | ||
| A_shared = T.alloc_shared(A_shared_shape, in_dtype) | ||
| B_shared = T.alloc_shared(B_shared_shape, storage_dtype) | ||
| B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) | ||
|
|
||
| C_local = T.alloc_fragment((block_M, block_N), accum_dtype) | ||
| C_shared = T.alloc_shared((block_M, block_N), out_dtype) | ||
|
|
||
| T.annotate_layout({ | ||
| C_shared: tilelang.layout.make_swizzled_layout(C_shared), | ||
| }) | ||
|
|
||
| T.clear(C_local) | ||
| for k in T.Pipelined(K // block_K, num_stages=num_stages): | ||
| T.copy(A[by * block_M, k * block_K], A_shared) | ||
| T.copy(B[bx * block_N, k * block_K // num_elems_per_byte], B_shared) | ||
|
|
||
| if fast_dequant: | ||
| get_fast_dequant_twiddling_func()(B_shared, B_dequantize_shared) | ||
| else: | ||
| get_simple_dequant_func()(B_shared, B_dequantize_shared) | ||
|
|
||
| T.gemm(A_shared, B_dequantize_shared, C_local, transpose_B=True) | ||
|
|
||
| T.copy(C_local, C_shared) | ||
| T.copy(C_shared, C[by * block_M:(by + 1) * block_M, bx * block_N:(bx + 1) * block_N]) | ||
|
|
||
| return main | ||
|
|
||
|
|
||
| def ref_program_twiddling(A, qB): | ||
| dtypeC = "bfloat16" | ||
| B = torch_convert_bit_twiddling(qB) | ||
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||
| C = C.to(torch.__getattribute__(dtypeC)) | ||
| return C | ||
|
|
||
|
|
||
| def ref_program_simple(A, qB): | ||
| dtypeC = "bfloat16" | ||
| B = torch_convert(qB) | ||
| C = torch.matmul(A.to(torch.float), B.T.to(torch.float)) | ||
| C = C.to(torch.__getattribute__(dtypeC)) | ||
| return C | ||
|
|
||
|
|
||
| def main(m=256, n=256, k=256, fast_dequant=True, tune=False): | ||
| total_flops = 2 * m * n * k | ||
| if tune: | ||
| kernel = matmul( | ||
| m, n, k, "bfloat16", "bfloat16", "float32", num_bits=4, fast_dequant=fast_dequant) | ||
| else: | ||
| kernel = matmul( | ||
| m, | ||
| n, | ||
| k, | ||
| "bfloat16", | ||
| "bfloat16", | ||
| "float32", | ||
| num_bits=4, | ||
| fast_dequant=fast_dequant, | ||
| block_M=256, | ||
| block_N=128, | ||
| block_K=128, | ||
| num_stages=2, | ||
| threads=256, | ||
| split=1) | ||
| profiler = kernel.get_profiler(tilelang.TensorSupplyType.Auto) | ||
| if fast_dequant: | ||
| profiler.assert_allclose(ref_program_twiddling, rtol=0.01, atol=0.01) | ||
| else: | ||
| profiler.assert_allclose(ref_program_simple, rtol=0.01, atol=0.01) | ||
| latency = profiler.do_bench(warmup=500) | ||
| print("Tile-lang: {:.2f} ms".format(latency)) | ||
| print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main(256, 256, 256, True) | ||
| main(256, 256, 256, False) | ||
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Missing Scale tensor parameter
The main function signature doesn't include a Scale tensor parameter, but the mxfp4 variant has it. This inconsistency between the two examples suggests missing functionality.
The function is missing the Scale tensor parameter that's present in the mxfp4 variant. Add it to support proper FP4 dequantization with scaling:
def main( A: T.Tensor(A_shape, in_dtype), B: T.Tensor(B_shape, storage_dtype), + Scale: T.Tensor(Scale_shape, storage_dtype), C: T.Tensor((M, N), out_dtype), ):Also need to define
Scale_shapeearlier:B_shape = (N, QK) + Scale_shape = (N, K // 32) # Assuming scale_size=32 as default A_shared_shape = (block_M, block_K)📝 Committable suggestion
🤖 Prompt for AI Agents