Skip to content

[Perf] Reduce MLA CPU overheads in V1 #14384

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,8 +161,13 @@ def forward_cuda(
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
Comment on lines -164 to -165
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we actually know what this line of code is for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no :/ it doesnt appear to be called, but just didn't want to create behavior change in case there was a model that needs it. I can pull it out completely and we can just see if we get reports of breakages

# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)

# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
Expand Down
15 changes: 11 additions & 4 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down

try:
Expand Down Expand Up @@ -627,8 +627,15 @@ def __init__(
self.v_head_dim = v_head_dim

self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)

if current_platform.is_cuda():
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly (and avoid
# calling forward_native, when we can call forward_cuda)
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_cuda

self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
Expand Down