-
Notifications
You must be signed in to change notification settings - Fork 30k
Fix Causality Handling in Flash Attention to Support Bidirectional Attention #39707
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Fix Causality Handling in Flash Attention to Support Bidirectional Attention #39707
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
BTW checked this: #35390 let's make sure we don't break it!
After checking #35390 , it seems like it only pops import torch
from transformers import AutoTokenizer, AutoModel
model_name = "Qwen/Qwen2.5-7B"
device = torch.device("cuda")
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModel.from_pretrained(model_name, device_map=device, attn_implementation="flash_attention_2", torch_dtype="auto")
input_sentence = "I want bidirectional attention"
inputs = tokenizer([input_sentence], return_tensors="pt", padding=True).to(model.device)
output_with_causal = model(**inputs, is_causal=True)
print(f'last hidden state with causal : {output_with_causal["last_hidden_state"]}')
output_without_causal = model(**inputs, is_causal=False)
print(f'last hidden state without causal : {output_without_causal["last_hidden_state"]}') and you should see the output:
which indeed the output is changed by altering |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is actually a nice feature to have!
Pinging @vasqu to make sure we are not forgetting anything!
run-slow: auto |
#39554 (comment) as per his comment we need to make sure the attention mask creation takes this into account imo! |
This comment contains run-slow, running the specified jobs: models: ['models/auto'] |
What does this PR do?
Refer to #39554
The original implementation of the
flash_attention_forward
function is restricted to performing causal attention and does not support bidirectional attention. This behavior stems from how the function handles causality:Attention.is_causal
attribute, which belongs to theAttention
class in the model.Attention.is_causal
is set toTrue
, enforcing causal attention (where the model only attends to previous tokens in a sequence).is_causal
key from the keyword arguments (kwargs
) passed to it, this value is not used. Instead, it always defers to the hardcodedAttention.is_causal
value.As a result, even if a user attempts to pass
is_causal=False
throughkwargs
to enable bidirectional attention (where the model can attend to both previous and future tokens), the input is ignored. Consequently, the current setup makes it impossible to perform bidirectional attention when using flash attention.Before submitting
Pull Request section?
to it if that's the case.
documentation guidelines, and
here are tips on formatting docstrings.
Who can review?
Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@ArthurZucker @Cyrilvallez @vasqu