Skip to content

[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE3 #17504

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
May 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 17 additions & 8 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn as nn
from transformers import LlamaConfig

from vllm.config import ModelConfig, VllmConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
Expand Down Expand Up @@ -76,17 +77,19 @@ def forward(
return hidden_states, residual


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(
self,
*,
model_config: ModelConfig,
vllm_config: VllmConfig,
start_layer_id: int = 0,
prefix: str = "",
) -> None:
super().__init__()
self.config = model_config.hf_config
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
Expand Down Expand Up @@ -119,8 +122,7 @@ def forward(
hidden_states: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor]:
input_embeds = self.embed_tokens(input_ids)
if (hidden_states.shape[-1] != input_embeds.shape[-1]):
hidden_states = self.fc(hidden_states)
assert hidden_states.shape[-1] == input_embeds.shape[-1]

residual = None
hidden_states, residual = self.layers[0](
Expand Down Expand Up @@ -169,9 +171,9 @@ class Eagle3LlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.model = LlamaModel(vllm_config=vllm_config,
start_layer_id=start_layer_id,
prefix="model")

Expand Down Expand Up @@ -214,6 +216,13 @@ def compute_logits(
logits_new[:, targets] = logits
return logits_new

def combine_hidden_states(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor:
# combine multiple auxiliary hidden states returned by eagle3
return self.model.fc(hidden_states)

def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
loader = AutoWeightsLoader(
self,
Expand Down
42 changes: 19 additions & 23 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from vllm.model_executor.model_loader.loader import get_model_loader
from vllm.model_executor.model_loader.utils import set_default_torch_dtype
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llama_eagle3 import Eagle3LlamaForCausalLM
from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
from vllm.v1.sample.metadata import SamplingMetadata

Expand Down Expand Up @@ -39,11 +40,9 @@ def __init__(

self.hidden_size = vllm_config.model_config.get_hidden_size()

# TODO: make eagle3 compatible with cudagraph
self.use_cuda_graph = self.method != 'eagle3' and \
(self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)
self.use_cuda_graph = (self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)

self.cudagraph_batch_sizes = list(
reversed(
Expand Down Expand Up @@ -90,6 +89,12 @@ def propose(
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1

if self.method == "eagle3":
assert isinstance(self.model, Eagle3LlamaForCausalLM)
target_hidden_states = self.model.combine_hidden_states(
target_hidden_states)
assert target_hidden_states.shape[-1] == self.hidden_size

# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
Expand Down Expand Up @@ -126,20 +131,15 @@ def propose(
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions

if self.method == 'eagle':
self.hidden_states[:num_tokens] = target_hidden_states
hidden_states = self.hidden_states
else:
# TODO: make eagle3 compatible with cuda graph
hidden_states = target_hidden_states
self.hidden_states[:num_tokens] = target_hidden_states

with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=hidden_states[:num_input_tokens],
hidden_states=self.hidden_states[:num_input_tokens],
)
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
Expand Down Expand Up @@ -209,10 +209,7 @@ def propose(
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions

if self.method == 'eagle':
# TODO: make eagle3 compatible with cudagraph.
self.hidden_states[:batch_size] = hidden_states
hidden_states = self.hidden_states
self.hidden_states[:batch_size] = hidden_states

# Run the model.
with set_forward_context(attn_metadata,
Expand All @@ -221,7 +218,7 @@ def propose(
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=hidden_states[:input_batch_size],
hidden_states=self.hidden_states[:input_batch_size],
)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
Expand Down Expand Up @@ -314,12 +311,11 @@ def dummy_run(
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.method == 'eagle':
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)


# NOTE(woosuk): Currently, the below code is not used and we always use argmax
Expand Down