Skip to content

[V1][Spec Decode] Apply torch.compile & cudagraph to EAGLE #17211

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Apr 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions examples/offline_inference/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ def parse_args():
help="downloaded from the eagle repo " \
"https://github.com/SafeAILab/EAGLE/blob/main/eagle/data/"
)
parser.add_argument("--method",
type=str,
default='eagle',
choices=['eagle', 'eagle3'])
parser.add_argument("--max_num_seqs", type=int, default=8)
parser.add_argument("--num_prompts", type=int, default=80)
parser.add_argument("--num_spec_tokens", type=int, default=2)
Expand All @@ -53,7 +57,13 @@ def main():
args = parse_args()

model_dir = "meta-llama/Llama-3.1-8B-Instruct"
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"

if args.method == 'eagle':
eagle_dir = "yuhuili/EAGLE-LLaMA3.1-Instruct-8B"
elif args.method == 'eagle3':
eagle_dir = "yuhuili/EAGLE3-LLaMA3.1-Instruct-8B"
else:
raise ValueError(f"unknown method: {args.method}")

max_model_len = 2048

Expand Down Expand Up @@ -81,7 +91,7 @@ def main():
max_num_seqs=args.max_num_seqs,
gpu_memory_utilization=0.8,
speculative_config={
"method": "eagle3" if "eagle3" in eagle_dir.lower() else "eagle",
"method": args.method,
"model": eagle_dir,
"num_speculative_tokens": args.num_spec_tokens,
"draft_tensor_parallel_size": args.draft_tp,
Expand Down
15 changes: 12 additions & 3 deletions vllm/compilation/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,8 +347,12 @@ def configure_post_pass(self):
PASS_KEY = "post_grad_custom_post_pass"
if PASS_KEY in inductor_config:
# Config should automatically wrap all inductor passes
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
if isinstance(inductor_config[PASS_KEY], PostGradPassManager):
assert (inductor_config[PASS_KEY].uuid() ==
self.post_grad_pass_manager.uuid())
else:
assert isinstance(inductor_config[PASS_KEY], InductorPass)
self.post_grad_pass_manager.add(inductor_config[PASS_KEY])
inductor_config[PASS_KEY] = self.post_grad_pass_manager

def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
Expand Down Expand Up @@ -408,8 +412,13 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable:
)
self.compilation_config.cache_dir = cache_dir

cache_dir = self.compilation_config.cache_dir
if compilation_counter.num_graphs_seen > 0:
cache_dir = self.compilation_config.cache_dir + \
f'-{compilation_counter.num_graphs_seen}'
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should be able to get the component's prefix to use as the cache directory, it could be more meaningful.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'll make that change, @luyuzhe111 and I were discussing something similar above

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else:
cache_dir = self.compilation_config.cache_dir
os.makedirs(cache_dir, exist_ok=True)
self.compilation_config.cache_dir = cache_dir
Comment on lines +415 to +421
Copy link
Collaborator

@zou3519 zou3519 May 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 This is suspicious. Why do you need a different cache directory for each graph? Also, this looks like it modifies everything, even the models that don't use eagle.

If there isn't a good reason I would prefer going back to the "single cache directory" that we had previously.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 thanks for reviewing! if there isn't a separate cache directory, the compiled code for the draft model (EAGLE) will not be saved at all. for models without EAGLE, my understanding is that the backend is invoked only once so this should not impact other models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 thanks for the response and clarifying that. Woosuk also filled me in on some more details offline. I understand why we need a separate cache directory.

Which of the "original model" and the "eagle head" get compiled first? (I'm trying to figure out if the first cache dir is for the original model or for the eagle head)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 original model should be compiled first! also if you wanna double check, the transformed code of EAGLE in the cache directory has a slightly different signature with hidden_states as an additional arg. if there is a more elegant solution, that would be great! I think my approach is a bit hacky indeed : )))

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the discussion! I added some comments and an assertion into #17662 , please take a look.

I think in the future we'll want a better way to handle multiple compiled regions in a vLLM model, but that will take some re-designing

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@luyuzhe111 the asserts in #17662 triggered, which means that this PR does affect non-eagle models

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zou3519 Thanks for the catch! I guess a simple fix would be just to create a separate cache directory only for EAGLE, via looking at the vllm speculative config, for example?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah that would work

rank = vllm_config.parallel_config.rank
dp_rank = vllm_config.parallel_config.data_parallel_rank
local_cache_dir = os.path.join(cache_dir, f"rank_{rank}_{dp_rank}")
Expand Down
25 changes: 14 additions & 11 deletions vllm/model_executor/models/llama_eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
import torch.nn as nn
from transformers import LlamaConfig

from vllm.config import ModelConfig
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
Expand Down Expand Up @@ -37,17 +38,19 @@ def __init__(
self.input_layernorm = nn.Identity()


@support_torch_compile
class LlamaModel(nn.Module):

def __init__(
self,
*,
model_config: ModelConfig,
start_layer_id: int = 0,
vllm_config: VllmConfig,
prefix: str = "",
start_layer_id: int = 0,
) -> None:
super().__init__()
self.config = model_config.hf_config
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.vocab_size = self.config.vocab_size
self.embed_tokens = VocabParallelEmbedding(
self.config.vocab_size,
Expand Down Expand Up @@ -75,8 +78,7 @@ def forward(
hidden_states = self.fc(
torch.cat((input_embeds, hidden_states), dim=-1))
residual = None
for i in range(len(self.layers)):
layer = self.layers[i]
for layer in self.layers:
hidden_states, residual = layer(
positions,
hidden_states,
Expand Down Expand Up @@ -117,12 +119,13 @@ def load_weights(self, weights: Iterable[Tuple[str,

class EagleLlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
prefix="model")
self.config = vllm_config. \
speculative_config.draft_model_config.hf_config
self.model = LlamaModel(vllm_config=vllm_config,
prefix="model",
start_layer_id=start_layer_id)

logit_scale = getattr(self.config, "logit_scale", 1.0)
self.logits_processor = LogitsProcessor(self.config.vocab_size,
Expand Down
5 changes: 3 additions & 2 deletions vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn as nn
from transformers import LlamaConfig

from vllm.config import ModelConfig
from vllm.config import ModelConfig, VllmConfig
from vllm.logger import init_logger
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import QKVParallelLinear
Expand Down Expand Up @@ -167,8 +167,9 @@ def load_weights(self, weights: Iterable[Tuple[str,

class Eagle3LlamaForCausalLM(LlamaForCausalLM):

def __init__(self, *, model_config: ModelConfig, start_layer_id: int = 0):
def __init__(self, *, vllm_config: VllmConfig, start_layer_id: int = 0):
nn.Module.__init__(self)
model_config = vllm_config.speculative_config.draft_model_config
self.config = model_config.hf_config
self.model = LlamaModel(model_config=model_config,
start_layer_id=start_layer_id,
Expand Down
122 changes: 100 additions & 22 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import triton
import triton.language as tl

from vllm.config import VllmConfig, set_current_vllm_config
from vllm.config import CompilationLevel, VllmConfig, set_current_vllm_config
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor.model_loader.loader import get_model_loader
Expand All @@ -26,10 +26,41 @@ def __init__(
device: torch.device,
):
self.vllm_config = vllm_config
self.method = self.vllm_config.speculative_config.method
self.num_speculative_tokens = (
vllm_config.speculative_config.num_speculative_tokens)
self.max_model_len = vllm_config.model_config.max_model_len
self.block_size = vllm_config.cache_config.block_size

self.dtype = vllm_config.model_config.dtype

self.max_num_tokens = vllm_config.scheduler_config \
.max_num_batched_tokens

self.hidden_size = vllm_config.model_config.get_hidden_size()

# TODO: make eagle3 compatible with cudagraph
self.use_cuda_graph = self.method != 'eagle3' and \
(self.vllm_config.compilation_config.level
== CompilationLevel.PIECEWISE and
not self.vllm_config.model_config.enforce_eager)

self.cudagraph_batch_sizes = list(
reversed(
self.vllm_config.compilation_config.cudagraph_capture_sizes))

# persistent buffers for cuda graph
self.input_ids = torch.zeros(self.max_num_tokens,
dtype=torch.int32,
device=device)
self.positions = torch.zeros(self.max_num_tokens,
dtype=torch.int64,
device=device)

self.hidden_states = torch.zeros(
(self.max_num_tokens, self.hidden_size),
dtype=self.dtype,
device=device)
# We need +1 here because the arange is used to set query_start_loc,
# which has one more element than batch_size.
self.arange = torch.arange(vllm_config.scheduler_config.max_num_seqs +
Expand Down Expand Up @@ -59,13 +90,12 @@ def propose(
batch_size = next_token_ids.shape[0]
last_token_indices = cu_num_tokens[1:] - 1

input_ids = torch.empty_like(target_token_ids)
# Shift the input ids by one token.
# E.g., [a1, b1, b2, c1, c2, c3] -> [b1, b2, c1, c2, c3, c3]
input_ids[:-1] = target_token_ids[1:]
self.input_ids[:num_tokens - 1] = target_token_ids[1:]
# Replace the last token with the next token.
# E.g., [b1, b2, c1, c2, c3, c3] -> [a2, b2, b3, c2, c3, c4]
input_ids[last_token_indices] = next_token_ids
self.input_ids[last_token_indices] = next_token_ids

# FA requires seq_len to have dtype int32.
seq_lens = (target_positions[last_token_indices] + 1).int()
Expand All @@ -88,14 +118,30 @@ def propose(
prefix_kv_lens=None,
suffix_kv_lens=None,
)
if self.use_cuda_graph and \
num_tokens <= self.cudagraph_batch_sizes[-1]:
num_input_tokens = self.vllm_config.pad_for_cudagraph(num_tokens)
else:
num_input_tokens = num_tokens
# copy inputs to buffer for cudagraph
self.positions[:num_tokens] = target_positions

with set_forward_context(attn_metadata, self.vllm_config):
hidden_states_logits, hidden_states_fwd = self.model(
input_ids=input_ids,
hidden_states=target_hidden_states,
positions=target_positions,
if self.method == 'eagle':
self.hidden_states[:num_tokens] = target_hidden_states
hidden_states = self.hidden_states
else:
# TODO: make eagle3 compatible with cuda graph
hidden_states = target_hidden_states

with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=num_input_tokens):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:num_input_tokens],
positions=self.positions[:num_input_tokens],
hidden_states=hidden_states[:num_input_tokens],
)
sample_hidden_states = hidden_states_logits[last_token_indices]
sample_hidden_states = last_hidden_states[last_token_indices]
logits = self.model.compute_logits(sample_hidden_states, None)
draft_token_ids = logits.argmax(dim=-1)

Expand All @@ -108,13 +154,20 @@ def propose(
draft_token_ids_list = [draft_token_ids]

positions = target_positions[last_token_indices]
hidden_states = hidden_states_fwd[last_token_indices]
hidden_states = hidden_states[last_token_indices]
if self.use_cuda_graph and \
batch_size <= self.cudagraph_batch_sizes[-1]:
input_batch_size = self.vllm_config.pad_for_cudagraph(batch_size)
else:
input_batch_size = batch_size
attn_metadata.num_actual_tokens = batch_size
attn_metadata.max_query_len = 1
attn_metadata.query_start_loc = self.arange[:batch_size + 1]
for _ in range(self.num_speculative_tokens - 1):
# Update the inputs.
input_ids = draft_token_ids_list[-1]
# cast to int32 is crucial when eagle model is compiled.
# tensor.argmax() returns int64 by default.
input_ids = draft_token_ids_list[-1].int()
positions += 1

# NOTE(woosuk): We should handle the case where the draft model
Expand Down Expand Up @@ -152,14 +205,27 @@ def propose(
attn_metadata.slot_mapping.masked_fill_(exceeds_max_model_len,
PADDING_SLOT_ID)

# copy inputs to buffer for cudagraph
self.input_ids[:batch_size] = input_ids
self.positions[:batch_size] = clamped_positions

if self.method == 'eagle':
# TODO: make eagle3 compatible with cudagraph.
self.hidden_states[:batch_size] = hidden_states
hidden_states = self.hidden_states

# Run the model.
with set_forward_context(attn_metadata, self.vllm_config):
hidden_states_logits, hidden_states = self.model(
input_ids=input_ids,
hidden_states=hidden_states,
positions=clamped_positions,
with set_forward_context(attn_metadata,
self.vllm_config,
num_tokens=input_batch_size):
last_hidden_states, hidden_states = self.model(
input_ids=self.input_ids[:input_batch_size],
positions=self.positions[:input_batch_size],
hidden_states=hidden_states[:input_batch_size],
)
logits = self.model.compute_logits(hidden_states_logits, None)
hidden_states = hidden_states[:batch_size]
logits = self.model.compute_logits(last_hidden_states[:batch_size],
None)
draft_token_ids = logits.argmax(dim=-1)
draft_token_ids_list.append(draft_token_ids)

Expand Down Expand Up @@ -227,13 +293,11 @@ def load_model(self, target_model: nn.Module) -> None:
draft_model_cls, arch = ModelRegistry.resolve_model_cls(
draft_model_config.architectures)
self.model = draft_model_cls(
model_config=draft_model_config,
vllm_config=self.vllm_config,
start_layer_id=target_layer_num).to(target_device)

loaded_weights = self.model.load_weights(
loader.get_all_weights(
self.vllm_config.speculative_config.draft_model_config,
self.model))
loader.get_all_weights(draft_model_config, self.model))
if self.vllm_config.speculative_config.method == "eagle3":
if "model.embed_tokens.weight" not in loaded_weights:
logger.info(
Expand All @@ -243,6 +307,20 @@ def load_model(self, target_model: nn.Module) -> None:
logger.info("Loading EAGLE LM head weights from the target model.")
self.model.lm_head = target_model.lm_head

@torch.inference_mode()
def dummy_run(
self,
num_tokens: int,
) -> None:
with set_forward_context(None, self.vllm_config,
num_tokens=num_tokens):
if self.method == 'eagle':
self.model(
input_ids=self.input_ids[:num_tokens],
positions=self.positions[:num_tokens],
hidden_states=self.hidden_states[:num_tokens],
)


# NOTE(woosuk): Currently, the below code is not used and we always use argmax
# to sample the draft tokens. We will use this after we find a way to manage
Expand Down
Loading