Skip to content

Fix Causality Handling in Flash Attention to Support Bidirectional Attention #39707

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

lucaswychan
Copy link

What does this PR do?

Refer to #39554

The original implementation of the flash_attention_forward function is restricted to performing causal attention and does not support bidirectional attention. This behavior stems from how the function handles causality:

  • The function relies on the Attention.is_causal attribute, which belongs to the Attention class in the model.
  • By default, Attention.is_causal is set to True, enforcing causal attention (where the model only attends to previous tokens in a sequence).
  • This attribute is never modified in the code, meaning the setting is effectively fixed.
  • Additionally, while the function removes the is_causal key from the keyword arguments (kwargs) passed to it, this value is not used. Instead, it always defers to the hardcoded Attention.is_causal value.

As a result, even if a user attempts to pass is_causal=False through kwargs to enable bidirectional attention (where the model can attend to both previous and future tokens), the input is ignored. Consequently, the current setup makes it impossible to perform bidirectional attention when using flash attention.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@ArthurZucker @Cyrilvallez @vasqu

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

BTW checked this: #35390 let's make sure we don't break it!

@lucaswychan
Copy link
Author

lucaswychan commented Jul 28, 2025

After checking #35390 , it seems like it only pops is_causal from kwargs without saving and using it (which is the reason I created this PR). My approach takes a similar way, but I first store the value of is_causal when popping, and fallback to module.is_causal only when is_causal is None (which means is_causal does not exist in kwargs). Below are the sample code to check my fix:

import torch
from transformers import AutoTokenizer, AutoModel

model_name = "Qwen/Qwen2.5-7B"
device = torch.device("cuda")

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, device_map=device, attn_implementation="flash_attention_2", torch_dtype="auto")

input_sentence = "I want bidirectional attention"
inputs = tokenizer([input_sentence], return_tensors="pt", padding=True).to(model.device)

output_with_causal = model(**inputs, is_causal=True)
print(f'last hidden state with causal : {output_with_causal["last_hidden_state"]}')

output_without_causal = model(**inputs, is_causal=False)
print(f'last hidden state without causal : {output_without_causal["last_hidden_state"]}')

and you should see the output:

last hidden state with causal : tensor([[[-2.1719,  2.5469, -6.6875,  ...,  2.9531, -0.2100,  2.2031],
         [-0.7305, -0.2656, -1.5156,  ..., -0.7461, -0.0952, -1.6641],
         [-0.4668,  1.8125,  0.2021,  ...,  0.8867, -1.1875, -1.0547],
         [ 1.8359,  2.3750, -1.5547,  ...,  0.6445, -1.3906, -2.0469],
         [ 4.1562, -3.6250, -2.9844,  ..., -2.1562, -1.2031, -0.3086]]],
       device='cuda:7', dtype=torch.bfloat16, grad_fn=<MulBackward0>)

last hidden state without causal : tensor([[[ 3.0156, -0.7070, -1.5234,  ..., -1.1875,  0.7188, -0.0464],
         [ 0.4629, -1.0781,  0.0135,  ..., -2.2656,  0.2412, -2.6250],
         [ 0.2334, -2.0469, -0.6680,  ..., -1.8906, -0.7305, -3.2969],
         [ 0.9961, -0.9844, -0.9805,  ..., -1.2500, -0.3555, -2.1094],
         [ 1.3125,  0.7539, -0.5781,  ..., -0.4824,  0.0216,  0.3633]]],
       device='cuda:7', dtype=torch.bfloat16, grad_fn=<MulBackward0>)

which indeed the output is changed by altering is_causal.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is actually a nice feature to have!
Pinging @vasqu to make sure we are not forgetting anything!

@ArthurZucker
Copy link
Collaborator

run-slow: auto

@ArthurZucker
Copy link
Collaborator

#39554 (comment) as per his comment we need to make sure the attention mask creation takes this into account imo!

Copy link
Contributor

This comment contains run-slow, running the specified jobs:

models: ['models/auto']
quantizations: [] ...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants