Skip to content

[V1] Support Deepseek MTP #18435

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 21 commits into from
May 23, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2255,7 +2255,7 @@ def __post_init__(self):


SpeculativeMethod = Literal["ngram", "eagle", "medusa", "mlp_speculator",
"draft_model"]
"draft_model", "deepseek_mtp"]
SpeculativeAcceptanceMethod = Literal["rejection_sampler",
"typical_acceptance_sampler"]

Expand Down Expand Up @@ -2519,6 +2519,15 @@ def __post_init__(self):
elif (self.draft_model_config.hf_config.model_type ==
"mlp_speculator"):
self.method = "mlp_speculator"
elif (self.draft_model_config.hf_config.model_type ==
"deepseek_mtp"):
self.method = "deepseek_mtp"
if self.num_speculative_tokens > 1:
logger.warning(
"All Deepseek MTP models only have " \
"one layer. Might need some code changes " \
"to support multiple layers."
)
else:
self.method = "draft_model"

Expand Down Expand Up @@ -2739,7 +2748,7 @@ def num_lookahead_slots(self) -> int:
return self.num_speculative_tokens

def use_eagle(self) -> bool:
return self.method in ("eagle", "eagle3")
return self.method in ("eagle", "eagle3", "deepseek_mtp")

def __repr__(self) -> str:
method = self.method
Expand Down
2 changes: 1 addition & 1 deletion vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,7 +1338,7 @@ def _is_v1_supported_oracle(self, model_config: ModelConfig) -> bool:
is_ngram_enabled = True
elif speculative_method == "medusa":
is_medusa_enabled = True
elif speculative_method in ("eagle", "eagle3"):
elif speculative_method in ("eagle", "eagle3", "deepseek_mtp"):
is_eagle_enabled = True
else:
speculative_model = self.speculative_config.get("model")
Expand Down
3 changes: 2 additions & 1 deletion vllm/model_executor/models/deepseek_mtp.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from .deepseek_v2 import (DeepseekV2DecoderLayer,
get_spec_layer_idx_from_weight_name)
from .interfaces import SupportsPP
from .utils import maybe_prefix


Expand Down Expand Up @@ -145,7 +146,7 @@ def compute_logits(
return logits


class DeepSeekMTP(nn.Module):
class DeepSeekMTP(nn.Module, SupportsPP):

def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
Expand Down
122 changes: 65 additions & 57 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from vllm.logger import init_logger
from vllm.model_executor.model_loader import get_model
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.triton_utils import tl, triton
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel

logger = init_logger(__name__)

Expand All @@ -25,12 +26,15 @@ def __init__(
self,
vllm_config: VllmConfig,
device: torch.device,
runner=None,
):
self.vllm_config = vllm_config
self.speculative_config = vllm_config.speculative_config
self.draft_model_config = self.speculative_config.draft_model_config
self.method = self.speculative_config.method

self.runner = runner

self.dtype = vllm_config.model_config.dtype
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size
Expand Down Expand Up @@ -106,24 +110,46 @@ def propose(
# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()

# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] - cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
if self.method in ["eagle", "eagle3"]:
# FIXME(woosuk): The below two ops cause synchronization. Optimize.
max_seq_len = seq_lens.max().item()
max_num_tokens = (cu_num_tokens[1:] -
cu_num_tokens[:-1]).max().item()
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_tokens,
max_query_len=max_num_tokens,
query_start_loc=cu_num_tokens,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=target_slot_mapping,
# TODO(woosuk): Support cascade attention.
use_cascade=False,
common_prefix_len=0,
cu_prefix_query_lens=None,
prefix_kv_lens=None,
suffix_kv_lens=None,
)
elif self.method == "deepseek_mtp":
query_lens = cu_num_tokens[1:] - cu_num_tokens[:-1]
max_query_len = query_lens.max().item()

common_attn_metadata = CommonAttentionMetadata(
query_start_loc=cu_num_tokens, seq_lens=seq_lens)

assert self.runner is not None

# FIXME: need to consider multiple kv_cache_groups
attn_metadata = self.runner.attn_metadata_builder.build(
num_reqs=batch_size,
num_actual_tokens=num_tokens,
max_query_len=max_query_len,
common_prefix_len=0,
common_attn_metadata=common_attn_metadata,
)
else:
raise ValueError(f"Unsupported method: {self.method}")

if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
Expand All @@ -136,11 +162,15 @@ def propose(
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
ret_hidden_states = self.model(
self.input_ids[:num_input_tokens],
self.positions[:num_input_tokens],
self.hidden_states[:num_input_tokens],
)
if self.method == "deepseek_mtp":
last_hidden_states = ret_hidden_states
else:
last_hidden_states, hidden_states = ret_hidden_states
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)
Expand All @@ -150,6 +180,10 @@ def propose(
# [batch_size, 1]
return draft_token_ids.view(-1, 1)

# TODO: Currently, MTP module released by deepseek only has
# one layer. Adapt this code to support multiple layers once
# there's a multi-layer MTP module.

# Generate the remaining draft tokens.
draft_token_ids_list = [draft_token_ids]

Expand Down Expand Up @@ -215,9 +249,9 @@ def propose(
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
self.input_ids[:input_batch_size],
self.positions[:input_batch_size],
self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
Expand Down Expand Up @@ -268,7 +302,7 @@ def prepare_inputs(

batch_size = num_rejected_tokens.shape[0]
BLOCK_SIZE = 1024
prepare_input_kernel[(batch_size, )](
prepare_eagle_input_kernel[(batch_size, )](
token_indices,
cu_target_query_lens,
cu_num_tokens,
Expand Down Expand Up @@ -320,9 +354,9 @@ def dummy_run(
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
self.input_ids[:num_tokens],
self.positions[:num_tokens],
self.hidden_states[:num_tokens],
)


Expand Down Expand Up @@ -367,29 +401,3 @@ def compute_probs_and_sample_next_token(
next_token_ids,
)
return next_token_ids, probs


@triton.jit
def prepare_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos

index_start = tl.load(cu_query_lens_ptr + pid)

num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
27 changes: 27 additions & 0 deletions vllm/v1/spec_decode/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# SPDX-License-Identifier: Apache-2.0
from vllm.triton_utils import tl, triton
from vllm.v1.worker.gpu_input_batch import InputBatch


Expand All @@ -16,3 +17,29 @@ def is_spec_decode_supported(req_id: str, input_batch: InputBatch) -> bool:
return False

return True


@triton.jit
def prepare_eagle_input_kernel(
out_ptr,
cu_query_lens_ptr,
cu_num_tokens_ptr,
BLOCK_SIZE: tl.constexpr,
):
pid = tl.program_id(0)

# [start_pos, end_pos)
start_pos = tl.load(cu_num_tokens_ptr + pid)
end_pos = tl.load(cu_num_tokens_ptr + pid + 1)
num_tokens = end_pos - start_pos

index_start = tl.load(cu_query_lens_ptr + pid)

num_blocks = tl.cdiv(num_tokens, BLOCK_SIZE)
for i in tl.range(num_blocks):
offset = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
tl.store(
out_ptr + start_pos + offset,
index_start + offset,
mask=offset < num_tokens,
)
19 changes: 14 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,12 +188,16 @@ def __init__(
self.use_aux_hidden_state_outputs = False
if self.speculative_config:
self.use_spec_decode = True

# NOTE(Jiayi): currently we put the entire draft model on
# the last PP rank. This is not ideal if there are many
# layers in the draft model.
if get_pp_group().is_last_rank:
if self.speculative_config.method == "ngram":
self.drafter = NgramProposer(self.vllm_config)
elif self.speculative_config.use_eagle():
self.drafter = EagleProposer(self.vllm_config,
self.device) # type: ignore
self.drafter = EagleProposer(self.vllm_config, self.device,
self) # type: ignore
if self.speculative_config.method == "eagle3":
self.use_aux_hidden_state_outputs = True
elif self.speculative_config.method == "medusa":
Expand Down Expand Up @@ -1362,6 +1366,12 @@ def execute_model(
device=self.device)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]

# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
block_table = eagle_attn_metadata.block_table
else:
block_table = None

if spec_decode_metadata is None:
# input_ids can be None for multimodal models.
target_token_ids = self.input_ids[:num_scheduled_tokens]
Expand Down Expand Up @@ -1407,7 +1417,7 @@ def execute_model(
target_slot_mapping=target_slot_mapping,
next_token_ids=next_token_ids,
cu_num_tokens=cu_num_tokens,
block_table=eagle_attn_metadata.block_table,
block_table=block_table,
sampling_metadata=sampling_metadata,
)
spec_token_ids = draft_token_ids.tolist()
Expand Down Expand Up @@ -1718,8 +1728,7 @@ def _dummy_run(
else:
hidden_states = outputs

if self.use_spec_decode and \
self.speculative_config.method in ('eagle', 'eagle3'):
if self.use_spec_decode and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
self.drafter.dummy_run(num_tokens)

Expand Down