Skip to content

Conversation

@Rachmanino
Copy link
Collaborator

@Rachmanino Rachmanino commented Sep 12, 2025

This pull request refactors and optimizes the dequantization and bit-twiddling routines in examples/dequantize_gemm, focusing on improving performance and code clarity. The main changes include replacing slow nested loops with efficient vectorized PyTorch operations for scaling and bit manipulation, as well as adding comprehensive tests to validate correctness and measure performance.

Performance and correctness improvements:

  • Replaced nested for-loops with vectorized PyTorch operations for scaling the matrix B in all reference dequantization functions (ref_program_twiddling, ref_program_twiddling_with_bias, ref_program_simple, and ref_program_simple_with_bias), significantly improving performance. [1] [2] [3] [4]
  • Fully rewrote torch_convert_bit_twiddling to use parallel, vectorized PyTorch operators for decoding and scaling, removing the inner Python loop and enabling efficient CUDA execution.

Testing and validation:

  • Added a main test block to utils.py that benchmarks and validates the outputs of torch_convert_bit_twiddling and torch_convert_bit_twiddling_parallel, ensuring both implementations produce matching results within a small tolerance.

Documentation:

  • Updated the docstring for torch_convert_bit_twiddling to clarify that the implementation is now parallel and uses torch operators.- Added a new example for grouped matrix multiplication with experts in example_dequant_groupgemm_bf16_mxfp4_hopper.py.
  • Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance.
  • Updated torch_convert_bit_twiddling function in utils.py to utilize parallel processing, enhancing efficiency and clarity in the conversion process.
    This pull request focuses on improving the performance and clarity of the dequantization utilities and reference implementations in examples/dequantize_gemm, as well as clarifying some docstrings in the tilelang library. The main changes include vectorizing previously loop-based scaling operations, introducing a parallel implementation for bit-twiddling conversion, and adding test code for validation and benchmarking.

Performance improvements and vectorization

  • Replaced nested for-loops with vectorized torch operations for scaling in all reference dequantization functions (ref_program_twiddling, ref_program_twiddling_with_bias, ref_program_simple, and ref_program_simple_with_bias), significantly improving efficiency and readability. [1] [2] [3] [4]
  • Refactored torch_convert_bit_twiddling to a fully parallel implementation using torch operators, eliminating Python loops and enabling efficient execution on CUDA devices.

Testing and validation

  • Added a main block to utils.py that benchmarks and validates the outputs of torch_convert_bit_twiddling and torch_convert_bit_twiddling_parallel, ensuring correctness and performance.

Documentation improvements

  • Updated the docstring in torch_convert_bit_twiddling to clarify that it is now a parallel implementation.
  • Clarified the descriptions of sync_threads and sync_global in tilelang/language/builtin.py to better reflect their scope and behavior.

Summary by CodeRabbit

  • New Features

    • Added a grouped MoE dequantized GEMM example with top-k routing, optional bias, autotuning and benchmarking.
  • Refactor

    • Vectorized dequantization and scaling across examples for faster, equivalent results.
    • Updated quantize exports to include a new dequantization intrinsic.
  • Utilities

    • Added debugging utilities for tensor similarity checks and colored warnings.
  • Documentation

    • Clarified synchronization semantics in docstrings (block vs. grid).
  • Tests

    • Added a CUDA-only test for the new grouped GEMM example.

Rachmanino and others added 2 commits September 12, 2025 14:59
- Added a new example for grouped matrix multiplication with experts in `example_dequant_groupgemm_bf16_mxfp4_hopper.py`.
- Improved dequantization logic in existing examples by replacing nested loops with vectorized operations for better performance.
- Updated `torch_convert_bit_twiddling` function in `utils.py` to utilize parallel processing, enhancing efficiency and clarity in the conversion process.
Co-authored-by: Zhengju Tang <97930865+tzj-fxz@users.noreply.github.com>
@coderabbitai
Copy link
Contributor

coderabbitai bot commented Sep 12, 2025

Walkthrough

Adds a new grouped MoE FP4→BF16 dequantized GEMM example and a CUDA test, vectorizes per-column-group scaling and bit‑twiddling in example utilities, swaps a quantize export, and updates two builtin docstrings. No public function signatures changed except the quantize exports.

Changes

Cohort / File(s) Summary of edits
New grouped MoE GEMM example
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
Added TileLang JIT/autotuned grouped MoE GEMM with FP4→BF16 dequant (fast intrinsic + fallback), reference PyTorch impl, synthetic data generator, autotune/benchmarking, validation, and CLI entrypoint.
New CUDA test
examples/dequantize_gemm/test_example_dequantize_gemm.py
Added test_example_dequant_groupedgemm_bf16_mxfp4_hopper() gated by CUDA and compute-capability >= 9.0 to run the new example.
Vectorized scaling in MXFP4 example
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
Replaced nested-loop elementwise scaling of B with a broadcasted vectorized operation across four reference variants; semantics preserved, no signature changes.
Import path adjustment
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py
Moved _tir_packed_to_unsigned_convert import from bitblas.quantization to tilelang.quantize; no logic changes.
Vectorized bit‑twiddling & test helpers
examples/dequantize_gemm/utils.py
Rewrote torch_convert_bit_twiddling to use vectorized tensor ops with input validation; removed inner-loop converter; added print_red_warning, calc_sim, and assert_similar helpers.
Quantize export change
tilelang/quantize/__init__.py
Updated public exports: removed _tir_packed_to_unsigned_convert_with_zeros, added _tir_u8_to_f4_to_bf16.
Builtin docstring updates
tilelang/language/builtin.py
Docstring clarifications: sync_threads now says "block", sync_global now says "grid"; no functional changes.

Sequence Diagram(s)

sequenceDiagram
  autonumber
  participant User
  participant Main as example_dequant_groupedgemm.main
  participant Data as get_data()
  participant Builder as JIT Builder
  participant Kernel as matmul (TileLang)
  participant Ref as ref_moe (PyTorch)
  participant Prof as tilelang.profiler

  User->>Main: invoke (M,N,K,topk,E,fast_dequant,with_bias)
  Main->>Data: generate A, qB, Scale, Bias, topk_weights, ids, padding
  Main->>Builder: build/compile kernel (fast/simple dequant)
  Builder-->>Kernel: compiled kernel
  Main->>Kernel: launch kernel with inputs
  Kernel->>Kernel: tiled K-loop → dequant → GEMM → apply bias/topk
  Kernel-->>Main: C (grouped, weighted)
  Main->>Ref: compute reference C_ref
  Ref-->>Main: C_ref
  Main->>Main: validate via assert_similar
  Main->>Prof: benchmark
  Prof-->>Main: latency stats
  Main-->>User: results & validation
Loading
sequenceDiagram
  autonumber
  participant Tile as Kernel Block
  participant A_s as A_shared
  participant B_p as qB_packed
  participant Deq as Dequant Macro
  participant B_d as B_bf16_tile
  participant MMA as GEMM/Acc

  Tile->>A_s: load A tile
  Tile->>B_p: load packed FP4 tile
  alt fast_dequant
    Tile->>Deq: call external twiddling intrinsic
  else simple_dequant
    Tile->>Deq: call `_tir_u8_to_f4_to_bf16` path
  end
  Deq-->>B_d: BF16 tile (per-block Scale applied)
  MMA->>MMA: compute tiled GEMM (apply Bias/topk)
  MMA-->>Tile: write back results
Loading

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~60 minutes

Possibly related PRs

Suggested reviewers

  • LeiWang1999

Poem

I twitch my whiskers at kernels anew,
FP4 bytes become BF16 through a hop or two,
Experts tiled and routers queued,
Broadcasted scales — no loops to feud,
CUDA hums, the rabbit stamps a review. 🥕🐰

Pre-merge checks and finishing touches

✅ Passed checks (3 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title clearly highlights the primary change — adding an MXFP4 grouped GEMM example for FusedMoE — which matches the large new example file and the stated PR objectives; it is concise, specific, and useful for a teammate scanning history.
Docstring Coverage ✅ Passed Docstring coverage is 100.00% which is sufficient. The required threshold is 80.00%.
✨ Finishing touches
  • 📝 Generate Docstrings
🧪 Generate unit tests
  • Create PR with unit tests
  • Post copyable unit tests in a comment

📜 Recent review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 1b85cb6 and 4eeaf53.

📒 Files selected for processing (1)
  • examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
  • examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run bash format.sh in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work!

🚀

@Rachmanino
Copy link
Collaborator Author

Rachmanino commented Sep 12, 2025

TODO

  • Correctness
  • Tune performance (TMA..)
  • Parallelize all related torch functions (later)

@Rachmanino
Copy link
Collaborator Author

cc @tzj-fxz

@Rachmanino Rachmanino changed the title [Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE and reimplement previous functions for acceleration [Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE and parallelize previous functions for acceleration Sep 12, 2025
@tzj-fxz tzj-fxz self-requested a review September 12, 2025 15:47
@tzj-fxz
Copy link
Contributor

tzj-fxz commented Sep 14, 2025

Add the similarity-based metric to check correctness for the large values will somewhat magnify the mismatch

…_hopper.py and utils.py

- Removed unnecessary cache disabling and manual seed setting in the example.
- Simplified nested loops into parallelized operations for better readability and performance.
- Updated the assertion function in utils.py to print detailed error messages.
- Adjusted tensor sizes in examples
Rachmanino and others added 4 commits September 15, 2025 11:13
- Changed the import statement for `_tir_packed_to_unsigned_convert` from `bitblas.quantization` to `tilelang.quantize` to reflect the new module structure.
@Rachmanino Rachmanino marked this pull request as ready for review September 15, 2025 11:38
@Rachmanino Rachmanino changed the title [Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE and parallelize previous functions for acceleration [Enhancement] Add a MXFP4 grouped GEMM example for FusedMoE Sep 15, 2025
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

Caution

Some comments are outside the diff and can’t be posted inline due to platform limitations.

⚠️ Outside diff range comments (1)
tilelang/language/builtin.py (1)

340-346: sync_global is not a grid-wide barrier; docstring is misleading + stray print

This issues a memory fence, not a cross-block rendezvous. Also, the print will spam logs during codegen.

Apply:

-def sync_global():
-    """Synchronize all threads in the entire grid.
-    """
+def sync_global():
+    """Device/global memory fence for participating threads.
+    Note: this is NOT a grid-wide barrier and does not synchronize blocks."""
     tx, ty, tz = get_thread_bindings()
     ex, ey, ez = get_block_extents()
-    print(tx, ty, tz, ex, ey, ez)
     args = ["global", tx == 0 and ty == 0 and tz == 0, ex * ey * ez]
     return evaluate(tir.Call("handle", "tir.tvm_storage_sync", args))
🧹 Nitpick comments (21)
tilelang/language/builtin.py (1)

358-358: Fix CI: add trailing newline

Ruff W292: No newline at end of file.

Please add a trailing newline at EOF to satisfy the linter.

examples/dequantize_gemm/utils.py (3)

27-47: Avoid 4x expansion with repeat_interleave; use stack+gather to cut mem/ops

Current approach replicates data 4x then masks via nested where. A stack(...).gather(...) pattern is leaner and keeps shapes tight.

I can draft a vectorized gather-based variant if you want it in this PR.


17-19: Prefer explicit exceptions over bare asserts for user-facing validation

Asserts can be stripped with -O and give poor error messages.

-    assert tensor.dim() == 2 and tensor.dtype == torch.uint8
+    if tensor.dim() != 2 or tensor.dtype != torch.uint8:
+        raise TypeError("tensor must be 2-D and dtype=torch.uint8")
-    assert K % 2 == 0, "Number of columns must be even"
+    if K % 2 != 0:
+        raise ValueError("Number of columns must be even")

117-152: Duplication with examples/gdn/utils.py + unused arg ‘data’ + missing newline

  • calc_sim/assert_similar/print_red_warning duplicate examples/gdn/utils.py; consider importing to avoid drift.
  • Ruff ARG001: ‘data’ is unused in assert_similar.
  • CI: add trailing newline at EOF (W292).

Apply:

-def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
+def assert_similar(x, y, eps=1e-8, name="tensor", data="", raise_assert=True):
@@
-    diff = (1. - sim).item()
-    print(f'{diff=}')
+    diff = (1. - sim).item()
+    if data:
+        print(f'{name} {data} diff={diff}')
+    else:
+        print(f'{name} diff={diff}')
@@
-        print_red_warning(f'{name} Error: {diff=}')
+        print_red_warning(f'{name} Error: diff={diff} {data}')
         if raise_assert:
             raise AssertionError
+```
And consider replacing these local definitions with:
```python
#from examples.gdn.utils import print_red_warning, calc_sim, assert_similar

Don’t forget to add a trailing newline at EOF.

examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)

26-26: Import path switch looks good

Moving to tilelang.quantize aligns with the new exports.

Optional: hoist this import to module scope to avoid per-JIT import overhead.

examples/dequantize_gemm/test_example_dequantize_gemm.py (1)

33-37: Import heavy example inside the test to avoid module-import side effects

Keeps CPU-only envs safer and reduces import-time work when the test is skipped.

-@tilelang.testing.requires_cuda
-@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
-def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
-    example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
+@tilelang.testing.requires_cuda
+@tilelang.testing.requires_cuda_compute_version_ge(9, 0)
+def test_example_dequant_groupedgemm_bf16_mxfp4_hopper():
+    import example_dequant_groupedgemm_bf16_mxfp4_hopper
+    example_dequant_groupedgemm_bf16_mxfp4_hopper.main()
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (7)

13-32: Docstring claims applying scale inside helper, but code applies it outside.

The description says the exponent is adjusted by scale and clamped, yet the function never uses the scale parameter (the clamped addition is commented out). This can confuse future readers and risks double‑scaling if someone later “fixes” the discrepancy.

Apply one of:

  • Align docs to reality (scale applied outside via 2^Scale).
  • Or actually use scale in the exponent (then remove external multiply). The former is lower risk for this PR.

46-49: Unused scale argument in _tir_u8_to_f4_to_bf16.

scale is passed through but not used. Keep it only if you intend to re‑enable in‑helper exponent adjust soon; otherwise drop it to avoid misleading API.

- def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, scale: tir.PrimExpr,
-                           dtype: str):
+ def _tir_u8_to_f4_to_bf16(nbit: int, val: tir.PrimExpr, pos: tir.PrimExpr, dtype: str):
@@
-                    Scale[
-                        bx * block_N + i, k * block_K // scale_size + j //
-                        scale_size],  # Scale is the exponential part, within the representation of uint8
-                    dtype=out_dtype,
-                ) * T.shift_left(
-                    1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))
+                    dtype=out_dtype,
+                ) * T.shift_left(1, (Scale[bx * block_N + i, k * block_K // scale_size + j // scale_size]))

Note: If you keep the parameter for API compatibility, at least mark it unused in the docstring.


238-239: Scale semantics mismatch with docstrings (2^(Scale) vs 2^(Scale-127)).

Both fast and simple paths multiply by 2^Scale (T.shift_left(1, Scale)), while several docstrings mention 2^(Scale - 127). Please normalize the documentation and comments to the implemented behavior.


391-395: Vectorized scaling: derive scale_size from shapes and use robust casting.

Avoid relying on a module/global scale_size. Also ensure the pow result is in a floating type to prevent dtype surprises across devices.

-    B = torch_convert_bit_twiddling(qB)
-    B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
+    B = torch_convert_bit_twiddling(qB)
+    group_cols = Scale.shape[1]  # = K // scale_size
+    scale_size_eff = B.shape[1] // group_cols
+    col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff
+    scale = Scale[:, col_groups].to(torch.float32)
+    B.mul_((2.0**scale).to(B.dtype))

414-418: Same scaling concerns in bias variant.

Mirror the fix from the non‑bias path to avoid global scale_size reliance and dtype pitfalls.

-    B = torch_convert_bit_twiddling(qB)
-    B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
+    B = torch_convert_bit_twiddling(qB)
+    group_cols = Scale.shape[1]
+    scale_size_eff = B.shape[1] // group_cols
+    col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff
+    scale = Scale[:, col_groups].to(torch.float32)
+    B.mul_((2.0**scale).to(B.dtype))

438-442: Simple path scaling: same shape/dtype robustness.

-    B = torch_convert(qB)
-    B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
+    B = torch_convert(qB)
+    group_cols = Scale.shape[1]
+    scale_size_eff = B.shape[1] // group_cols
+    col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff
+    scale = Scale[:, col_groups].to(torch.float32)
+    B.mul_((2.0**scale).to(B.dtype))

466-470: Simple+bias path scaling: mirror the robust version.

-    B = torch_convert(qB)
-    B *= 2**(Scale[:, (torch.arange(B.shape[1], device=B.device) // scale_size)])
+    B = torch_convert(qB)
+    group_cols = Scale.shape[1]
+    scale_size_eff = B.shape[1] // group_cols
+    col_groups = torch.arange(B.shape[1], device=B.device) // scale_size_eff
+    scale = Scale[:, col_groups].to(torch.float32)
+    B.mul_((2.0**scale).to(B.dtype))
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (8)

129-196: Fast dequant macro: good structure; confirm exponent type cast.

Scale_local_thread_exponent[0] = T.shift_left(1, Scale_local_thread[0]) shifts an int and stores into BF16. An explicit cast improves clarity and avoids backend‑dependent implicit casts.

-                Scale_local_thread_exponent[0] = T.shift_left(1, (Scale_local_thread[0]))
+                Scale_local_thread_exponent[0] = T.cast(out_dtype, T.shift_left(1, Scale_local_thread[0]))

Please confirm TileLang lowers the implicit cast as expected on your target SMs; otherwise keep the explicit cast.


203-221: Simple dequant macro: same scale semantics mismatch as elsewhere.

Docs mention 2^(Scale - 127), but implementation uses 2^Scale. Align comments or change math consistently across the repo.


277-279: Bias initialization uses a parallel fill; ok but consider T.copy when with_bias=True.

You already copy Bias into Bias_shared. You could initialize C_local with a tiled copy or vectorized add after GEMM to reduce extra writes. Non‑blocking.


306-346: Reference MoE is O(padding_M) dequantizations of huge B — precompute per expert.

You dequantize and rescale B inside the per‑token loop, repeating work for all tokens of the same expert. This will explode runtime for realistic shapes.

-    # Iterate over sorted_token_ids
-    for idx in range(len(sorted_token_ids)):  # padding_M
+    # Precompute per-expert dequantized/scaled weights once
+    unique_eids = torch.unique(expert_ids).tolist()
+    deqB = {}
+    for eid in unique_eids:
+        Be = torch_convert_bit_twiddling(qB[eid])
+        group_cols = Scale.shape[2]  # = K // scale_size
+        scale_size_eff = Be.shape[1] // group_cols
+        col_groups = torch.arange(Be.shape[1], device=Be.device) // scale_size_eff
+        sf = (2.0**Scale[eid][:, col_groups].to(torch.float32)).to(Be.dtype)
+        deqB[eid] = Be * sf
+
+    # Iterate over sorted_token_ids
+    for idx in range(len(sorted_token_ids)):  # padding_M
         token_id = sorted_token_ids[idx]
         if token_id == -1:
             continue
-        expert_id = expert_ids[idx // block_M]
+        expert_id = int(expert_ids[idx // block_M].item())
@@
-        B = torch_convert_bit_twiddling(qB[expert_id])  # shape: (N, K)
-        B *= 2**(
-            Scale[expert_id][:, (torch.arange(B.shape[1], device=B.device) // scale_size)].to(
-                torch.bfloat16))
+        B = deqB[expert_id]

389-389: Fix fullwidth parentheses in comment.

Replace (padding_M,) with ASCII (padding_M,) to satisfy Ruff RUF003 and avoid ambiguity.

-    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)
+    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)

399-431: Debug prints: gate behind a flag or logger.

Unconditional prints (sorted_token_ids, expert_ids) will spam stdout in benchmarks/tests.

-    print(f'{sorted_token_ids=}')
-    print(f'{expert_ids=}')
+    if bool(int(os.getenv("TL_DEBUG_MOE", "0"))):
+        print(f'{sorted_token_ids=}')
+        print(f'{expert_ids=}')

(Remember to import os.)


109-150: CI: trailing whitespace on blank lines (W293).

Ruff flagged blank lines with spaces around these regions. Strip trailing whitespace to unblock CI.

Run: ruff --fix examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py


464-464: CI: add newline at EOF (W292).

Add a trailing newline to satisfy linters.

-    main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)
\ No newline at end of file
+    main(M, N, K, scale_size, fast_dequant=True, with_bias=True, topk=topk, E=E)
+
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between f0d6669 and 7717d16.

📒 Files selected for processing (7)
  • examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (4 hunks)
  • examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1 hunks)
  • examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1 hunks)
  • examples/dequantize_gemm/test_example_dequantize_gemm.py (2 hunks)
  • examples/dequantize_gemm/utils.py (2 hunks)
  • tilelang/language/builtin.py (1 hunks)
  • tilelang/quantize/__init__.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (6)
examples/dequantize_gemm/example_dequant_gemm_fine_grained.py (1)
tilelang/quantize/quantization.py (1)
  • _tir_packed_to_unsigned_convert (258-266)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (9)
examples/dequantize_gemm/utils.py (1)
  • torch_convert_bit_twiddling (4-56)
tilelang/quantize/mxfp.py (1)
  • get_mxfp_intrin_group (51-107)
tilelang/language/__init__.py (2)
  • import_source (198-200)
  • annotate_layout (103-141)
tilelang/language/ast/ir.py (1)
  • func_name (206-214)
tilelang/language/allocate.py (3)
  • alloc_local (39-50)
  • alloc_fragment (53-64)
  • alloc_shared (21-36)
tilelang/language/tir/op.py (3)
  • shift_left (2768-2784)
  • call_extern (172-194)
  • address_of (463-479)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
tilelang/language/builtin.py (1)
tilelang/language/tir/op.py (1)
  • tvm_storage_sync (533-546)
examples/dequantize_gemm/test_example_dequantize_gemm.py (2)
tilelang/testing/__init__.py (1)
  • requires_cuda_compute_version_ge (95-96)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)
  • main (225-301)
  • main (398-456)
tilelang/quantize/__init__.py (3)
examples/dequantize_gemm/example_dequant_gemm_bf16_mxfp4_hopper.py (1)
  • _tir_u8_to_f4_to_bf16 (10-49)
examples/dequantize_gemm/example_dequant_gemm_bf16_fp4_hopper.py (1)
  • _tir_u8_to_f4_to_bf16 (218-259)
tilelang/quantize/quantization.py (1)
  • _tir_u8_to_f4_to_bf16 (30-73)
examples/dequantize_gemm/utils.py (1)
examples/gdn/utils.py (3)
  • print_red_warning (4-5)
  • calc_sim (8-15)
  • assert_similar (18-40)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

389-389: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


389-389: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

tilelang/quantize/__init__.py

8-8: Unused noqa directive (non-enabled: F401)

Remove unused noqa directive

(RUF100)

examples/dequantize_gemm/utils.py

131-131: Unused function argument: data

(ARG001)

🪛 GitHub Actions: CI Test on AMD
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

[error] 109-109: Ruff: W293 Blank line contains whitespace.


[error] 115-115: Ruff: W293 Blank line contains whitespace.


[error] 136-136: Ruff: W293 Blank line contains whitespace.


[error] 143-143: Ruff: W293 Blank line contains whitespace.


[error] 150-150: Ruff: W293 Blank line contains whitespace.


[error] 464-464: Ruff: W292 No newline at end of file.

tilelang/language/builtin.py

[error] 358-358: Ruff: W292 No newline at end of file.

examples/dequantize_gemm/utils.py

[error] 152-152: Ruff: W292 No newline at end of file.

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: format-check
🔇 Additional comments (2)
tilelang/language/builtin.py (1)

334-336: Docstring fix from “warp” to “block” is correct

Matches CUDA semantics of shared-memory barrier (aka __syncthreads).

tilelang/quantize/__init__.py (1)

1-9: Clean up exports and fix linter warning

  • _tir_packed_to_unsigned_convert_with_zeros is still defined at tilelang/quantize/quantization.py:269 and still exported from tilelang/quantize/init.py — update the PR description or remove the export/definition if removal was intended.
  • Remove the unnecessary "# noqa: F401" on _tir_u8_to_f4_to_bf16 in tilelang/quantize/init.py to address RUF100; verify there are no external consumers before deleting the export.

Comment on lines 283 to 286
for i, j in T.Parallel(block_M, block_K):
if sorted_token_ids_shared[i] != -1:
A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
if fast_dequant:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Uninitialized rows of A_shared for padded tokens.

Rows for -1 tokens aren’t written but still participate in GEMM. Zero them to avoid undefined reads.

-                for i, j in T.Parallel(block_M, block_K):
-                    if sorted_token_ids_shared[i] != -1:
-                        A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
+                for i, j in T.Parallel(block_M, block_K):
+                    if sorted_token_ids_shared[i] != -1:
+                        A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
+                    else:
+                        A_shared[i, j] = T.const(0, in_dtype)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for i, j in T.Parallel(block_M, block_K):
if sorted_token_ids_shared[i] != -1:
A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
if fast_dequant:
for i, j in T.Parallel(block_M, block_K):
if sorted_token_ids_shared[i] != -1:
A_shared[i, j] = A[sorted_token_ids_shared[i] // topk, k * block_K + j]
else:
A_shared[i, j] = T.const(0, in_dtype)
if fast_dequant:

Comment on lines +53 to +56
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)

return bf16_new
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Return dtype likely float32; violates docstring contract (bf16 expected)

Multiplying a bf16 tensor by a float literal promotes to float32 in PyTorch; the function then returns f32.

Apply:

-    # Avoid integer overflow by using a float32 multiplier for the exponent scaling
-    bf16_new = bf16_bf16 * (2.0**126)
-
-    return bf16_new
+    # Keep result in bf16 as documented
+    scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
+    return (bf16_bf16 * scale)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
# Avoid integer overflow by using a float32 multiplier for the exponent scaling
bf16_new = bf16_bf16 * (2.0**126)
return bf16_new
# Keep result in bf16 as documented
scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
return (bf16_bf16 * scale)
🤖 Prompt for AI Agents
In examples/dequantize_gemm/utils.py around lines 53 to 56, the multiplication
by a float literal promotes the bf16 tensor to float32 so the function returns
f32 instead of bf16 as the docstring promises; after computing bf16_new keep the
numeric scaling but cast the result back to bfloat16 (e.g., use
.to(torch.bfloat16) or .bfloat16()) before returning so the return dtype matches
the docstring.

Rachmanino and others added 5 commits September 16, 2025 02:16
…dequant_groupedgemm_bf16_mxfp4_hopper.py

- Added a new function `get_configs()` to generate hyperparameter configurations for tuning.
- Updated the `matmul` function to utilize autotuning with the new configurations.
- Improve kernel performance via vectorization and threadblock swizzle.
- Enhanced the main function to support the new autotuning inputs and updated parameters for better performance.
Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 2

♻️ Duplicate comments (2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)

295-299: Initialize topk_weights_shared for padded rows.

Uninitialized values for padded rows are later used in a multiply. Set them to 0.

             for i in T.Parallel(block_M):
                 if sorted_token_ids_shared[i] != -1:
                     topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]]
+                else:
+                    topk_weights_shared[i] = T.const(0, out_dtype)

313-321: Zero A_shared for padded rows to avoid undefined reads in GEMM.

Rows for -1 tokens are skipped and left uninitialized but still participate in GEMM.

             for k in T.Pipelined(K // block_K, num_stages=num_stages):
                 for copy_i in T.serial(block_M * block_K // threads // 16):
                     base = copy_i * threads * 16 + tx * 16
                     if sorted_token_ids_shared[base // block_K] != -1:
                         for copy_j in T.vectorized(16):
-                            A_shared[base // block_K, base % block_K +
-                                     copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
-                                                 k * block_K + base % block_K + copy_j]
+                            A_shared[base // block_K, base % block_K + copy_j] = \
+                                A[sorted_token_ids_shared[base // block_K] // topk,
+                                  k * block_K + base % block_K + copy_j]
+                    else:
+                        for copy_j in T.vectorized(16):
+                            A_shared[base // block_K, base % block_K + copy_j] = T.const(0, in_dtype)
🧹 Nitpick comments (4)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (4)

170-173: Doc fix: scale factor description mismatches implementation.

The code applies 2**Scale; doc says 2^(Scale - 127). Update the doc to match behavior.

-            - Loads the corresponding per-block scale entry, interprets it as an exponent bias
-              (applies 2^(Scale - 127)), and multiplies the dequantized BF16 fragment by that factor.
+            - Loads the corresponding per-block scale entry, interprets it as a base-2 exponent
+              (applies 2**Scale), and multiplies the dequantized BF16 fragment by that factor.

91-96: Doc fix: default tile sizes mismatch the function signature.

Defaults here should read 128, 256, 128.

-        block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 256, 128, 128).
+        block_M, block_N, block_K (int, optional): tile sizes for M, N, and K dimensions (defaults 128, 256, 128).

287-287: Swizzle panel size likely should be warp size (32).

use_swizzle docs say “panel size is the number of threads in a warp”. Consider 32 instead of 10.

-            T.use_swizzle(10)
+            T.use_swizzle(32)

427-427: Ruff RUF003: replace fullwidth parentheses in comment.

-    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)
+    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between 7717d16 and be2f79a.

📒 Files selected for processing (3)
  • examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1 hunks)
  • examples/dequantize_gemm/test_example_dequantize_gemm.py (2 hunks)
  • tilelang/language/builtin.py (1 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
  • examples/dequantize_gemm/test_example_dequantize_gemm.py
  • tilelang/language/builtin.py
🧰 Additional context used
🧬 Code graph analysis (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (9)
examples/dequantize_gemm/utils.py (1)
  • torch_convert_bit_twiddling (4-56)
tilelang/autotuner/capture.py (1)
  • set_autotune_inputs (100-118)
tilelang/quantize/mxfp.py (1)
  • get_mxfp_intrin_group (51-107)
tilelang/language/__init__.py (3)
  • import_source (198-200)
  • annotate_layout (103-141)
  • use_swizzle (94-100)
tilelang/language/allocate.py (3)
  • alloc_local (39-50)
  • alloc_fragment (53-64)
  • alloc_shared (21-36)
tilelang/language/parallel.py (1)
  • Parallel (8-28)
tilelang/language/copy.py (1)
  • copy (84-152)
tilelang/language/fill.py (1)
  • clear (24-48)
tilelang/language/pipeline.py (1)
  • Pipelined (9-46)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

427-427: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


427-427: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)

228-251: Possible double scaling in simple_dequant path; verify parity with ref.

simple_dequant passes Scale to _tir_u8_to_f4_to_bf16 and multiplies again by 2**Scale. Confirm that the intrin does not already incorporate scaling; otherwise you’re scaling twice. Please run with fast_dequant=False and compare to ref_moe.

Comment on lines +209 to +215
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue

Fix incorrect T.call_extern signature (will raise at build time).

call_extern expects dtype first, then func_name. Passing func_name positionally and dtype as a kw arg results in “multiple values for argument 'dtype'”.

Apply this diff:

-                T.call_extern(
-                    func_name,
-                    T.address_of(B_local_thread[0]),
-                    T.address_of(B_dequantize_local_thread[0]),
-                    1,
-                    dtype=out_dtype,
-                )
+                T.call_extern(
+                    "void",
+                    func_name,
+                    T.address_of(B_local_thread[0]),
+                    T.address_of(B_dequantize_local_thread[0]),
+                    1,
+                )
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
T.call_extern(
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
dtype=out_dtype,
)
T.call_extern(
"void",
func_name,
T.address_of(B_local_thread[0]),
T.address_of(B_dequantize_local_thread[0]),
1,
)
🤖 Prompt for AI Agents
In examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py
around lines 209 to 215, T.call_extern is called with func_name as the first
positional arg and dtype passed as a keyword, which conflicts with the expected
signature (dtype first, then func_name) and causes a build-time “multiple values
for argument 'dtype'” error; fix by reordering the call to pass the dtype as the
first argument and the function name second (remove the dtype keyword argument),
keeping the remaining arguments (addresses and integer) unchanged.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 0

♻️ Duplicate comments (4)
examples/dequantize_gemm/utils.py (1)

53-56: Return dtype promoted to float32; violates docstring (should return bfloat16).

Multiplying bf16 by a float literal upcasts to float32. Keep the scale as bf16 to preserve the documented return type.

Apply:

-    # Avoid integer overflow by using a float32 multiplier for the exponent scaling
-    bf16_new = bf16_bf16 * (2.0**126)
-
-    return bf16_new
+    # Keep result as bfloat16 as documented
+    scale = torch.tensor(2.0**126, dtype=torch.bfloat16, device=tensor.device)
+    return bf16_bf16 * scale
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (3)

295-300: Uninitialized topk_weights_shared for padded rows.

When sorted_token_ids_shared[i] == -1, the weight remains undefined but is later used in a multiply.

Apply:

             for i in T.Parallel(block_M):
                 if sorted_token_ids_shared[i] != -1:
                     topk_weights_shared[i] = topk_weights[sorted_token_ids_shared[i]]
+                else:
+                    topk_weights_shared[i] = T.const(0, out_dtype)

313-322: Uninitialized rows of A_shared for padded tokens.

Rows for -1 tokens are left untouched and still participate in GEMM.

Apply:

                 for copy_i in T.serial(block_M * block_K // threads // 16):
                     base = copy_i * threads * 16 + tx * 16
                     if sorted_token_ids_shared[base // block_K] != -1:
                         for copy_j in T.vectorized(16):
                             A_shared[base // block_K, base % block_K +
                                      copy_j] = A[sorted_token_ids_shared[base // block_K] // topk,
                                                  k * block_K + base % block_K + copy_j]
+                    else:
+                        for copy_j in T.vectorized(16):
+                            A_shared[base // block_K, base % block_K + copy_j] = T.const(0, in_dtype)

209-215: Incorrect T.call_extern signature; will raise at runtime.

API is call_extern(dtype, func_name, *args, ...). You're passing func_name first and dtype as a kwarg.

Apply:

-                T.call_extern(
-                    func_name,
-                    T.address_of(B_local_thread[0]),
-                    T.address_of(B_dequantize_local_thread[0]),
-                    1,
-                    dtype=out_dtype,
-                )
+                T.call_extern(
+                    "void",
+                    func_name,
+                    T.address_of(B_local_thread[0]),
+                    T.address_of(B_dequantize_local_thread[0]),
+                    1,
+                )
🧹 Nitpick comments (4)
examples/dequantize_gemm/utils.py (2)

131-153: data arg is unused in assert_similar; address Ruff ARG001 and improve pass logging.

Either remove the arg or use it on success.

Apply:

-    if not (0 <= diff <= eps):
+    if not (0 <= diff <= eps):
         print_red_warning(f'{name} Error: {diff=}')
         if raise_assert:
             raise AssertionError
+    else:
+        if data:
+            print(f"{name} {data} passed")

21-31: Memory-heavy repeat_interleave; consider reshape+stack to avoid 4× duplication.

Using repeat_interleave(4) creates a wide intermediate. You can compute res{0..3} on (N, K//2) and torch.stack(..., dim=-1) then view(N, -1); or gather with a (K//2, 4) position index to reduce peak memory.

examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (2)

134-134: Redundant self-assignment.

import_source = import_source is a no-op; remove to avoid confusion.

-    import_source = import_source

433-433: Ambiguous fullwidth parentheses in comment (RUF003).

Replace (padding_M,) with ASCII (padding_M,).

-    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)
+    expert_ids = torch.tensor(expert_ids, dtype=torch.int32, device='cuda')  # (padding_M,)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between be2f79a and 1b85cb6.

📒 Files selected for processing (2)
  • examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1 hunks)
  • examples/dequantize_gemm/utils.py (2 hunks)
🧰 Additional context used
🧬 Code graph analysis (2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (4)
examples/dequantize_gemm/utils.py (2)
  • torch_convert_bit_twiddling (4-56)
  • assert_similar (131-152)
tilelang/autotuner/capture.py (1)
  • set_autotune_inputs (100-118)
tilelang/quantize/mxfp.py (1)
  • get_mxfp_intrin_group (51-107)
tilelang/language/__init__.py (2)
  • import_source (198-200)
  • use_swizzle (94-100)
examples/dequantize_gemm/utils.py (1)
examples/gdn/utils.py (3)
  • print_red_warning (4-5)
  • calc_sim (8-15)
  • assert_similar (18-40)
🪛 Ruff (0.12.2)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py

433-433: Comment contains ambiguous (FULLWIDTH LEFT PARENTHESIS). Did you mean ( (LEFT PARENTHESIS)?

(RUF003)


433-433: Comment contains ambiguous (FULLWIDTH RIGHT PARENTHESIS). Did you mean ) (RIGHT PARENTHESIS)?

(RUF003)

examples/dequantize_gemm/utils.py

131-131: Unused function argument: data

(ARG001)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
  • GitHub Check: bot-task
  • GitHub Check: format-check
  • GitHub Check: format-check
🔇 Additional comments (1)
examples/dequantize_gemm/example_dequant_groupedgemm_bf16_mxfp4_hopper.py (1)

271-271: Confirm dtype for B_dequantize_shared.

Dequant output buffer is allocated with in_dtype; typically this should be out_dtype. If in_dtype != out_dtype, types will mismatch GEMM expectations.

[suggest_minor_issue]
Potential change:

-            B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, in_dtype)
+            B_dequantize_shared = T.alloc_shared(B_dequantize_shared_shape, out_dtype)

Please confirm intended types for A/B tiles and T.gemm.

@tzj-fxz tzj-fxz requested a review from LeiWang1999 September 17, 2025 02:15
@LeiWang1999 LeiWang1999 merged commit 8554cb0 into tile-ai:main Sep 17, 2025
6 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants