Skip to content

Commit b549e95

Browse files
zixi-qiminpeter
authored andcommitted
[Bugfix] Fix EAGLE vocab embedding for multimodal target model (vllm-project#19570)
Signed-off-by: qizixi <qizixi@meta.com> Signed-off-by: minpeter <kali2005611@gmail.com>
1 parent e7e92e9 commit b549e95

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

vllm/v1/spec_decode/eagle.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -329,16 +329,24 @@ def load_model(self, target_model: nn.Module) -> None:
329329

330330
self.attn_layer_names = list(draft_attn_layer_names)
331331

332+
if supports_multimodal(target_model):
333+
# handle multimodality
334+
self.model.config.image_token_index = (
335+
target_model.config.image_token_index)
336+
target_language_model = target_model.get_language_model()
337+
else:
338+
target_language_model = target_model
332339
# share embed_tokens with the target model if needed
333340
if get_pp_group().world_size == 1 \
334341
and self.model.model.embed_tokens.weight.shape \
335-
== target_model.model.embed_tokens.weight.shape:
342+
== target_language_model.model.embed_tokens.weight.shape:
336343
logger.info(
337344
"Assuming the EAGLE head shares the same vocab embedding" \
338345
" with the target model."
339346
)
340347
del self.model.model.embed_tokens
341-
self.model.model.embed_tokens = target_model.model.embed_tokens
348+
self.model.model.embed_tokens = (
349+
target_language_model.model.embed_tokens)
342350
else:
343351
logger.info(
344352
"The EAGLE head's vocab embedding will be loaded separately" \
@@ -349,12 +357,9 @@ def load_model(self, target_model: nn.Module) -> None:
349357
# some model definition do not define lm_head explicitly
350358
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
351359
if self.vllm_config.speculative_config.method != "eagle3" and \
352-
hasattr(target_model, "lm_head"):
360+
hasattr(target_language_model, "lm_head"):
353361
logger.info("Loading EAGLE LM head weights from the target model.")
354-
if supports_multimodal(target_model):
355-
self.model.lm_head = target_model.get_language_model().lm_head
356-
else:
357-
self.model.lm_head = target_model.lm_head
362+
self.model.lm_head = target_language_model.lm_head
358363

359364
@torch.inference_mode()
360365
def dummy_run(

0 commit comments

Comments
 (0)