-
Notifications
You must be signed in to change notification settings - Fork 30.2k
[qwen2-vl] fix FA2 inference #39121
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[qwen2-vl] fix FA2 inference #39121
Conversation
The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update. |
attention_mask = None | ||
if self.config._attn_implementation != "flash_attention_2": | ||
attention_mask = torch.full( | ||
[1, 1, seq_length, key_states.shape[-2]], | ||
torch.finfo(query_states.dtype).min, | ||
device=query_states.device, | ||
dtype=query_states.dtype, | ||
) | ||
for i in range(1, len(cu_seqlens)): | ||
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This probably also fixes #39067? Might be nice to change the flash attention integration path in the future to prioritize pos+cu_seq (even with an attention mask).
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it doesn't, that issue was there before refactor and I found it's related to FA2 using ragged inputs which the SDPA/eager path doesn't support in the same way
I will comment under #39067 when I have a clear fix. Let's keep this PR open until then, better to fix it once and forever
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks (for bearing with me) :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I understand why we create the mask in the attention layer, we should make sure we do it at the model level. it is the same for all layers
I don't think there was a specific reason, it has been that way since the first model was released. In any case Qwen-VL models seem to be based on FA2 first, while other attn are added just because we wanted to support them Thus the mask is prepared based on |
@zucchini-nlp Hi, the FA2 may not be working as expected. This is because we use the following code in transformers/src/transformers/modeling_flash_attention_utils.py Lines 423 to 426 in ed36f84
However, there is an issue with the variable names here, which causes FA2 to not function correctly. After correcting the code to the following version, I tested the model accuracy and it worked as expected 🤗 attn_output, _ = attention_interface(
self,
query_states,
key_states,
value_states,
attention_mask=attention_mask,
dropout=0.0,
scaling=self.scaling,
cu_seq_lens_k=cu_seqlens, # pass cu seq lens for FA2
cu_seq_lens_q=cu_seqlens,
max_length_q=max_seqlen,
max_length_k=max_seqlen,
is_causal=False,
**kwargs,
) |
@JJJYmmm yeah, thanks, the current state doesn't work yet. In fact renaming only doesn't help because the |
Oh! Yes, I also modified the logic in |
@zucchini-nlp LGTM 🫡 |
is_fa2_with_position_ids = ( | ||
position_ids is not None | ||
and query_states.shape[0] == 1 | ||
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all())) | ||
) | ||
is_fa2_with_varlen_kwargs = all( | ||
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k) | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit unsure about this since we would allow cu_seq
and max_seq
only but on most models we also have RoPE so it's breaking those models silently if we eff up not passing position_ids
(due to RoPE positions being bound to position_ids
as well). We should imo add at least a warning on only varlen kwargs to give some discretion here.
On another note, what do the integration tests use? Are they still working as expected 👀 seems a bit sus
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The integration tests always use position_ids
which is the first is_fa2_with_position_ids
, I just copied it from existing code and moved up here. The second case is added for Qwen only, afaik no other model passes pre-computed cu_lens
form attention layers
In qwen we don't need any position ids, because they are 3D and won't help at all in inferring cu_lens
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I meant the general integration tests like e.g.
def test_small_model_integration_test_batch_flashatt2(self): |
For why, I'm concerned about the second case on qwen is future model additions and general usability, not the validity of qwen. For developers, is_fa2_with_varlen_kwargs
indicates that this suffices for varlen - before, we (unintentionally) checked for the existence of (correct flattened) position ids that RoPE models need when using varlen. Maybe #35941 helps for reference on what I mean.
Imo, it would help to add at least comments that for varlen most models need correct flattened position ids (from e.g. a collator), especially RoPE models which make up the majority of newer models.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmm, am I right that users might be passing only cu_lens
without correct position_ids
? I believe that would be users' responsibility to take care that RoPE is applied correctly, but I will add a comment in code explaining it, sure
In slow integration tests we don't pass position_ids
, not that of I know. For most LLMs the fa2 path integration tests fallback to inferring cu_lens
from the mask, and in Qwen the position_ids
are constructed on-the-fly during forward call. The model has a requirement for adding rope deltas on top of 3D positions and I don't think users would be doing all that manually
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
"Hmm, am I right that users might be passing only cu_lens without correct position_ids?" - Yes, not only users but possibly us as well because it's something that's harder to figure out when done wrong imo :D
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks again
@zucchini-nlp can we refactor the attention mask creation to be outside, cu seqs are not different for a layer! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's push!
* fix FA2 * update is causal flag and remove mask for FA2 * update for FA2 with varlen path * how the tests were passing with different devices? * add comment and ref to the PR * move mask preparation to base pretrained model * seq len is the first dim, not second * fix copies to fix GLM4V
What does this PR do?
Fixes #39095 and set
mask=None
for FA2, otherwise inference fails because FA2 expects a 2d maskTested with
RUN_SLOW=1 pytest -k flash_attn
for all modified models