Skip to content

[V1][Spec Decode] Support multi-layer eagle draft model #18030

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
May 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/v1/spec_decode/test_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,9 @@ def create_deterministic_logits(token_ids):
# Assign the mock to the proposer
proposer.model = model_mock

# Assign draft attn_layer_names since load_model is not invoked
proposer.attn_layer_names = ["layer.0"]

# Create input tensors
cu_num_tokens = torch.tensor([0, seq_len_1, total_tokens],
dtype=torch.int32,
Expand Down
33 changes: 29 additions & 4 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import (CommonAttentionMetadata,
FlashAttentionMetadata)
from vllm.v1.kv_cache_interface import KVCacheConfig
from vllm.v1.sample.metadata import SamplingMetadata
from vllm.v1.spec_decode.utils import prepare_eagle_input_kernel

Expand Down Expand Up @@ -150,6 +151,11 @@ def propose(
else:
raise ValueError(f"Unsupported method: {self.method}")

# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
per_layer_attn_metadata = {}
for layer_name in self.attn_layer_names:
per_layer_attn_metadata[layer_name] = attn_metadata
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
Expand All @@ -159,7 +165,7 @@ def propose(
self.positions[:num_tokens] = target_positions
self.hidden_states[:num_tokens] = target_hidden_states

with set_forward_context(attn_metadata,
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
ret_hidden_states = self.model(
Expand Down Expand Up @@ -245,7 +251,7 @@ def propose(
self.hidden_states[:batch_size] = hidden_states

# Run the model.
with set_forward_context(attn_metadata,
with set_forward_context(per_layer_attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
Expand Down Expand Up @@ -318,8 +324,8 @@ def load_model(self, target_model: nn.Module) -> None:
draft_attn_layer_names = (
get_layers_from_vllm_config(self.vllm_config, Attention).keys() -
target_attn_layer_names)
assert len(draft_attn_layer_names) == 1
self.attn_layer_name = next(iter(draft_attn_layer_names))

self.attn_layer_names = list(draft_attn_layer_names)

# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1:
Expand Down Expand Up @@ -355,6 +361,25 @@ def dummy_run(
self.hidden_states[:num_tokens],
)

def validate_same_kv_cache_group(self,
kv_cache_config: KVCacheConfig) -> None:
"""
Validate that all eagle layers belong to the same KVCacheGroup.
Need this assumption to ensure all eagle layers can use the
same AttentionMetadata.
May extend to multiple AttentionMetadata in the future.
"""
kv_cache_groups: dict[str, int] = {}
for id, kv_cache_group in enumerate(kv_cache_config.kv_cache_groups):
for layer_name in kv_cache_group.layer_names:
kv_cache_groups[layer_name] = id
assert len(
set([
kv_cache_groups[layer_name]
for layer_name in self.attn_layer_names
])
) == 1, "All eagle layers should belong to the same kv cache group"


# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
Expand Down
18 changes: 13 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1360,11 +1360,13 @@ def execute_model(
scheduler_output.num_scheduled_tokens[req_id])
next_token_id = req_state.get_token_id(seq_len)
next_token_ids.append(next_token_id)
next_token_ids = async_tensor_h2d(next_token_ids,
dtype=torch.int32,
target_device=self.device,
pin_memory=True)
eagle_attn_metadata = attn_metadata[self.drafter.attn_layer_name]
next_token_ids = torch.tensor(next_token_ids,
dtype=torch.int32,
device=self.device)
# At this moment, we assume all eagle layers belong to the same KV
# cache group, thus using the same attention metadata.
eagle_attn_metadata = attn_metadata[
self.drafter.attn_layer_names[0]]

# NOTE: deepseek_mtp uses MLA which does not have `block_table`
if hasattr(eagle_attn_metadata, "block_table"):
Expand Down Expand Up @@ -2018,6 +2020,12 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
# KV cache specs.
raise ValueError("Unknown KV cache spec type.")

if self.speculative_config and self.speculative_config.use_eagle():
assert isinstance(self.drafter, EagleProposer)
# validate all draft model layers belong to the same kv cache
# group
self.drafter.validate_same_kv_cache_group(kv_cache_config)

bind_kv_cache(
kv_caches,
self.vllm_config.compilation_config.static_forward_context,
Expand Down