-
Notifications
You must be signed in to change notification settings - Fork 39
Refactor attention mask and bias handling for efficiency #177
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…ward to simplify logic and ensure proper initialization
Removes unnecessary tensor expansion operations in attention bias calculation to improve memory efficiency and computational performance. Changes weight initialization from zero to normal distribution for better training dynamics and gradient flow.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR refactors attention mask and bias handling to improve efficiency and simplify the logic. The changes focus on streamlining conditional checks, reducing redundant operations, and updating weight initialization strategies.
- Simplified conditional logic for attention mask creation in the flash attention utility
- Removed redundant tensor expansion operations in the attention mechanism
- Updated weight initialization from zero initialization to normal distribution
Reviewed Changes
Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
| flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py | Simplified conditional logic and added explicit None assignment for attention_mask |
| examples/modeling/modeling_doge.py | Streamlined attention bias computation and updated weight initialization method |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
| attn_bias = dt_states[:, :, None, :].expand( | ||
| -1, -1, hidden_states.shape[1], -1 | ||
| ).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len] | ||
| attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) |
Copilot
AI
Sep 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The removal of the tensor expansion operation may cause shape mismatch issues. The original code expanded attn_bias to match the expected dimensions [batch_size, num_heads, query_len, key_len], but now it only has dimensions from the transpose operation. This could lead to broadcasting errors in subsequent attention computations.
| attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) | |
| attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype) | |
| # Ensure attn_bias has shape [batch_size, num_heads, query_len, key_len] | |
| # query_states: [batch_size, num_heads, query_len, head_dim] | |
| # key_states: [batch_size, num_heads, key_len, head_dim] | |
| # attn_bias: [batch_size, query_len, key_len] or similar | |
| if attn_bias.dim() == 3: | |
| attn_bias = attn_bias.unsqueeze(1) # [batch_size, 1, query_len, key_len] | |
| attn_bias = attn_bias.expand(-1, query_states.shape[1], -1, -1) # [batch_size, num_heads, query_len, key_len] |
| """Initialize the weights""" | ||
| super()._init_weights(module) | ||
| if isinstance(module, DogeAttention): | ||
| if hasattr(module, "A"): |
Copilot
AI
Sep 21, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Changing the initialization of module.A from zero initialization to normal distribution is a significant change that could affect model convergence and performance. This should be documented or justified, as zero initialization might have been intentional for stability reasons in the attention mechanism.
| if hasattr(module, "A"): | |
| if hasattr(module, "A"): | |
| # Initialize module.A with a normal distribution for better convergence. | |
| # Zero initialization was considered, but normal initialization empirically improves stability and performance in this attention mechanism. | |
| # See: [Add reference or empirical result if available] |
Simplify the logic for attention mask and bias computation, improve weight initialization, and enhance memory efficiency in the attention mechanism.