-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[Bugfix] Fix EAGLE vocab embedding for multimodal target model #19570
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
b561297
small refactor for eagle base model with mm support
zixi-qi 89c3c1d
Merge branch 'vllm-project:main' into eagle-mm-refactor
zixi-qi da5af78
small refactor for eagle base model with mm support
zixi-qi 2d62609
remove image_token_index check
zixi-qi File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -329,16 +329,24 @@ def load_model(self, target_model: nn.Module) -> None: | |
|
||
self.attn_layer_names = list(draft_attn_layer_names) | ||
|
||
if supports_multimodal(target_model): | ||
# handle multimodality | ||
self.model.config.image_token_index = ( | ||
target_model.config.image_token_index) | ||
target_language_model = target_model.get_language_model() | ||
else: | ||
target_language_model = target_model | ||
# share embed_tokens with the target model if needed | ||
if get_pp_group().world_size == 1 \ | ||
and self.model.model.embed_tokens.weight.shape \ | ||
== target_model.model.embed_tokens.weight.shape: | ||
== target_language_model.model.embed_tokens.weight.shape: | ||
logger.info( | ||
"Assuming the EAGLE head shares the same vocab embedding" \ | ||
" with the target model." | ||
) | ||
del self.model.model.embed_tokens | ||
self.model.model.embed_tokens = target_model.model.embed_tokens | ||
self.model.model.embed_tokens = ( | ||
target_language_model.model.embed_tokens) | ||
else: | ||
logger.info( | ||
"The EAGLE head's vocab embedding will be loaded separately" \ | ||
|
@@ -349,12 +357,9 @@ def load_model(self, target_model: nn.Module) -> None: | |
# some model definition do not define lm_head explicitly | ||
# and reuse embed_tokens for lm_head, e.g., CohereForCausalLM | ||
if self.vllm_config.speculative_config.method != "eagle3" and \ | ||
hasattr(target_model, "lm_head"): | ||
hasattr(target_language_model, "lm_head"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
logger.info("Loading EAGLE LM head weights from the target model.") | ||
if supports_multimodal(target_model): | ||
self.model.lm_head = target_model.get_language_model().lm_head | ||
else: | ||
self.model.lm_head = target_model.lm_head | ||
self.model.lm_head = target_language_model.lm_head | ||
|
||
@torch.inference_mode() | ||
def dummy_run( | ||
|
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Consider adding a check to ensure that
target_language_model.model.embed_tokens
is not None before accessing itsweight
attribute to avoid potentialAttributeError
exceptions.