Skip to content

Revert "[Perf] Reduce MLA CPU overheads in V1 (#14384)" #14471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 2 additions & 7 deletions vllm/model_executor/layers/rotary_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,13 +161,8 @@ def forward_cuda(
) -> Tuple[torch.Tensor, torch.Tensor]:
from vllm import _custom_ops as ops

# __setattr__ in nn.Module (called by `self.cos_sin_cache = ...`)
# is expensive, so avoid calling it if possible
if self.cos_sin_cache.device != query.device or \
self.cos_sin_cache.dtype != query.dtype:
self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)

self.cos_sin_cache = self.cos_sin_cache.to(query.device,
dtype=query.dtype)
# ops.rotary_embedding()/batched_rotary_embedding()
# are in-place operations that update the query and key tensors.
if offsets is not None:
Expand Down
15 changes: 4 additions & 11 deletions vllm/v1/attention/backends/mla/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,8 +222,8 @@
Fp8LinearGenericOp, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.utils import cdiv, round_down

try:
Expand Down Expand Up @@ -627,15 +627,8 @@ def __init__(
self.v_head_dim = v_head_dim

self.rotary_emb = rotary_emb

if current_platform.is_cuda():
# Hack for V1 for now to avoid torch library overhead (since we are
# already inside an attention custom op), pull out the forward
# method from the rotary embedding and call it directly (and avoid
# calling forward_native, when we can call forward_cuda)
# TODO(lucas): we should probably find a cleaner way to do this
self.rotary_emb = rotary_emb.forward_cuda

self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
Expand Down