@@ -329,16 +329,24 @@ def load_model(self, target_model: nn.Module) -> None:
329
329
330
330
self .attn_layer_names = list (draft_attn_layer_names )
331
331
332
+ if supports_multimodal (target_model ):
333
+ # handle multimodality
334
+ self .model .config .image_token_index = (
335
+ target_model .config .image_token_index )
336
+ target_language_model = target_model .get_language_model ()
337
+ else :
338
+ target_language_model = target_model
332
339
# share embed_tokens with the target model if needed
333
340
if get_pp_group ().world_size == 1 \
334
341
and self .model .model .embed_tokens .weight .shape \
335
- == target_model .model .embed_tokens .weight .shape :
342
+ == target_language_model .model .embed_tokens .weight .shape :
336
343
logger .info (
337
344
"Assuming the EAGLE head shares the same vocab embedding" \
338
345
" with the target model."
339
346
)
340
347
del self .model .model .embed_tokens
341
- self .model .model .embed_tokens = target_model .model .embed_tokens
348
+ self .model .model .embed_tokens = (
349
+ target_language_model .model .embed_tokens )
342
350
else :
343
351
logger .info (
344
352
"The EAGLE head's vocab embedding will be loaded separately" \
@@ -349,12 +357,9 @@ def load_model(self, target_model: nn.Module) -> None:
349
357
# some model definition do not define lm_head explicitly
350
358
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM
351
359
if self .vllm_config .speculative_config .method != "eagle3" and \
352
- hasattr (target_model , "lm_head" ):
360
+ hasattr (target_language_model , "lm_head" ):
353
361
logger .info ("Loading EAGLE LM head weights from the target model." )
354
- if supports_multimodal (target_model ):
355
- self .model .lm_head = target_model .get_language_model ().lm_head
356
- else :
357
- self .model .lm_head = target_model .lm_head
362
+ self .model .lm_head = target_language_model .lm_head
358
363
359
364
@torch .inference_mode ()
360
365
def dummy_run (
0 commit comments