Skip to content

[VLM] Add MLA with pure RoPE support for deepseek-vl2 models #12729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions vllm/attention/backends/mla/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)

try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
Expand Down Expand Up @@ -174,6 +175,8 @@ def __init__(
self.v_head_dim = v_head_dim

self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
Expand Down Expand Up @@ -420,6 +423,24 @@ def _forward_decode(
) -> torch.Tensor:
raise NotImplementedError

def apply_pure_rope(
self,
input_positions: torch.Tensor,
q_pe: torch.Tensor,
k_pe: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
seq_len = input_positions.size(0)
ori_q_pe_shape, ori_k_pe_shape = q_pe.shape, k_pe.shape

q_pe, k_pe = self.rotary_emb(
input_positions,
q_pe.reshape(seq_len, -1),
k_pe.reshape(seq_len, -1),
)
q_pe, k_pe = q_pe.view(ori_q_pe_shape), k_pe.view(ori_k_pe_shape)

return q_pe, k_pe
Comment on lines +426 to +442
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you say a bit about why you needed to wrap rotary_embedding when using pure_rope? Wondering if we could clean things up by always doing this reshape so that we could always call self.rotary_embedding without the special cases for pure rope vs yarn

Copy link
Collaborator Author

@Isotr0py Isotr0py Feb 5, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrapped the rotary_embedding to reshape with pure_rope because if q_pe and k_pe have shape of [seq_len, num_heads, head_dim] and passed to pure_rope directly, it will cause an illegal memory allocation on q_pe when applying flash_attention_varlen_func:

[rank0]:   File "/home/zifeng/develop-projects/vllm/vllm/attention/backends/mla/utils.py", line 531, in _forward_prefill_flash
[rank0]:     attn_output = flash_attn_varlen_func(
[rank0]:                   ^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/zifeng/develop-projects/vllm/vllm/vllm_flash_attn/flash_attn_interface.py", line 172, in flash_attn_varlen_func
[rank0]:     out, softmax_lse = torch.ops._vllm_fa2_C.varlen_fwd(
[rank0]:                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]:   File "/home/zifeng/miniconda3/envs/vllm/lib/python3.12/site-packages/torch/_ops.py", line 1116, in __call__
[rank0]:     return self._op(*args, **(kwargs or {}))
[rank0]:            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

BTW, if we use forward_native for pure_rope without reshape, the error won't be encountered and it can also work with shape of [seq_len, num_heads, head_dim], so the issue is forward_cuda specific. Perhaps we should add a shape check in RotaryEmbedding's forward_cuda?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, seems that it's because the calculation for num_heads in rotary_embedding cuda ops is unsuitable for tensor with shape [seq_len, num_heads, head_dim]:

void rotary_embedding(
torch::Tensor& positions, // [batch_size, seq_len] or [num_tokens]
torch::Tensor& query, // [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch::Tensor& key, // [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int64_t head_size,
torch::Tensor& cos_sin_cache, // [max_position, rot_dim]
bool is_neox) {
int64_t num_tokens = query.numel() / query.size(-1);
int rot_dim = cos_sin_cache.size(1);
int num_heads = query.size(-1) / head_size;
int num_kv_heads = key.size(-1) / head_size;

Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds like a bug in the kernel -- I'll look into it tomorrow. In the meantime I like adding a shape check in forward_cuda if you have a good idea of what shapes are problematic

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's fix it in a separate PR to avoid blocking v0.7.2 release, especially it's on the kernel side and I need some time to build with compilation. :)

Nice find, sounds good to me!


def forward(
self,
layer: AttentionLayer,
Expand All @@ -444,21 +465,22 @@ def forward(
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
rope_fn = (self.rotary_emb
if self.use_yarn_rope else self.apply_pure_rope)

if is_decode:
q_nope = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
q_pe = torch.matmul(hidden_states_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
q_pe, k_pe = \
self.rotary_emb(attn_metadata.input_positions, q_pe, k_pe)
q_pe, k_pe = rope_fn(attn_metadata.input_positions, q_pe, k_pe)
else:
assert is_prefill
q = self.q_proj(hidden_states_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)

# TODO(lucas): there must be a nicer way to write this line
q[..., self.qk_nope_head_dim:], k_pe = \
self.rotary_emb(
rope_fn(
attn_metadata.input_positions,
q[..., self.qk_nope_head_dim:], k_pe)

Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,7 +414,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,8 @@ def __init__(
quant_config=quant_config,
prefix=f"{prefix}.o_proj")

rope_scaling["rope_type"] = 'deepseek_yarn'
if rope_scaling:
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
Expand Down