Skip to content

[Speculative Decoding] Enable arbitrary model inputs #5101

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed

Conversation

abhigoyal1997
Copy link
Contributor

@abhigoyal1997 abhigoyal1997 commented May 29, 2024

This PR changes the ModelRunner to support models with a different input signature than the default. Mainly, this will benefit speculative decoding methods where draft models are not standard Transformers-based LMs e.g., Medusa, EAGLE, RNN-based, etc. (additionally, in future, any LLMs with different input signature than the default).

For this, we need support for 2 things:

  • When ModelRunner calls the forward method of the model, only use the expected inputs.
    • Match model inputs with the signature of the model's forward method.
    • Don't prepare unnecessary inputs (this is good to have as it might reduce some overhead).
  • Add support for models requiring additional inputs than the default ones e.g., hidden_states in Medusa.
    • Allow models to specify an optional config (shape and dtype) for the additional inputs (to capture CUDA graph).
    • Prepare these additional inputs in ModelRunner (or in the model itself?) and pass them as inputs to the forward call.
      • Support inputs that come from the sequence (via seq_group_metadata_list)
      • Support inputs that live inside the Worker/ModelRunner as preserved state from prev. iteration

Part of refactoring #4978

@abhigoyal1997 abhigoyal1997 marked this pull request as draft May 29, 2024 10:58
@abhigoyal1997 abhigoyal1997 changed the title [Misc] [Speculative Decoding] Enable arbitrary model inputs [Speculative Decoding] Enable arbitrary model inputs May 29, 2024
@DarkLight1337
Copy link
Member

This would also be great for multi-modal LLMs which accept inputs from other modalities.

I am currently working on #4197 which enables additional inputs to be passed in via decorating the model class with input processors, but it assumes that the inputs are tied to specific modalities. Perhaps we can further generalize that idea in your PR?

@abhigoyal1997
Copy link
Contributor Author

This would also be great for multi-modal LLMs which accept inputs from other modalities.

I am currently working on #4197 which enables additional inputs to be passed in via decorating the model class with input processors, but it assumes that the inputs are tied to specific modalities. Perhaps we can further generalize that idea in your PR?

Hi @DarkLight1337

This looks like a reasonable idea to try out. We can generalize the MultiModalRegistry (maybe as InputRegistry) and register all additional inputs (including multi-model) using the same method.

However, I opened this PR just to refactor and simplify additional inputs in Medusa implementation (#4978) and I am more focused on that for now. Once that PR is closed, I can look more into this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants