Skip to content

[Bugfix] Fix EAGLE vocab embedding for multimodal target model #19570

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 13, 2025
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions vllm/v1/spec_decode/eagle.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,16 +329,24 @@ def load_model(self, target_model: nn.Module) -> None:

self.attn_layer_names = list(draft_attn_layer_names)

if supports_multimodal(target_model):
# handle multimodality
self.model.config.image_token_index = (
target_model.config.image_token_index)
target_language_model = target_model.get_language_model()
else:
target_language_model = target_model
# share embed_tokens with the target model if needed
if get_pp_group().world_size == 1 \
and self.model.model.embed_tokens.weight.shape \
== target_model.model.embed_tokens.weight.shape:
== target_language_model.model.embed_tokens.weight.shape:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure that target_language_model.model.embed_tokens is not None before accessing its weight attribute to avoid potential AttributeError exceptions.

if target_language_model.model.embed_tokens and self.model.model.embed_tokens.weight.shape == target_language_model.model.embed_tokens.weight.shape:

logger.info(
"Assuming the EAGLE head shares the same vocab embedding" \
" with the target model."
)
del self.model.model.embed_tokens
self.model.model.embed_tokens = target_model.model.embed_tokens
self.model.model.embed_tokens = (
target_language_model.model.embed_tokens)
else:
logger.info(
"The EAGLE head's vocab embedding will be loaded separately" \
Expand All @@ -349,12 +357,9 @@ def load_model(self, target_model: nn.Module) -> None:
# some model definition do not define lm_head explicitly
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
if self.vllm_config.speculative_config.method != "eagle3" and \
hasattr(target_model, "lm_head"):
hasattr(target_language_model, "lm_head"):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Consider adding a check to ensure that target_language_model has the attribute lm_head and that it is not None before accessing it to avoid potential AttributeError exceptions.

if hasattr(target_language_model, "lm_head") and target_language_model.lm_head is not None:

logger.info("Loading EAGLE LM head weights from the target model.")
if supports_multimodal(target_model):
self.model.lm_head = target_model.get_language_model().lm_head
else:
self.model.lm_head = target_model.lm_head
self.model.lm_head = target_language_model.lm_head

@torch.inference_mode()
def dummy_run(
Expand Down