Skip to content

Conversation

@chengyupku
Copy link
Contributor

@chengyupku chengyupku commented Oct 20, 2025

This pull request refactors the variable-length attention example scripts to improve modularity and maintainability. The main changes are the extraction of shared utility functions into a new varlen_utils.py file and updating the example scripts to use these utilities. Additionally, the default configuration for the examples has been updated to use larger, more realistic model sizes.

Utility extraction and code reuse:

  • Created a new utility module examples/flash_attention/varlen_utils.py containing the generate_random_padding_mask and generate_qkv functions, which handle input padding and mask generation for variable-length attention. This reduces code duplication and centralizes shared logic.
  • Updated examples/flash_attention/example_mha_fwd_varlen.py to import generate_random_padding_mask and generate_qkv from varlen_utils.py, removing their previous inline implementations.

Configuration and usability improvements:

  • Increased the default values for batch size, number of heads, sequence length, and dimension in the main function and CLI arguments of example_mha_fwd_varlen.py to better match typical large-scale model settings. [1] [2]
  • Improved output messaging in example_mha_fwd_varlen.py to clearly indicate when all checks have passed.

New example added:

  • Added a new example script examples/flash_attention/example_gqa_fwd_varlen.py demonstrating variable-length forward attention with grouped query attention (GQA), using the shared utilities for input preparation and padding.

Summary by CodeRabbit

  • New Features
    • Added example implementations for variable-length flash attention kernels with padding mask support
    • Introduced utility functions for processing padded and unpadded query-key-value tensors
    • Updated example configurations with larger default parameters for improved demonstrations

@github-actions
Copy link

👋 Hi! Thank you for contributing to the TileLang project.

Please remember to run pre-commit run --all-files in the root directory of the project to ensure your changes are properly linted and formatted. This will help ensure your contribution passes the format check.

We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Oct 20, 2025

Walkthrough

This PR introduces variable-length flash attention support by adding a new utility module for padding-aware QKV processing and two new example implementations: a grouped query attention kernel and an updated multi-head attention example. The changes refactor shared utilities and increase default configuration parameters.

Changes

Cohort / File(s) Summary
New flash attention examples
examples/flash_attention/example_gqa_fwd_varlen.py
Adds example script implementing grouped query attention with var-length flash attention kernel. Includes attention_ref reference implementation, flashattn Tile-lang JIT kernel with multi-stage tiled GEMMs and causal masking support, and benchmarking logic with argparse configuration.
Variable-length utilities
examples/flash_attention/varlen_utils.py
Introduces padding-aware QKV processing utilities: generate_random_padding_mask for creating padding masks with multiple modes, and generate_qkv for handling packed/unpacked QKV variants with unpadding and re-padding functions.
Updated MHA example
examples/flash_attention/example_mha_fwd_varlen.py
Refactors to import utilities from varlen_utils instead of local definitions. Increases default configuration (batch: 2→8, heads: 16→64, seq_len: 256→2048, dim: 32→128). Updates success message.

Estimated code review effort

🎯 4 (Complex) | ⏱️ ~50 minutes

The PR introduces two new files with substantive logic (new kernel implementation and padding utilities), alongside updates to an existing example. While individual pieces follow consistent patterns, the heterogeneous nature of changes—spanning kernel-level implementations, utility functions, and configuration updates—requires varied reasoning per file.

Suggested reviewers

  • LeiWang1999

Poem

🐰 Hops of attention align so fine,
With padding masks in perfect line,
Variable lengths now dance and play,
Through Tile-lang kernels, night and day,
Grouped queries bloom—efficiency's delight! ✨

Pre-merge checks and finishing touches

❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 22.22% which is insufficient. The required threshold is 80.00%. You can run @coderabbitai generate docstrings to improve docstring coverage.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title Check ✅ Passed The title "[Example] Update GQA varlen fwd and MHA varlen fwd" directly references the main example files being modified and added in this pull request. It is specific and avoids vague language, clearly identifying which examples are being worked on. However, the title focuses primarily on the example files themselves and does not capture the refactoring aspect, specifically the extraction of shared utilities into the new varlen_utils.py module or convey that example_gqa_fwd_varlen.py is a new addition. The PR objectives emphasize improved modularity through utility extraction, which is not reflected in the title, making it a partially related representation of the changeset.
✨ Finishing touches
  • 📝 Generate docstrings
🧪 Generate unit tests (beta)
  • Create PR with unit tests
  • Post copyable unit tests in a comment

Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 3

🧹 Nitpick comments (6)
examples/flash_attention/varlen_utils.py (3)

4-4: Use package‑relative import for sibling module

Absolute import will break when running from repo root. Prefer relative import.

-from bert_padding import pad_input, unpad_input
+from .bert_padding import pad_input, unpad_input

Follow‑up: add empty init.py files under examples/ and examples/flash_attention/ to enable package imports.


7-18: Minor: make dtype/device explicit in mask generation

Avoid implicit dtype promotion on some backends; keep consistency with lengths dtype.

-    padding_mask = (
-        repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths)
+    padding_mask = (
+        repeat(torch.arange(max_seqlen, device=device, dtype=lengths.dtype), "s -> b s",
+               b=batch_size) < lengths)

20-36: Return contract is opaque; consider a typed container

Long tuple is error‑prone. Suggest a NamedTuple/dataclass to self‑document fields (shapes, dtypes).

I can send a small patch introducing VarLenQKV struct and update call sites.

examples/flash_attention/example_mha_fwd_varlen.py (1)

10-10: Make sibling imports robust

Use package‑relative import to work from repo root via python -m examples.flash_attention.example_mha_fwd_varlen.

-from varlen_utils import generate_random_padding_mask, generate_qkv
+from .varlen_utils import generate_random_padding_mask, generate_qkv

Add __init__.py under examples/ and examples/flash_attention/.

examples/flash_attention/example_gqa_fwd_varlen.py (2)

9-9: Use package‑relative import for sibling utils

Prevents import errors from repo root.

-from varlen_utils import generate_random_padding_mask, generate_qkv
+from .varlen_utils import generate_random_padding_mask, generate_qkv

Add __init__.py under examples/ and examples/flash_attention/.


194-201: Duplicate seeding; keep one

set_random_seed(0) is called twice.

-    tilelang.testing.set_random_seed(0)
-
-    causal = False
+    causal = False
     if causal:
         total_flops *= 0.5
-
-    tilelang.testing.set_random_seed(0)
+    tilelang.testing.set_random_seed(0)
📜 Review details

Configuration used: CodeRabbit UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b2acfc3 and 4874417.

📒 Files selected for processing (3)
  • examples/flash_attention/example_gqa_fwd_varlen.py (1 hunks)
  • examples/flash_attention/example_mha_fwd_varlen.py (3 hunks)
  • examples/flash_attention/varlen_utils.py (1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/flash_attention/varlen_utils.py (1)
examples/flash_attention/bert_padding.py (2)
  • pad_input (201-213)
  • unpad_input (100-124)
examples/flash_attention/example_gqa_fwd_varlen.py (4)
examples/flash_attention/varlen_utils.py (2)
  • generate_random_padding_mask (7-18)
  • generate_qkv (21-122)
tilelang/env.py (1)
  • disable_cache (267-268)
tilelang/transform/pass_config.py (1)
  • PassConfigKey (6-104)
tilelang/testing/__init__.py (1)
  • set_random_seed (30-35)
examples/flash_attention/example_mha_fwd_varlen.py (1)
examples/flash_attention/varlen_utils.py (2)
  • generate_random_padding_mask (7-18)
  • generate_qkv (21-122)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
  • GitHub Check: Quick Lint
🔇 Additional comments (5)
examples/flash_attention/example_mha_fwd_varlen.py (3)

211-211: Bigger defaults may OOM on smaller GPUs

Batch=8, heads=64, seq=2048, dim=128 allocates >1GB just for QKV. Consider documenting memory needs or lowering CLI defaults while keeping high‑end values in README.


283-283: LGTM: clearer success message

User‑friendly and unambiguous.


288-291: LGTM: argparse defaults match main() signature

Consistent with increased workload.

examples/flash_attention/example_gqa_fwd_varlen.py (2)

14-45: Reference attention matches GQA head expansion; good

Logic for kv head replication and masking looks correct for varlen.

Please run both scripts via python -m examples.flash_attention.example_gqa_fwd_varlen after adding __init__.py to ensure imports and assertions pass end‑to‑end.


11-11: No action required—API is properly exported

The verification confirms that tilelang.disable_cache() is correctly accessible at the top level. The symbol is defined in tilelang/env.py (line 197) and explicitly re-exported in tilelang/__init__.py (line 74), making it available as a package-level API.

Likely an incorrect or invalid review comment.

Comment on lines +196 to +199
causal = False
if causal:
total_flops *= 0.5

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟡 Minor

Use is_causal for FLOPs adjustment

Currently tied to local causal=False, so TFlops is wrong when --is_causal is set.

-    causal = False
-    if causal:
+    causal = is_causal
+    if is_causal:
         total_flops *= 0.5
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
causal = False
if causal:
total_flops *= 0.5
causal = is_causal
if is_causal:
total_flops *= 0.5
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 196 to 199,
the FLOPs adjustment currently uses a hardcoded local variable `causal=False`;
replace that check with the parsed flag `is_causal` (or the existing CLI/config
variable used in this script) so that when `--is_causal` is set the code halves
total_flops correctly; simply remove the local `causal` variable and change the
if condition to `if is_causal:` (or equivalent) so total_flops *= 0.5 runs based
on the actual runtime flag.

Comment on lines +259 to +263
latency = do_bench(
lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q))
print("Tile-lang: {:.2f} ms".format(latency))
print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major

Fix TFLOPs calculation (ms → s conversion missing; off by 1e6)

do_bench prints ms; convert to seconds and then to TFLOPs.

-    print("Tile-lang: {:.2f} ms".format(latency))
-    print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
+    print("Tile-lang: {:.2f} ms".format(latency))
+    # TFLOPs = FLOPs / (seconds) / 1e12; latency is ms
+    tflops = total_flops / (latency * 1e-3) / 1e12
+    print("Tile-lang: {:.2f} TFLOPs".format(tflops))
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 259 to 263,
the TFLOPs print uses latency in milliseconds without converting to seconds and
uses the wrong scaling factor; convert latency to seconds (e.g., latency_s =
latency / 1000.0) and compute TFLOPs as total_flops / latency_s * 1e-12 (or
equivalently (total_flops / latency) * 1e-15 * 1000) and update the print to use
that value so the units are correct.

Comment on lines +62 to +69
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🔴 Critical

Fix qkvpacked mask assertion to handle None and use torch.equal

Current assertion fails when either mask is None (bool has no .all()). Also be explicit about equal seqlens when stacking q/k/v.

Apply:

-    if qkvpacked:
-        assert (query_padding_mask == key_padding_mask).all()
-        assert nheads == nheads_k
+    if qkvpacked:
+        assert query_padding_mask is not None and key_padding_mask is not None, \
+            "qkvpacked requires both query and key padding masks."
+        assert torch.equal(query_padding_mask, key_padding_mask), \
+            "qkvpacked requires identical query/key masks."
+        assert nheads == nheads_k, "qkvpacked requires same #heads for q/k/v."
+        assert seqlen_q == seqlen_k, "qkvpacked requires equal sequence lengths."
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
if qkvpacked:
assert (query_padding_mask == key_padding_mask).all()
assert nheads == nheads_k
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:
if qkvpacked:
assert query_padding_mask is not None and key_padding_mask is not None, \
"qkvpacked requires both query and key padding masks."
assert torch.equal(query_padding_mask, key_padding_mask), \
"qkvpacked requires identical query/key masks."
assert nheads == nheads_k, "qkvpacked requires same #heads for q/k/v."
assert seqlen_q == seqlen_k, "qkvpacked requires equal sequence lengths."
qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1)
qkv = torch.stack([q, k, v], dim=2)
if query_padding_mask is not None:
dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q)
else:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant