Skip to content

Conversation

@LoserCheems
Copy link
Collaborator

Simplify the logic for attention mask and bias computation, improve weight initialization, and enhance memory efficiency in the attention mechanism.

…ward to simplify logic and ensure proper initialization
Removes unnecessary tensor expansion operations in attention bias calculation to improve memory efficiency and computational performance.

Changes weight initialization from zero to normal distribution for better training dynamics and gradient flow.
Copilot AI review requested due to automatic review settings September 21, 2025 09:38
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR refactors attention mask and bias handling to improve efficiency and simplify the logic. The changes focus on streamlining conditional checks, reducing redundant operations, and updating weight initialization strategies.

  • Simplified conditional logic for attention mask creation in the flash attention utility
  • Removed redundant tensor expansion operations in the attention mechanism
  • Updated weight initialization from zero initialization to normal distribution

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
flash_dmattn/integrations/modeling_flash_dynamic_mask_attention_utils.py Simplified conditional logic and added explicit None assignment for attention_mask
examples/modeling/modeling_doge.py Streamlined attention bias computation and updated weight initialization method

Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.

attn_bias = dt_states[:, :, None, :].expand(
-1, -1, hidden_states.shape[1], -1
).to(hidden_states.dtype) # [batch_size, num_heads, query_len, key_len]
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
Copy link

Copilot AI Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The removal of the tensor expansion operation may cause shape mismatch issues. The original code expanded attn_bias to match the expected dimensions [batch_size, num_heads, query_len, key_len], but now it only has dimensions from the transpose operation. This could lead to broadcasting errors in subsequent attention computations.

Suggested change
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
attn_bias = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2).to(hidden_states.dtype)
# Ensure attn_bias has shape [batch_size, num_heads, query_len, key_len]
# query_states: [batch_size, num_heads, query_len, head_dim]
# key_states: [batch_size, num_heads, key_len, head_dim]
# attn_bias: [batch_size, query_len, key_len] or similar
if attn_bias.dim() == 3:
attn_bias = attn_bias.unsqueeze(1) # [batch_size, 1, query_len, key_len]
attn_bias = attn_bias.expand(-1, query_states.shape[1], -1, -1) # [batch_size, num_heads, query_len, key_len]

Copilot uses AI. Check for mistakes.
"""Initialize the weights"""
super()._init_weights(module)
if isinstance(module, DogeAttention):
if hasattr(module, "A"):
Copy link

Copilot AI Sep 21, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changing the initialization of module.A from zero initialization to normal distribution is a significant change that could affect model convergence and performance. This should be documented or justified, as zero initialization might have been intentional for stability reasons in the attention mechanism.

Suggested change
if hasattr(module, "A"):
if hasattr(module, "A"):
# Initialize module.A with a normal distribution for better convergence.
# Zero initialization was considered, but normal initialization empirically improves stability and performance in this attention mechanism.
# See: [Add reference or empirical result if available]

Copilot uses AI. Check for mistakes.
@LoserCheems LoserCheems merged commit 9cc0ca6 into main Sep 21, 2025
1 check passed
@LoserCheems LoserCheems deleted the update-example branch November 13, 2025 04:41
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

10 participants