Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions vllm/model_executor/models/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -604,6 +604,11 @@ def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
self.model.aux_hidden_state_layers = layers

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Override to return default layers for Llama

Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
num_layers = len(self.model.layers)
return (2, num_layers // 2, num_layers - 3)

Expand Down
8 changes: 7 additions & 1 deletion vllm/model_executor/models/llama_eagle3.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.llama import LlamaDecoderLayer, LlamaForCausalLM
from vllm.multimodal.inputs import NestedTensors

from .utils import AutoWeightsLoader, maybe_prefix

Expand Down Expand Up @@ -241,7 +242,12 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
requires_grad=False,
)

def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
def get_input_embeddings(
self,
input_ids: torch.Tensor,
multimodal_embeddings: Optional[NestedTensors] = None,
is_multimodal: Optional[torch.Tensor] = None,
) -> torch.Tensor:
return self.model.get_input_embeddings(input_ids)

def forward(
Expand Down
27 changes: 25 additions & 2 deletions vllm/model_executor/models/mllama4.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,12 @@
from vllm.sequence import IntermediateTensors
from vllm.utils.tensor_schema import TensorSchema, TensorShape

from .interfaces import MultiModalEmbeddings, SupportsMultiModal, SupportsPP
from .interfaces import (
MultiModalEmbeddings,
SupportsEagle3,
SupportsMultiModal,
SupportsPP,
)
from .llama4 import Llama4ForCausalLM
from .utils import AutoWeightsLoader, flatten_bn, maybe_prefix
from .vision import run_dp_sharded_vision_model
Expand Down Expand Up @@ -717,7 +722,9 @@ def get_dummy_mm_data(
info=Mllama4ProcessingInfo,
dummy_inputs=Mllama4DummyInputsBuilder,
)
class Llama4ForConditionalGeneration(nn.Module, SupportsMultiModal, SupportsPP):
class Llama4ForConditionalGeneration(
nn.Module, SupportsMultiModal, SupportsPP, SupportsEagle3
):
packed_modules_mapping = {
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
"gate_up_proj": ["gate_proj", "up_proj"],
Expand Down Expand Up @@ -767,6 +774,22 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
self.language_model.make_empty_intermediate_tensors
)

def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None:
"""Set which layers should output auxiliary hidden states for EAGLE3."""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "set_aux_hidden_state_layers")
self.language_model.set_aux_hidden_state_layers(layers)

def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]:
"""Get the layer indices for auxiliary hidden state outputs.

Note: The GPU model runner will override this with layers from
the speculative config if available, providing dynamic configuration.
"""
# Delegate to underlying language model (Llama4ForCausalLM)
assert hasattr(self.language_model, "get_eagle3_aux_hidden_state_layers")
return self.language_model.get_eagle3_aux_hidden_state_layers()

def _parse_and_validate_image_input(
self, **kwargs: object
) -> Optional[Llama4ImagePatchInputs]:
Expand Down
8 changes: 8 additions & 0 deletions vllm/transformers_utils/configs/speculators/algos.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,18 @@ def update_eagle3(config_dict: dict, vllm_config: dict) -> None:
- draft_vocab_size: Size of the draft model's vocabulary
- target_hidden_size: Hidden size of the target model
- norm_before_residual: Whether to apply norm before residual connection
- eagle_aux_hidden_state_layer_ids: List of layer indices from the base
model to use as auxiliary inputs for the Eagle3 drafter. These layers
provide intermediate hidden states that help the drafter make better
predictions. This is the standard field used in Eagle3 checkpoints.
"""

vllm_config["draft_vocab_size"] = config_dict.get("draft_vocab_size")
if config_dict.get("target_hidden_size") is not None:
vllm_config["target_hidden_size"] = config_dict["target_hidden_size"]
vllm_config["norm_before_residual"] = config_dict.get("norm_before_residual", True)
vllm_config["architectures"] = ["Eagle3LlamaForCausalLM"]
if config_dict.get("eagle_aux_hidden_state_layer_ids"):
vllm_config["eagle_aux_hidden_state_layer_ids"] = config_dict[
"eagle_aux_hidden_state_layer_ids"
]
43 changes: 38 additions & 5 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2943,15 +2943,24 @@ def load_model(self, eep_scale_up: bool = False) -> None:
logger.info("Loading drafter model...")
self.drafter.load_model(self.model)
if self.use_aux_hidden_state_outputs:
if supports_eagle3(self.model):
self.model.set_aux_hidden_state_layers(
self.model.get_eagle3_aux_hidden_state_layers()
)
else:
if not supports_eagle3(self.model):
raise RuntimeError(
"Model does not support EAGLE3 interface but "
"aux_hidden_state_outputs was requested"
)

# Try to get auxiliary layers from speculative config,
# otherwise use model's default layers
aux_layers = self._get_eagle3_aux_layers_from_config()
if aux_layers:
logger.info(
"Using auxiliary layers from speculative config: %s",
aux_layers,
)
else:
aux_layers = self.model.get_eagle3_aux_hidden_state_layers()

self.model.set_aux_hidden_state_layers(aux_layers)
time_after_load = time.perf_counter()
self.model_memory_usage = m.consumed_memory
logger.info(
Expand Down Expand Up @@ -3006,6 +3015,30 @@ def load_model(self, eep_scale_up: bool = False) -> None:
self.model, self.vllm_config, CUDAGraphMode.NONE, self.device
)

def _get_eagle3_aux_layers_from_config(self) -> Optional[tuple[int, ...]]:
"""Extract Eagle3 auxiliary layer indices from speculative config.

These indices specify which hidden states from the base model should
be used as auxiliary inputs for the Eagle3 drafter model during
speculative decoding.

Returns:
Tuple of layer indices if found in draft model config,
None otherwise.
"""
if not (self.speculative_config and self.speculative_config.draft_model_config):
return None

hf_config = self.speculative_config.draft_model_config.hf_config
if not hasattr(hf_config, "eagle_aux_hidden_state_layer_ids"):
return None

layer_ids = hf_config.eagle_aux_hidden_state_layer_ids
if layer_ids and isinstance(layer_ids, (list, tuple)):
return tuple(layer_ids)

return None

def reload_weights(self) -> None:
assert getattr(self, "model", None) is not None, (
"Cannot reload weights before model is loaded."
Expand Down