Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Summary

  • This fix addresses inconsistencies in mask creation and handling across different backends, improving performance and memory usage.

Root Cause

  • The previous implementation had duplicated and ad-hoc mask logic, leading to inefficiencies and potential errors in gradient reporting.

Changes

  • Centralized mask creation through a shared utility, standardizing top-k/causal masking across Python, CUDA, Triton, and Flex paths. Improved handling of attention bias and mask shapes.

Reproduction

  • The issue can be reproduced by running benchmarks with varying input shapes and observing discrepancies in memory usage and performance.

Tests

  • Updated benchmarks to validate the new masking logic and ensure correct behavior across all supported backends.

Compatibility

  • Changes maintain backward compatibility, but users should verify their implementations against the new masking utility.

Checklist

  • Linked issue provided
  • Adds or updates tests
  • Updates docs if needed
  • No perf regressions

Replaces the ad‑hoc mask builder with a shared utility to centralize top‑k/causal masking across CUDA, Triton, and Flex paths. Passes explicit tensor shapes and dtype min to ensure correctness, and adds block‑sparse support (block size 64) for dynamic CUDA.

Handles optional masks safely when repeating KV for GQA. Updates benchmark bias to broadcast over queries (B,H,1,K) to reduce memory and match masking expectations.

Improves consistency, reduces duplication, and prepares for extensible masking strategies.
Removes an unused symbol to clean up imports and silence linter warnings.
Reduces clutter and avoids confusion with unreferenced utilities.
Introduces no functional changes.
Replaces the local mask builder with a centralized utility to standardize top-k/causal masking across Python, CUDA, Triton, and Flex paths.

Passes explicit batch/query/key sizes and dtype min, repeats masks only when present, and skips masked_fill when unneeded.

Reduces duplication, improves consistency and maintainability, and streamlines GQA handling.
Replaces ad-hoc mask construction with a shared mask utility to unify top‑k/causal masking across Python, CUDA, Triton, and Flex paths. Reduces duplication and allows safely skipping masking when not required.

Fixes gradient reporting in the Flex path by returning grads w.r.t. the original input tensors.

Also clarifies shape handling and guards masked fill, improving robustness.
Replaces duplicated mask logic with a shared utility to standardize top‑k/causal masking and dtype handling across CUDA, Triton, and Flex backends.

Aligns attention bias to a broadcastable per‑query shape to cut memory and simplify kernel expectations.

Removes redundant KV/mask/bias repetition in the Triton path, repeats conditionally for Flex, and makes GQA fan‑out explicit for correctness and performance.
Updates forward/backward equivalence benchmarks to create attention bias with a singleton query dimension so it broadcasts across queries.

Aligns shapes with kernel expectations during cached decoding, reduces memory footprint, and prevents shape mismatches across CUDA, Triton, and Flex paths.
Copilot AI review requested due to automatic review settings November 13, 2025 04:38
Copilot finished reviewing on behalf of LoserCheems November 13, 2025 04:43
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copilot reviewed 4 out of 4 changed files in this pull request and generated 5 comments.

Comments suppressed due to low confidence (1)

benchmarks/backward_equivalence.py:268

  • Missing repeat_kv calls for Triton implementation. The Triton function expects key_states, value_states, attn_mask, and attn_bias to have num_heads dimension (not num_kv_heads). The following code should be added before the transpose operations:
# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias_leaf, num_queries_per_kv)

This is consistent with the forward implementation in forward_equivalence.py lines 242-245.

    # Ensure correct data types and memory layout for Triton function
    query_states = query_states.transpose(1, 2)         # [batch, query_len, num_heads, head_dim]  
    key_states = key_states.transpose(1, 2)             # [batch, key_len, num_heads, head_dim]  
    value_states = value_states.transpose(1, 2)         # [batch, key_len, num_heads, head_dim]  

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

Comment on lines 280 to 283
# Ensure correct data types and memory layout for Triton function
query_states = query_states.transpose(1, 2).contiguous() # [batch, query_len, num_heads, head_dim]
key_states = key_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
value_states = value_states.transpose(1, 2).contiguous() # [batch, key_len, num_heads, head_dim]
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing repeat_kv calls for Triton implementation. The Triton function expects key_states, value_states, attn_mask, and attn_bias to have num_heads dimension (not num_kv_heads). The following code should be added before the transpose operations:

# Repeat KV for multi-head attention (GQA support)
key_states = repeat_kv(key_states, num_queries_per_kv)
value_states = repeat_kv(value_states, num_queries_per_kv)
attn_mask = repeat_kv(attn_mask, num_queries_per_kv)
attn_bias = repeat_kv(attn_bias, num_queries_per_kv)

This is consistent with the forward implementation in forward_performance.py lines 273-276 and forward_equivalence.py lines 242-245.

Copilot uses AI. Check for mistakes.
batch_size, num_heads, query_len, _ = query_states.shape
_, num_kv_heads, key_len, _ = key_states.shape

num_queries_per_kv = num_heads // num_kv_heads
Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_queries_per_kv is not used.

Suggested change
num_queries_per_kv = num_heads // num_kv_heads

Copilot uses AI. Check for mistakes.
Comment on lines +187 to +188
num_queries_per_kv = num_heads // num_kv_heads

Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_queries_per_kv is not used.

Suggested change
num_queries_per_kv = num_heads // num_kv_heads

Copilot uses AI. Check for mistakes.
Comment on lines +161 to +162
num_queries_per_kv = num_heads // num_kv_heads

Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_queries_per_kv is not used.

Suggested change
num_queries_per_kv = num_heads // num_kv_heads

Copilot uses AI. Check for mistakes.
Comment on lines +183 to +184
num_queries_per_kv = num_heads // num_kv_heads

Copy link

Copilot AI Nov 13, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Variable num_queries_per_kv is not used.

Suggested change
num_queries_per_kv = num_heads // num_kv_heads

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit ed981b0 into main Nov 13, 2025
7 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

8 participants