-
-
Notifications
You must be signed in to change notification settings - Fork 7.8k
[VLM] Add MLA with pure RoPE support for deepseek-vl2 models #12729
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
+30
−6
Merged
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you say a bit about why you needed to wrap
rotary_embedding
when using pure_rope? Wondering if we could clean things up by always doing this reshape so that we could always callself.rotary_embedding
without the special cases for pure rope vs yarnUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wrapped the
rotary_embedding
to reshape with pure_rope because ifq_pe
andk_pe
have shape of[seq_len, num_heads, head_dim]
and passed to pure_rope directly, it will cause an illegal memory allocation onq_pe
when applyingflash_attention_varlen_func
:BTW, if we use
forward_native
for pure_rope without reshape, the error won't be encountered and it can also work with shape of[seq_len, num_heads, head_dim]
, so the issue isforward_cuda
specific. Perhaps we should add a shape check inRotaryEmbedding
'sforward_cuda
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, seems that it's because the calculation for num_heads in rotary_embedding cuda ops is unsuitable for tensor with shape
[seq_len, num_heads, head_dim]
:vllm/csrc/pos_encoding_kernels.cu
Lines 124 to 136 in b3a0d01
Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like a bug in the kernel -- I'll look into it tomorrow. In the meantime I like adding a shape check in forward_cuda if you have a good idea of what shapes are problematic
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nice find, sounds good to me!