Skip to content

[qwen2-vl] fix FA2 inference #39121

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 1, 2025
Merged

Conversation

zucchini-nlp
Copy link
Member

@zucchini-nlp zucchini-nlp commented Jun 30, 2025

What does this PR do?

Fixes #39095 and set mask=None for FA2, otherwise inference fails because FA2 expects a 2d mask

Tested with RUN_SLOW=1 pytest -k flash_attn for all modified models

@zucchini-nlp zucchini-nlp requested a review from Cyrilvallez June 30, 2025 09:43
@zucchini-nlp zucchini-nlp added the for patch Tag issues / labels that should be included in the next patch label Jun 30, 2025
@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Comment on lines 638 to 647
attention_mask = None
if self.config._attn_implementation != "flash_attention_2":
attention_mask = torch.full(
[1, 1, seq_length, key_states.shape[-2]],
torch.finfo(query_states.dtype).min,
device=query_states.device,
dtype=query_states.dtype,
)
for i in range(1, len(cu_seqlens)):
attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This probably also fixes #39067? Might be nice to change the flash attention integration path in the future to prioritize pos+cu_seq (even with an attention mask).

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, it doesn't, that issue was there before refactor and I found it's related to FA2 using ragged inputs which the SDPA/eager path doesn't support in the same way

I will comment under #39067 when I have a clear fix. Let's keep this PR open until then, better to fix it once and forever

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks (for bearing with me) :)

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not sure I understand why we create the mask in the attention layer, we should make sure we do it at the model level. it is the same for all layers

@zucchini-nlp
Copy link
Member Author

I don't think there was a specific reason, it has been that way since the first model was released. In any case Qwen-VL models seem to be based on FA2 first, while other attn are added just because we wanted to support them

Thus the mask is prepared based on cu_seqlens

@JJJYmmm
Copy link
Contributor

JJJYmmm commented Jun 30, 2025

@zucchini-nlp Hi, the FA2 may not be working as expected. This is because we use the following code in _flash_attention_forward:

cu_seq_lens_q: Optional[torch.LongTensor] = None,
cu_seq_lens_k: Optional[torch.LongTensor] = None,
max_length_q: Optional[int] = None,
max_length_k: Optional[int] = None,

However, there is an issue with the variable names here, which causes FA2 to not function correctly.

https://github.com/zucchini-nlp/transformers/blob/2cb7f21991478494c963ab8c4b52691d8d192b0d/src/transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py#L266-L269

After correcting the code to the following version, I tested the model accuracy and it worked as expected 🤗

        attn_output, _ = attention_interface(
            self,
            query_states,
            key_states,
            value_states,
            attention_mask=attention_mask,
            dropout=0.0,
            scaling=self.scaling,
            cu_seq_lens_k=cu_seqlens,  # pass cu seq lens for FA2
            cu_seq_lens_q=cu_seqlens,
            max_length_q=max_seqlen,
            max_length_k=max_seqlen,
            is_causal=False,
            **kwargs,
        )

@zucchini-nlp
Copy link
Member Author

@JJJYmmm yeah, thanks, the current state doesn't work yet. In fact renaming only doesn't help because the varlen codepath needs position ids as well. I have the local version working and now I am trying to find the reason to a linked issue. Even though it's not caused by recent changes, we need to fix it

@JJJYmmm
Copy link
Contributor

JJJYmmm commented Jun 30, 2025

@JJJYmmm yeah, thanks, the current state doesn't work yet. In fact renaming only doesn't help because the varlen codepath needs position ids as well. I have the local version working and now I am trying to find the reason to a linked issue. Even though it's not caused by recent changes, we need to fix it

Oh! Yes, I also modified the logic in _flash_attention_forward — when position_ids is None but cu_seqlens is not None, it now directly calls the varlen func. Looking forward to your changes!

@zucchini-nlp
Copy link
Member Author

@vasqu @JJJYmmm I updated the fa2 path so it uses flash_attn_with_varlen for pad-free inputs. Can you check once again and if everything is fine, I'll merge

@JJJYmmm
Copy link
Contributor

JJJYmmm commented Jun 30, 2025

@zucchini-nlp LGTM 🫡

Comment on lines +517 to +524
is_fa2_with_position_ids = (
position_ids is not None
and query_states.shape[0] == 1
and (max_length_q is not None or (query_length != 1 and not (torch.diff(position_ids, dim=-1) >= 0).all()))
)
is_fa2_with_varlen_kwargs = all(
kwarg is not None for kwarg in (cu_seq_lens_q, cu_seq_lens_k, max_length_q, max_length_k)
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit unsure about this since we would allow cu_seq and max_seq only but on most models we also have RoPE so it's breaking those models silently if we eff up not passing position_ids (due to RoPE positions being bound to position_ids as well). We should imo add at least a warning on only varlen kwargs to give some discretion here.

On another note, what do the integration tests use? Are they still working as expected 👀 seems a bit sus

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The integration tests always use position_ids which is the first is_fa2_with_position_ids, I just copied it from existing code and moved up here. The second case is added for Qwen only, afaik no other model passes pre-computed cu_lens form attention layers

In qwen we don't need any position ids, because they are 3D and won't help at all in inferring cu_lens

Copy link
Contributor

@vasqu vasqu Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I meant the general integration tests like e.g.

def test_small_model_integration_test_batch_flashatt2(self):

For why, I'm concerned about the second case on qwen is future model additions and general usability, not the validity of qwen. For developers, is_fa2_with_varlen_kwargs indicates that this suffices for varlen - before, we (unintentionally) checked for the existence of (correct flattened) position ids that RoPE models need when using varlen. Maybe #35941 helps for reference on what I mean.

Imo, it would help to add at least comments that for varlen most models need correct flattened position ids (from e.g. a collator), especially RoPE models which make up the majority of newer models.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, am I right that users might be passing only cu_lens without correct position_ids? I believe that would be users' responsibility to take care that RoPE is applied correctly, but I will add a comment in code explaining it, sure

In slow integration tests we don't pass position_ids, not that of I know. For most LLMs the fa2 path integration tests fallback to inferring cu_lens from the mask, and in Qwen the position_ids are constructed on-the-fly during forward call. The model has a requirement for adding rope deltas on top of 3D positions and I don't think users would be doing all that manually

Copy link
Contributor

@vasqu vasqu Jul 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"Hmm, am I right that users might be passing only cu_lens without correct position_ids?" - Yes, not only users but possibly us as well because it's something that's harder to figure out when done wrong imo :D

Copy link
Contributor

@vasqu vasqu left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks again

@ArthurZucker
Copy link
Collaborator

@zucchini-nlp can we refactor the attention mask creation to be outside, cu seqs are not different for a layer!

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's push!

@ArthurZucker
Copy link
Collaborator

image
Actually on cuda it makes a somewhat 2x diff

@zucchini-nlp zucchini-nlp enabled auto-merge (squash) July 1, 2025 10:14
@zucchini-nlp zucchini-nlp merged commit 7a25f8d into huggingface:main Jul 1, 2025
20 checks passed
Cyrilvallez pushed a commit that referenced this pull request Jul 4, 2025
* fix FA2

* update is causal flag and remove mask for FA2

* update for FA2 with varlen path

* how the tests were passing with different devices?

* add comment and ref to the PR

* move mask preparation to base pretrained model

* seq len is the first dim, not second

* fix copies to fix GLM4V
@zucchini-nlp zucchini-nlp mentioned this pull request Jul 18, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
for patch Tag issues / labels that should be included in the next patch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Qwen2_5_VLVisionAttention with flash attention has no is_causal attribute
5 participants