-
Notifications
You must be signed in to change notification settings - Fork 334
[Example] Update GQA varlen fwd and MHA varlen fwd #1071
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the TileLang project. Please remember to run We appreciate you taking this step! Our team will review your contribution, and we look forward to your awesome work! 🚀 |
WalkthroughThis PR introduces variable-length flash attention support by adding a new utility module for padding-aware QKV processing and two new example implementations: a grouped query attention kernel and an updated multi-head attention example. The changes refactor shared utilities and increase default configuration parameters. Changes
Estimated code review effort🎯 4 (Complex) | ⏱️ ~50 minutes The PR introduces two new files with substantive logic (new kernel implementation and padding utilities), alongside updates to an existing example. While individual pieces follow consistent patterns, the heterogeneous nature of changes—spanning kernel-level implementations, utility functions, and configuration updates—requires varied reasoning per file. Suggested reviewers
Poem
Pre-merge checks and finishing touches❌ Failed checks (1 warning)
✅ Passed checks (2 passed)
✨ Finishing touches
🧪 Generate unit tests (beta)
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
🧹 Nitpick comments (6)
examples/flash_attention/varlen_utils.py (3)
4-4: Use package‑relative import for sibling moduleAbsolute import will break when running from repo root. Prefer relative import.
-from bert_padding import pad_input, unpad_input +from .bert_padding import pad_input, unpad_inputFollow‑up: add empty init.py files under examples/ and examples/flash_attention/ to enable package imports.
7-18: Minor: make dtype/device explicit in mask generationAvoid implicit dtype promotion on some backends; keep consistency with lengths dtype.
- padding_mask = ( - repeat(torch.arange(max_seqlen, device=device), "s -> b s", b=batch_size) < lengths) + padding_mask = ( + repeat(torch.arange(max_seqlen, device=device, dtype=lengths.dtype), "s -> b s", + b=batch_size) < lengths)
20-36: Return contract is opaque; consider a typed containerLong tuple is error‑prone. Suggest a NamedTuple/dataclass to self‑document fields (shapes, dtypes).
I can send a small patch introducing VarLenQKV struct and update call sites.
examples/flash_attention/example_mha_fwd_varlen.py (1)
10-10: Make sibling imports robustUse package‑relative import to work from repo root via
python -m examples.flash_attention.example_mha_fwd_varlen.-from varlen_utils import generate_random_padding_mask, generate_qkv +from .varlen_utils import generate_random_padding_mask, generate_qkvAdd
__init__.pyunder examples/ and examples/flash_attention/.examples/flash_attention/example_gqa_fwd_varlen.py (2)
9-9: Use package‑relative import for sibling utilsPrevents import errors from repo root.
-from varlen_utils import generate_random_padding_mask, generate_qkv +from .varlen_utils import generate_random_padding_mask, generate_qkvAdd
__init__.pyunder examples/ and examples/flash_attention/.
194-201: Duplicate seeding; keep one
set_random_seed(0)is called twice.- tilelang.testing.set_random_seed(0) - - causal = False + causal = False if causal: total_flops *= 0.5 - - tilelang.testing.set_random_seed(0) + tilelang.testing.set_random_seed(0)
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (3)
examples/flash_attention/example_gqa_fwd_varlen.py(1 hunks)examples/flash_attention/example_mha_fwd_varlen.py(3 hunks)examples/flash_attention/varlen_utils.py(1 hunks)
🧰 Additional context used
🧬 Code graph analysis (3)
examples/flash_attention/varlen_utils.py (1)
examples/flash_attention/bert_padding.py (2)
pad_input(201-213)unpad_input(100-124)
examples/flash_attention/example_gqa_fwd_varlen.py (4)
examples/flash_attention/varlen_utils.py (2)
generate_random_padding_mask(7-18)generate_qkv(21-122)tilelang/env.py (1)
disable_cache(267-268)tilelang/transform/pass_config.py (1)
PassConfigKey(6-104)tilelang/testing/__init__.py (1)
set_random_seed(30-35)
examples/flash_attention/example_mha_fwd_varlen.py (1)
examples/flash_attention/varlen_utils.py (2)
generate_random_padding_mask(7-18)generate_qkv(21-122)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: Quick Lint
🔇 Additional comments (5)
examples/flash_attention/example_mha_fwd_varlen.py (3)
211-211: Bigger defaults may OOM on smaller GPUsBatch=8, heads=64, seq=2048, dim=128 allocates >1GB just for QKV. Consider documenting memory needs or lowering CLI defaults while keeping high‑end values in README.
283-283: LGTM: clearer success messageUser‑friendly and unambiguous.
288-291: LGTM: argparse defaults match main() signatureConsistent with increased workload.
examples/flash_attention/example_gqa_fwd_varlen.py (2)
14-45: Reference attention matches GQA head expansion; goodLogic for kv head replication and masking looks correct for varlen.
Please run both scripts via
python -m examples.flash_attention.example_gqa_fwd_varlenafter adding__init__.pyto ensure imports and assertions pass end‑to‑end.
11-11: No action required—API is properly exportedThe verification confirms that
tilelang.disable_cache()is correctly accessible at the top level. The symbol is defined intilelang/env.py(line 197) and explicitly re-exported intilelang/__init__.py(line 74), making it available as a package-level API.Likely an incorrect or invalid review comment.
| causal = False | ||
| if causal: | ||
| total_flops *= 0.5 | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use is_causal for FLOPs adjustment
Currently tied to local causal=False, so TFlops is wrong when --is_causal is set.
- causal = False
- if causal:
+ causal = is_causal
+ if is_causal:
total_flops *= 0.5📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| causal = False | |
| if causal: | |
| total_flops *= 0.5 | |
| causal = is_causal | |
| if is_causal: | |
| total_flops *= 0.5 |
🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 196 to 199,
the FLOPs adjustment currently uses a hardcoded local variable `causal=False`;
replace that check with the parsed flag `is_causal` (or the existing CLI/config
variable used in this script) so that when `--is_causal` is set the code halves
total_flops correctly; simply remove the local `causal` variable and change the
if condition to `if is_causal:` (or equivalent) so total_flops *= 0.5 runs based
on the actual runtime flag.
| latency = do_bench( | ||
| lambda: kernel(q_unpad, k_unpad, v_unpad, cu_seqlens_q, cu_seqlens_k, max_seqlen_q)) | ||
| print("Tile-lang: {:.2f} ms".format(latency)) | ||
| print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9)) | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix TFLOPs calculation (ms → s conversion missing; off by 1e6)
do_bench prints ms; convert to seconds and then to TFLOPs.
- print("Tile-lang: {:.2f} ms".format(latency))
- print("Tile-lang: {:.2f} TFlops".format(total_flops / latency * 1e-9))
+ print("Tile-lang: {:.2f} ms".format(latency))
+ # TFLOPs = FLOPs / (seconds) / 1e12; latency is ms
+ tflops = total_flops / (latency * 1e-3) / 1e12
+ print("Tile-lang: {:.2f} TFLOPs".format(tflops))🤖 Prompt for AI Agents
In examples/flash_attention/example_gqa_fwd_varlen.py around lines 259 to 263,
the TFLOPs print uses latency in milliseconds without converting to seconds and
uses the wrong scaling factor; convert latency to seconds (e.g., latency_s =
latency / 1000.0) and compute TFLOPs as total_flops / latency_s * 1e-12 (or
equivalently (total_flops / latency) * 1e-15 * 1000) and update the print to use
that value so the units are correct.
| if qkvpacked: | ||
| assert (query_padding_mask == key_padding_mask).all() | ||
| assert nheads == nheads_k | ||
| qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) | ||
| qkv = torch.stack([q, k, v], dim=2) | ||
| if query_padding_mask is not None: | ||
| dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) | ||
| else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fix qkvpacked mask assertion to handle None and use torch.equal
Current assertion fails when either mask is None (bool has no .all()). Also be explicit about equal seqlens when stacking q/k/v.
Apply:
- if qkvpacked:
- assert (query_padding_mask == key_padding_mask).all()
- assert nheads == nheads_k
+ if qkvpacked:
+ assert query_padding_mask is not None and key_padding_mask is not None, \
+ "qkvpacked requires both query and key padding masks."
+ assert torch.equal(query_padding_mask, key_padding_mask), \
+ "qkvpacked requires identical query/key masks."
+ assert nheads == nheads_k, "qkvpacked requires same #heads for q/k/v."
+ assert seqlen_q == seqlen_k, "qkvpacked requires equal sequence lengths."📝 Committable suggestion
‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.
| if qkvpacked: | |
| assert (query_padding_mask == key_padding_mask).all() | |
| assert nheads == nheads_k | |
| qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) | |
| qkv = torch.stack([q, k, v], dim=2) | |
| if query_padding_mask is not None: | |
| dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) | |
| else: | |
| if qkvpacked: | |
| assert query_padding_mask is not None and key_padding_mask is not None, \ | |
| "qkvpacked requires both query and key padding masks." | |
| assert torch.equal(query_padding_mask, key_padding_mask), \ | |
| "qkvpacked requires identical query/key masks." | |
| assert nheads == nheads_k, "qkvpacked requires same #heads for q/k/v." | |
| assert seqlen_q == seqlen_k, "qkvpacked requires equal sequence lengths." | |
| qkv_unpad = torch.stack([q_unpad, k_unpad, v_unpad], dim=1) | |
| qkv = torch.stack([q, k, v], dim=2) | |
| if query_padding_mask is not None: | |
| dqkv_pad_fn = lambda dqkv_unpad: pad_input(dqkv_unpad, indices_q, batch_size, seqlen_q) | |
| else: |
This pull request refactors the variable-length attention example scripts to improve modularity and maintainability. The main changes are the extraction of shared utility functions into a new
varlen_utils.pyfile and updating the example scripts to use these utilities. Additionally, the default configuration for the examples has been updated to use larger, more realistic model sizes.Utility extraction and code reuse:
examples/flash_attention/varlen_utils.pycontaining thegenerate_random_padding_maskandgenerate_qkvfunctions, which handle input padding and mask generation for variable-length attention. This reduces code duplication and centralizes shared logic.examples/flash_attention/example_mha_fwd_varlen.pyto importgenerate_random_padding_maskandgenerate_qkvfromvarlen_utils.py, removing their previous inline implementations.Configuration and usability improvements:
mainfunction and CLI arguments ofexample_mha_fwd_varlen.pyto better match typical large-scale model settings. [1] [2]example_mha_fwd_varlen.pyto clearly indicate when all checks have passed.New example added:
examples/flash_attention/example_gqa_fwd_varlen.pydemonstrating variable-length forward attention with grouped query attention (GQA), using the shared utilities for input preparation and padding.Summary by CodeRabbit