- 
                Notifications
    You must be signed in to change notification settings 
- Fork 291
[Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE #811
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`. - Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance. - Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process. Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
| WalkthroughAdds a new grouped MoE FP4→BF16 dequantized GEMM example and a CUDA test, vectorizes per-column-group scaling and bit‑twiddling in example utilities, swaps a quantize export, and updates two builtin docstrings. No public function signatures changed except the quantize exports. Changes
 Sequence Diagram(s)sequenceDiagram
  autonumber
  participant User
  participant Main as example_dequant_groupedgemm.main
  participant Data as get_data()
  participant Builder as JIT Builder
  participant Kernel as matmul (TileLang)
  participant Ref as ref_moe (PyTorch)
  participant Prof as tilelang.profiler
  User->>Main: invoke (M,N,K,topk,E,fast_dequant,with_bias)
  Main->>Data: generate A, qB, Scale, Bias, topk_weights, ids, padding
  Main->>Builder: build/compile kernel (fast/simple dequant)
  Builder-->>Kernel: compiled kernel
  Main->>Kernel: launch kernel with inputs
  Kernel->>Kernel: tiled K-loop → dequant → GEMM → apply bias/topk
  Kernel-->>Main: C (grouped, weighted)
  Main->>Ref: compute reference C_ref
  Ref-->>Main: C_ref
  Main->>Main: validate via assert_similar
  Main->>Prof: benchmark
  Prof-->>Main: latency stats
  Main-->>User: results & validation
sequenceDiagram
  autonumber
  participant Tile as Kernel Block
  participant A_s as A_shared
  participant B_p as qB_packed
  participant Deq as Dequant Macro
  participant B_d as B_bf16_tile
  participant MMA as GEMM/Acc
  Tile->>A_s: load A tile
  Tile->>B_p: load packed FP4 tile
  alt fast_dequant
    Tile->>Deq: call external twiddling intrinsic
  else simple_dequant
    Tile->>Deq: call `_tir_u8_to_f4_to_bf16` path
  end
  Deq-->>B_d: BF16 tile (per-block Scale applied)
  MMA->>MMA: compute tiled GEMM (apply Bias/topk)
  MMA-->>Tile: write back results
Estimated code review effort🎯 4 (Complex) | ⏱️ ~60 minutes Possibly related PRs
 Suggested reviewers
 Poem
 Pre-merge checks and finishing touches✅ Passed checks (3 passed)
 ✨ Finishing touches
 🧪 Generate unit tests
 📜 Recent review detailsConfiguration used: CodeRabbit UI Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
 🚧 Files skipped from review as they are similar to previous changes (1)
 ⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
 Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment  | 
| 👋 Hi! Thank you for contributing to the TileLang project. Please remember to run  We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 | 
| TODO
 | 
| cc @tzj-fxz | 
| Add the similarity-based metric to check correctness for the large values will somewhat magnify the mismatch | 
…_hopper.py and utils.py - Removed unnecessary cache disabling and manual seed setting in the example. - Simplified nested loops into parallelized operations for better readability and performance. - Updated the assertion function in utils.py to print detailed error messages. - Adjusted tensor sizes in examples
- Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️  Outside diff range comments (1)
tilelang/language/builtin.py (1)
340-346: sync_global is not a grid-wide barrier; docstring is misleading + stray printThis issues a memory fence, not a cross-block rendezvous. Also, the print will spam logs during codegen.
Apply:
-def sync_global(): - """Synchronize all threads in the entire grid. - """ +def sync_global(): + """Device/global memory fence for participating threads. + Note: this is NOT a grid-wide barrier and does not synchronize blocks.""" tx, ty, tz = get_thread_bindings() ex, ey, ez = get_block_extents() - print(tx, ty, tz, ex, ey, ez) args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez] return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))
🧹 Nitpick comments (21)
tilelang/language/builtin.py (1)
358-358: Fix CI: add trailing newlineRuff W292: No newline at end of file.
Please add a trailing newline at EOF to satisfy the linter.
examples/dequantize_gemm/utils.py (3)
27-47: Avoid 4x expansion with repeat_interleave; use stack+gather to cut mem/opsCurrent approach replicates data 4x then masks via nested where. A stack(...).gather(...) pattern is leaner and keeps shapes tight.
I can draft a vectorized gather-based variant if you want it in this PR.
17-19: Prefer explicit exceptions over bare asserts for user-facing validationAsserts can be stripped with -O and give poor error messages.
- assert tensor.dim() == 2 and tensor.dtype == torch.uint8 + if tensor.dim() != 2 or tensor.dtype != torch.uint8: + raise TypeError("tensor must be 2-D and dtype=torch.uint8") - assert K % 2 == 0, "Number of columns must be even" + if K % 2 != 0: + raise ValueError("Number of columns must be even")
117-152: Duplication with examples/gdn/utils.py + unused arg ‘data’ + missing newline
- calc_sim/assert_similar/print_red_warning duplicate examples/gdn/utils.py; consider importing to avoid drift.
- Ruff ARG001: ‘data’ is unused in assert_similar.
- CI: add trailing newline at EOF (W292).
Apply:
-def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): +def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True): @@ - diff = (1. - sim).item() - print(f'{diff=}') + diff = (1. - sim).item() + if data: + print(f'{name} {data} diff={diff}') + else: + print(f'{name} diff={diff}') @@ - print_red_warning(f'{name} Error: {diff=}') + print_red_warning(f'{name} Error: diff={diff} {data}') if raise_assert: raise AssertionError +``` And consider replacing these local definitions with: ```python #from examples.gdn.utils import print_red_warning, calc_sim, assert_similarDon’t forget to add a trailing newline at EOF.
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
26-26: Import path switch looks goodMoving to tilelang.quantize aligns with the new exports.
Optional: hoist this import to module scope to avoid per-JIT import overhead.
examples/dequantize_gemm/test_example_dequantize_gemm.py (1)
33-37: Import heavy example inside the test to avoid module-import side effectsKeeps CPU-only envs safer and reduces import-time work when the test is skipped.
-@tilelang.testing.requires_cuda -@tilelang.testing.requires_cuda_compute_version_ge(9, 0) -def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): - example_dequant_groupedgemm_bf16_mxfp4_hopper.main() +@tilelang.testing.requires_cuda +@tilelang.testing.requires_cuda_compute_version_ge(9, 0) +def test_example_dequant_groupedgemm_bf16_mxfp4_hopper(): + import example_dequant_groupedgemm_bf16_mxfp4_hopper + example_dequant_groupedgemm_bf16_mxfp4_hopper.main()examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (7)
13-32: Docstring claims applyingscaleinside helper, but code applies it outside.The description says the exponent is adjusted by
scaleand clamped, yet the function never uses thescaleparameter (the clamped addition is commented out). This can confuse future readers and risks double‑scaling if someone later “fixes” the discrepancy.Apply one of:
- Align docs to reality (scale applied outside via 2^Scale).
- Or actually use
scalein the exponent (then remove external multiply). The former is lower risk for this PR.
46-49: Unusedscaleargument in_tir_u8_to_f4_to_bf16.
scaleis passed through but not used. Keep it only if you intend to re‑enable in‑helper exponent adjust soon; otherwise drop it to avoid misleading API.- def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr, - dtype: str): + def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str): @@ - Scale[ - bx * block_N + i, k * block_K // scale_size + j // - scale_size], # Scale is the exponential part, within the representation of uint8 - dtype=out_dtype, - ) * T.shift_left( - 1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size])) + dtype=out_dtype, + ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))Note: If you keep the parameter for API compatibility, at least mark it unused in the docstring.
238-239: Scale semantics mismatch with docstrings (2^(Scale) vs 2^(Scale-127)).Both fast and simple paths multiply by
2^Scale(T.shift_left(1, Scale)), while several docstrings mention2^(Scale - 127). Please normalize the documentation and comments to the implemented behavior.
391-395: Vectorized scaling: derivescale_sizefrom shapes and use robust casting.Avoid relying on a module/global
scale_size. Also ensure the pow result is in a floating type to prevent dtype surprises across devices.- B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)]) + B = torch_convert_bit_twiddling(qB) + group_cols = Scale.shape[1] # = K // scale_size + scale_size_eff = B.shape[1] // group_cols + col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff + scale = Scale[:, col_groups].to(torch.float32) + B.mul_((2.0**scale).to(B.dtype))
414-418: Same scaling concerns in bias variant.Mirror the fix from the non‑bias path to avoid global
scale_sizereliance and dtype pitfalls.- B = torch_convert_bit_twiddling(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)]) + B = torch_convert_bit_twiddling(qB) + group_cols = Scale.shape[1] + scale_size_eff = B.shape[1] // group_cols + col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff + scale = Scale[:, col_groups].to(torch.float32) + B.mul_((2.0**scale).to(B.dtype))
438-442: Simple path scaling: same shape/dtype robustness.- B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)]) + B = torch_convert(qB) + group_cols = Scale.shape[1] + scale_size_eff = B.shape[1] // group_cols + col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff + scale = Scale[:, col_groups].to(torch.float32) + B.mul_((2.0**scale).to(B.dtype))
466-470: Simple+bias path scaling: mirror the robust version.- B = torch_convert(qB) - B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)]) + B = torch_convert(qB) + group_cols = Scale.shape[1] + scale_size_eff = B.shape[1] // group_cols + col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff + scale = Scale[:, col_groups].to(torch.float32) + B.mul_((2.0**scale).to(B.dtype))examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (8)
129-196: Fast dequant macro: good structure; confirm exponent type cast.
Scale_local_thread_exponent[0] = T.shift_left(1, Scale_local_thread[0])shifts an int and stores into BF16. An explicit cast improves clarity and avoids backend‑dependent implicit casts.- Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0])) + Scale_local_thread_exponent[0] = T.cast(out_dtype, T.shift_left(1, Scale_local_thread[0]))Please confirm TileLang lowers the implicit cast as expected on your target SMs; otherwise keep the explicit cast.
203-221: Simple dequant macro: same scale semantics mismatch as elsewhere.Docs mention
2^(Scale - 127), but implementation uses2^Scale. Align comments or change math consistently across the repo.
277-279: Bias initialization uses a parallel fill; ok but consider T.copy when with_bias=True.You already copy Bias into
Bias_shared. You could initializeC_localwith a tiled copy or vectorized add after GEMM to reduce extra writes. Non‑blocking.
306-346: Reference MoE is O(padding_M) dequantizations of huge B — precompute per expert.You dequantize and rescale
Binside the per‑token loop, repeating work for all tokens of the same expert. This will explode runtime for realistic shapes.- # Iterate over sorted_token_ids - for idx in range(len(sorted_token_ids)): # padding_M + # Precompute per-expert dequantized/scaled weights once + unique_eids = torch.unique(expert_ids).tolist() + deqB = {} + for eid in unique_eids: + Be = torch_convert_bit_twiddling(qB[eid]) + group_cols = Scale.shape[2] # = K // scale_size + scale_size_eff = Be.shape[1] // group_cols + col_groups = torch.arange(Be.shape[1], device=Be.device) // scale_size_eff + sf = (2.0**Scale[eid][:, col_groups].to(torch.float32)).to(Be.dtype) + deqB[eid] = Be * sf + + # Iterate over sorted_token_ids + for idx in range(len(sorted_token_ids)): # padding_M token_id = sorted_token_ids[idx] if token_id == -1: continue - expert_id = expert_ids[idx // block_M] + expert_id = int(expert_ids[idx // block_M].item()) @@ - B = torch_convert_bit_twiddling(qB[expert_id]) # shape: (N, K) - B *= 2**( - Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to( - torch.bfloat16)) + B = deqB[expert_id]
389-389: Fix fullwidth parentheses in comment.Replace
(padding_M,)with ASCII(padding_M,)to satisfy Ruff RUF003 and avoid ambiguity.- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
399-431: Debug prints: gate behind a flag or logger.Unconditional prints (
sorted_token_ids,expert_ids) will spam stdout in benchmarks/tests.- print(f'{sorted_token_ids=}') - print(f'{expert_ids=}') + if bool(int(os.getenv("TL_DEBUG_MOE", "0"))): + print(f'{sorted_token_ids=}') + print(f'{expert_ids=}')(Remember to
import os.)
109-150: CI: trailing whitespace on blank lines (W293).Ruff flagged blank lines with spaces around these regions. Strip trailing whitespace to unblock CI.
Run: ruff --fix examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
464-464: CI: add newline at EOF (W292).Add a trailing newline to satisfy linters.
- main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E) \ No newline at end of file + main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E) +
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (7)
- examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py(4 hunks)
- examples/dequantize_gemm/example_dequant_gemm_fine_grained.py(1 hunks)
- examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py(1 hunks)
- examples/dequantize_gemm/test_example_dequantize_gemm.py(2 hunks)
- examples/dequantize_gemm/utils.py(2 hunks)
- tilelang/language/builtin.py(1 hunks)
- tilelang/quantize/__init__.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
tilelang/quantize/quantization.py (1)
_tir_packed_to_unsigned_convert(258-266)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (9)
examples/dequantize_gemm/utils.py (1)
torch_convert_bit_twiddling(4-56)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group(51-107)tilelang/language/__init__.py (2)
import_source(198-200)
annotate_layout(103-141)tilelang/language/ast/ir.py (1)
func_name(206-214)tilelang/language/allocate.py (3)
alloc_local(39-50)
alloc_fragment(53-64)
alloc_shared(21-36)tilelang/language/tir/op.py (3)
shift_left(2768-2784)
call_extern(172-194)
address_of(463-479)tilelang/language/parallel.py (1)
Parallel(8-28)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/pipeline.py (1)
Pipelined(9-46)
tilelang/language/builtin.py (1)
tilelang/language/tir/op.py (1)
tvm_storage_sync(533-546)
examples/dequantize_gemm/test_example_dequantize_gemm.py (2)
tilelang/testing/__init__.py (1)
requires_cuda_compute_version_ge(95-96)examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
main(225-301)
main(398-456)
tilelang/quantize/__init__.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
_tir_u8_to_f4_to_bf16(10-49)examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
_tir_u8_to_f4_to_bf16(218-259)tilelang/quantize/quantization.py (1)
_tir_u8_to_f4_to_bf16(30-73)
examples/dequantize_gemm/utils.py (1)
examples/gdn/utils.py (3)
print_red_warning(4-5)
calc_sim(8-15)
assert_similar(18-40)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
389-389: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
389-389: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
tilelang/quantize/__init__.py
8-8: Unused noqa directive (non-enabled: F401)
Remove unused noqa directive
(RUF100)
examples/dequantize_gemm/utils.py
131-131: Unused function argument: data
(ARG001)
🪛 GitHub Actions: CI Test on AMD
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
[error] 109-109: Ruff: W293 Blank line contains whitespace.
[error] 115-115: Ruff: W293 Blank line contains whitespace.
[error] 136-136: Ruff: W293 Blank line contains whitespace.
[error] 143-143: Ruff: W293 Blank line contains whitespace.
[error] 150-150: Ruff: W293 Blank line contains whitespace.
[error] 464-464: Ruff: W292 No newline at end of file.
tilelang/language/builtin.py
[error] 358-358: Ruff: W292 No newline at end of file.
examples/dequantize_gemm/utils.py
[error] 152-152: Ruff: W292 No newline at end of file.
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: format-check
🔇 Additional comments (2)
tilelang/language/builtin.py (1)
334-336: Docstring fix from “warp” to “block” is correctMatches CUDA semantics of shared-memory barrier (aka __syncthreads).
tilelang/quantize/__init__.py (1)
1-9: Clean up exports and fix linter warning
- _tir_packed_to_unsigned_convert_with_zeros is still defined at tilelang/quantize/quantization.py:269 and still exported from tilelang/quantize/init.py — update the PR description or remove the export/definition if removal was intended.
- Remove the unnecessary "# noqa: F401" on _tir_u8_to_f4_to_bf16 in tilelang/quantize/init.py to address RUF100; verify there are no external consumers before deleting the export.
        
          
                examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
          
            Show resolved
            Hide resolved
        
      | for i, j in T.Parallel(block_M, block_K): | ||
| if sorted_token_ids_shared[i] != -1: | ||
| A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j] | ||
| if fast_dequant: | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uninitialized rows of A_shared for padded tokens.
Rows for -1 tokens aren’t written but still participate in GEMM. Zero them to avoid undefined reads.
-                for i, j in T.Parallel(block_M, block_K):
-                    if sorted_token_ids_shared[i] != -1:
-                        A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
+                for i, j in T.Parallel(block_M, block_K):
+                    if sorted_token_ids_shared[i] != -1:
+                        A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
+                    else:
+                        A_shared[i, j] = T.const(0, in_dtype)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| for i, j in T.Parallel(block_M, block_K): | |
| if sorted_token_ids_shared[i] != -1: | |
| A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j] | |
| if fast_dequant: | |
| for i, j in T.Parallel(block_M, block_K): | |
| if sorted_token_ids_shared[i] != -1: | |
| A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j] | |
| else: | |
| A_shared[i, j] = T.const(0, in_dtype) | |
| if fast_dequant: | 
| # Avoid integer overflow by using a float32 multiplier for the exponent scaling | ||
| bf16_new = bf16_bf16 * (2.0**126) | ||
|  | ||
| return bf16_new | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Return dtype likely float32; violates docstring contract (bf16 expected)
Multiplying a bf16 tensor by a float literal promotes to float32 in PyTorch; the function then returns f32.
Apply:
-    # Avoid integer overflow by using a float32 multiplier for the exponent scaling
-    bf16_new = bf16_bf16 * (2.0**126)
-
-    return bf16_new
+    # Keep result in bf16 as documented
+    scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
+    return (bf16_bf16 * scale)📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| # Avoid integer overflow by using a float32 multiplier for the exponent scaling | |
| bf16_new = bf16_bf16 * (2.0**126) | |
| return bf16_new | |
| # Keep result in bf16 as documented | |
| scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device) | |
| return (bf16_bf16 * scale) | 
🤖 Prompt for AI Agents
In examples/dequantize_gemm/utils.py around lines 53 to 56, the multiplication
by a float literal promotes the bf16 tensor to float32 so the function returns
f32 instead of bf16 as the docstring promises; after computing bf16_new keep the
numeric scaling but cast the result back to bfloat16 (e.g., use
.to(torch.bfloat16) or .bfloat16()) before returning so the return dtype matches
the docstring.
…dequant_groupedgemm_bf16_mxfp4_hopper.py - Added a new function `get_configs()` to generate hyperparameter configurations for tuning. - Updated the `matmul` function to utilize autotuning with the new configurations. - Improve kernel performance via vectorization and threadblock swizzle. - Enhanced the main function to support the new autotuning inputs and updated parameters for better performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
295-299: Initialize topk_weights_shared for padded rows.Uninitialized values for padded rows are later used in a multiply. Set them to 0.
for i in T.Parallel(block_M): if sorted_token_ids_shared[i] != -1: topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]] + else: + topk_weights_shared[i] = T.const(0, out_dtype)
313-321: Zero A_shared for padded rows to avoid undefined reads in GEMM.Rows for -1 tokens are skipped and left uninitialized but still participate in GEMM.
for k in T.Pipelined(K // block_K, num_stages=num_stages): for copy_i in T.serial(block_M * block_K // threads // 16): base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): - A_shared[base // block_K, base % block_K + - copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, - k * block_K + base % block_K + copy_j] + A_shared[base // block_K, base % block_K + copy_j] = \ + A[sorted_token_ids_shared[base // block_K] // topk, + k * block_K + base % block_K + copy_j] + else: + for copy_j in T.vectorized(16): + A_shared[base // block_K, base % block_K + copy_j] = T.const(0, in_dtype)
🧹 Nitpick comments (4)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (4)
170-173: Doc fix: scale factor description mismatches implementation.The code applies 2**Scale; doc says 2^(Scale - 127). Update the doc to match behavior.
- - Loads the corresponding per-block scale entry, interprets it as an exponent bias - (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor. + - Loads the corresponding per-block scale entry, interprets it as a base-2 exponent + (applies 2**Scale), and multiplies the dequantized BF16 fragment by that factor.
91-96: Doc fix: default tile sizes mismatch the function signature.Defaults here should read 128, 256, 128.
- block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128). + block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 128, 256, 128).
287-287: Swizzle panel size likely should be warp size (32).use_swizzle docs say “panel size is the number of threads in a warp”. Consider 32 instead of 10.
- T.use_swizzle(10) + T.use_swizzle(32)
427-427: Ruff RUF003: replace fullwidth parentheses in comment.- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
- examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py(1 hunks)
- examples/dequantize_gemm/test_example_dequantize_gemm.py(2 hunks)
- tilelang/language/builtin.py(1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/dequantize_gemm/test_example_dequantize_gemm.py
- tilelang/language/builtin.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (9)
examples/dequantize_gemm/utils.py (1)
torch_convert_bit_twiddling(4-56)tilelang/autotuner/capture.py (1)
set_autotune_inputs(100-118)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group(51-107)tilelang/language/__init__.py (3)
import_source(198-200)
annotate_layout(103-141)
use_swizzle(94-100)tilelang/language/allocate.py (3)
alloc_local(39-50)
alloc_fragment(53-64)
alloc_shared(21-36)tilelang/language/parallel.py (1)
Parallel(8-28)tilelang/language/copy.py (1)
copy(84-152)tilelang/language/fill.py (1)
clear(24-48)tilelang/language/pipeline.py (1)
Pipelined(9-46)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
427-427: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
427-427: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
228-251: Possible double scaling in simple_dequant path; verify parity with ref.simple_dequant passes Scale to _tir_u8_to_f4_to_bf16 and multiplies again by 2**Scale. Confirm that the intrin does not already incorporate scaling; otherwise you’re scaling twice. Please run with fast_dequant=False and compare to ref_moe.
| T.call_extern( | ||
| func_name, | ||
| T.address_of(B_local_thread[0]), | ||
| T.address_of(B_dequantize_local_thread[0]), | ||
| 1, | ||
| dtype=out_dtype, | ||
| ) | 
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix incorrect T.call_extern signature (will raise at build time).
call_extern expects dtype first, then func_name. Passing func_name positionally and dtype as a kw arg results in “multiple values for argument 'dtype'”.
Apply this diff:
-                T.call_extern(
-                    func_name,
-                    T.address_of(B_local_thread[0]),
-                    T.address_of(B_dequantize_local_thread[0]),
-                    1,
-                    dtype=out_dtype,
-                )
+                T.call_extern(
+                    "void",
+                    func_name,
+                    T.address_of(B_local_thread[0]),
+                    T.address_of(B_dequantize_local_thread[0]),
+                    1,
+                )📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| T.call_extern( | |
| func_name, | |
| T.address_of(B_local_thread[0]), | |
| T.address_of(B_dequantize_local_thread[0]), | |
| 1, | |
| dtype=out_dtype, | |
| ) | |
| T.call_extern( | |
| "void", | |
| func_name, | |
| T.address_of(B_local_thread[0]), | |
| T.address_of(B_dequantize_local_thread[0]), | |
| 1, | |
| ) | 
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
around lines 209 to 215, T.call_extern is called with func_name as the first
positional arg and dtype passed as a keyword, which conflicts with the expected
signature (dtype first, then func_name) and causes a build-time “multiple values
for argument 'dtype'” error; fix by reordering the call to pass the dtype as the
first argument and the function name second (remove the dtype keyword argument),
keeping the remaining arguments (addresses and integer) unchanged.
        
          
                examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
              
                Outdated
          
            Show resolved
            Hide resolved
        
      There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 0
♻️ Duplicate comments (4)
examples/dequantize_gemm/utils.py (1)
53-56: Return dtype promoted to float32; violates docstring (should return bfloat16).Multiplying bf16 by a float literal upcasts to float32. Keep the scale as bf16 to preserve the documented return type.
Apply:
- # Avoid integer overflow by using a float32 multiplier for the exponent scaling - bf16_new = bf16_bf16 * (2.0**126) - - return bf16_new + # Keep result as bfloat16 as documented + scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device) + return bf16_bf16 * scaleexamples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (3)
295-300: Uninitializedtopk_weights_sharedfor padded rows.When
sorted_token_ids_shared[i] == -1, the weight remains undefined but is later used in a multiply.Apply:
for i in T.Parallel(block_M): if sorted_token_ids_shared[i] != -1: topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]] + else: + topk_weights_shared[i] = T.const(0, out_dtype)
313-322: Uninitialized rows ofA_sharedfor padded tokens.Rows for
-1tokens are left untouched and still participate in GEMM.Apply:
for copy_i in T.serial(block_M * block_K // threads // 16): base = copy_i * threads * 16 + tx * 16 if sorted_token_ids_shared[base // block_K] != -1: for copy_j in T.vectorized(16): A_shared[base // block_K, base % block_K + copy_j] = A[sorted_token_ids_shared[base // block_K] // topk, k * block_K + base % block_K + copy_j] + else: + for copy_j in T.vectorized(16): + A_shared[base // block_K, base % block_K + copy_j] = T.const(0, in_dtype)
209-215: IncorrectT.call_externsignature; will raise at runtime.API is
call_extern(dtype, func_name, *args, ...). You're passingfunc_namefirst anddtypeas a kwarg.Apply:
- T.call_extern( - func_name, - T.address_of(B_local_thread[0]), - T.address_of(B_dequantize_local_thread[0]), - 1, - dtype=out_dtype, - ) + T.call_extern( + "void", + func_name, + T.address_of(B_local_thread[0]), + T.address_of(B_dequantize_local_thread[0]), + 1, + )
🧹 Nitpick comments (4)
examples/dequantize_gemm/utils.py (2)
131-153:dataarg is unused inassert_similar; address Ruff ARG001 and improve pass logging.Either remove the arg or use it on success.
Apply:
- if not (0 <= diff <= eps): + if not (0 <= diff <= eps): print_red_warning(f'{name} Error: {diff=}') if raise_assert: raise AssertionError + else: + if data: + print(f"{name} {data} passed")
21-31: Memory-heavy repeat_interleave; consider reshape+stack to avoid 4× duplication.Using
repeat_interleave(4)creates a wide intermediate. You can computeres{0..3}on(N, K//2)andtorch.stack(..., dim=-1)thenview(N, -1); orgatherwith a(K//2, 4)position index to reduce peak memory.examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
134-134: Redundant self-assignment.
import_source = import_sourceis a no-op; remove to avoid confusion.- import_source = import_source
433-433: Ambiguous fullwidth parentheses in comment (RUF003).Replace
(padding_M,)with ASCII(padding_M,).- expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,) + expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda') # (padding_M,)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (2)
- examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py(1 hunks)
- examples/dequantize_gemm/utils.py(2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (4)
examples/dequantize_gemm/utils.py (2)
torch_convert_bit_twiddling(4-56)
assert_similar(131-152)tilelang/autotuner/capture.py (1)
set_autotune_inputs(100-118)tilelang/quantize/mxfp.py (1)
get_mxfp_intrin_group(51-107)tilelang/language/__init__.py (2)
import_source(198-200)
use_swizzle(94-100)
examples/dequantize_gemm/utils.py (1)
examples/gdn/utils.py (3)
print_red_warning(4-5)
calc_sim(8-15)
assert_similar(18-40)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
433-433: Comment contains ambiguous ( (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?
(RUF003)
433-433: Comment contains ambiguous ) (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?
(RUF003)
examples/dequantize_gemm/utils.py
131-131: Unused function argument: data
(ARG001)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: bot-task
- GitHub Check: format-check
- GitHub Check: format-check
🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)
271-271: Confirm dtype forB_dequantize_shared.Dequant output buffer is allocated with
in_dtype; typically this should beout_dtype. Ifin_dtype != out_dtype, types will mismatch GEMM expectations.[suggest_minor_issue]
Potential change:- B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype) + B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, out_dtype)Please confirm intended types for A/B tiles and
T.gemm.
This pull request refactors and optimizes the dequantization and bit-twiddling routines in
examples/dequantize_gemm, focusing on improving performance and code clarity. The main changes include replacing slow nested loops with efficient vectorized PyTorch operations for scaling and bit manipulation, as well as adding comprehensive tests to validate correctness and measure performance.Performance and correctness improvements:
Bin all reference dequantization functions (ref_program_twiddling,ref_program_twiddling_with_bias,ref_program_simple, andref_program_simple_with_bias), significantly improving performance. [1] [2] [3] [4]torch_convert_bit_twiddlingto use parallel, vectorized PyTorch operators for decoding and scaling, removing the inner Python loop and enabling efficient CUDA execution.Testing and validation:
utils.pythat benchmarks and validates the outputs oftorch_convert_bit_twiddlingandtorch_convert_bit_twiddling_parallel, ensuring both implementations produce matching results within a small tolerance.Documentation:
torch_convert_bit_twiddlingto clarify that the implementation is now parallel and uses torch operators.- Added a new example for grouped matrix multiplication with experts inexample_dequant_groupgemm_bf16_mxfp4_hopper.py.torch_convert_bit_twiddlingfunction inutils.pyto utilize parallel processing, enhancing efficiency and clarity in the conversion process.This pull request focuses on improving the performance and clarity of the dequantization utilities and reference implementations in
examples/dequantize_gemm, as well as clarifying some docstrings in thetilelanglibrary. The main changes include vectorizing previously loop-based scaling operations, introducing a parallel implementation for bit-twiddling conversion, and adding test code for validation and benchmarking.Performance improvements and vectorization
ref_program_twiddling,ref_program_twiddling_with_bias,ref_program_simple, andref_program_simple_with_bias), significantly improving efficiency and readability. [1] [2] [3] [4]torch_convert_bit_twiddlingto a fully parallel implementation using torch operators, eliminating Python loops and enabling efficient execution on CUDA devices.Testing and validation
utils.pythat benchmarks and validates the outputs oftorch_convert_bit_twiddlingandtorch_convert_bit_twiddling_parallel, ensuring correctness and performance.Documentation improvements
torch_convert_bit_twiddlingto clarify that it is now a parallel implementation.sync_threadsandsync_globalintilelang/language/builtin.pyto better reflect their scope and behavior.Summary by CodeRabbit
New Features
Refactor
Utilities
Documentation
Tests