-
-
Notifications
You must be signed in to change notification settings - Fork 10.3k
Description
Motivation.
Prompt Embedding inputs are a niche, but frequently asked for feature in vLLM. #15428 introduced them in the v0 engine, but they have not yet been ported to the v1 engine. Prompt embedding users will be stuck on older versions of vLLM unless the feature is also introduced into the v1 engine.
Related historical issues/PRs:
v0 Support:
- [Core] [Bugfix] Add Input Embeddings #15428
- [Core] Gate
prompt_embeds
behind a feature flag #17607 - Feature/vllm/input embedding completion api #17590
- [Misc] V0 fallback for
--enable-prompt-embeds
#17615 - [Core] [Bugfix]: tensor parallel with prompt embeds #18171
- [Bugfix]: Batch generation from prompt_embeds fails for long prompts #21390
- [Bugfix] Fix isinstance check for tensor types in _load_prompt_embeds to use dtype comparison #21612
Open Issues that would require v1 support to fix:
- [Bug]: Prompt Embedding returns 500 internal error for Qwen 2.5 VL model #20757 (Although this issue is closed because the user's original issue was resolved, it revealed a more fundamental incompatibility between the v0 implementation and multi-modal models in that engine which was not resolved).
- [Usage]: embed prompts #19746
Proposed Change.
Input processing
Input pre-processing
#15428 introduced many changes for how inputs are pre-processed. Luckily most of those changes are shared between v0 and v1, so few changes will be needed for input pre-processing.
v1 Processor
The changes are minor and mirror those made to the v0 processor. Namely validating the model input should consider the length of prompt embeds OR prompt_token_ids. Additional the v1 EngineCoreRequest requires a new prompt_embeds field to pass in the prompt_embeds to the engine.
There are several other places between the input processing and the scheduler where a new prompt_embeds
field will be needed in some struct, including:
- Request
- NewRequestData
- RequestState (both the init and the from_new_request methods)
- EngineCoreRequest
- CachedRequestState
The Request.__init__
initializes a private variable self._all_token_ids
to be a copy of the input tokens, but it should be full of placeholder tokens, something like:
self._all_token_ids: list[int] = (self.prompt_token_ids if self.prompt_token_ids is not None else [0]*self.num_prompt_tokens).copy()
Note
To the best of my knowledge, this should be the only place where placeholder tokens should be necessary, simplifying lots of logic to handle those placeholder tokens in the v0 engine’s implementation.
FastDetokenizer assumes prompt_token_ids are available for priming the detokenizer, this should be refactored when prompt_token_ids are unavailable.
Scheduler
Because prompt_embeds and prompt_token_ids cannot both be passed into the decoder model forward in the same batch, it is crucial that these be scheduled in separate batches. Upon scheduling a batch, the v0 engine picks either prompt_embeds or prompt_token_ids requests to fill the batch (favoring whichever the oldest remaining decode request happens to be). We propose porting this logic to the v1 scheduler.
Additional optimizations may be possible (such as treating prompt_token_ids as two steps, an “encoding” step to do the lookup to prompt embeds from token ids (somewhat similar to how multimodal requests are handled), and then the “decode” step which can be batched with raw prompt_embeds requests). Such ideas should be out of scope while trying to achieve feature parity as they increase the complexity of the scheduler.
Model Execution
GPU
The v1 GPU executor is sufficiently different from v0, that not much can be reused. The new architecture, which handles multimodal inputs by generating the appropriate inputs_embeds already has much of the support needed to take raw inputs_embeds and pass them to the models.
When a new request comes in, its prompt embeds can be forwarded directly to the model. When a request has previously generated new token ids, those should be converted to inputs_embeds by passing them through the get_inputs_embeds()
module of the model during the next decoding step. Those new inputs_embeds corresponding to each new token, should be appended to the existing ones cached.
This logic will look something like (pseudo code):
if self.is_multimodal_model and get_pp_group().is_first_rank:
# NOTE(woosuk): To unify token ids and soft tokens (vision
# embeddings), we always use embeddings (rather than token ids)
# as input to the multimodal model, even when the input is text.
input_ids = self.input_ids[:num_scheduled_tokens]
if mm_embeds:
inputs_embeds = self.model.get_input_embeddings(
input_ids, mm_embeds)
else:
inputs_embeds = self.model.get_input_embeddings(input_ids)
# TODO(woosuk): Avoid the copy. Optimize.
self.inputs_embeds[:num_scheduled_tokens].copy_(inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
elif self.model_config.enable_prompt_embeds and <batch has inputs_embeds>:
self.inputs_embeds[:num_scheduled_tokens].copy_(<batch’s inputs_embeds tensor>)
new_inputs_embeds = self.model.get_input_embeddings(<batch’s previously generated new tokens only>)
self.inputs_embeds[num_scheduled_tokens: num_input_tokens].copy_(new_inputs_embeds)
inputs_embeds = self.inputs_embeds[:num_input_tokens]
input_ids = None
else:
# This logic path remains unchanged when prompt_embeds is disabled (or batch does not contain prompt_embeds) to avoid the performance regression in https://github.com/vllm-project/vllm/pull/11032.
input_ids = self.input_ids[:num_input_tokens]
inputs_embeds = None
This could be optimized by caching the newly generated tokens' inputs_embeds on each step, so only one token (per sequence) needs converted each step, though the best structure to store those in isn't clear to me yet. Inspiration may be found in the existing logic for executing and gathering embeds from the multimodal encoders which were cached previous to the decoding step.
Care should be taken when implementing this to avoid tensor parallel issues, like those solved in #18171.
Cudagraph capture
In the v0 engine, we do two separate cudagraph captures for each batch size (one with input_ids and one with inputs_embeds/prompt embeddings). This allows us to maintain the execution speed-up from torch compile with either type of input. This RFC proposes porting over that two-cudagraph solution as is, only compiling the prompt embedding graph if --enable-prompt-embeds
is set.
Additional optimizations may be possible, given the differences between cudagraph capture in v1/v0 and the fact that for most models, the inputs_embeds graph is a strict subset of the input_ids graph. These ideas are out of scope for this RFC which aims primarily for feature parity with the v0 engine. Additional optimizations can be investigated in the future to speed up start-time and/or decrease the memory requirements of supporting compiled prompt embeds.
CPU
#19746 requests CPU support for prompt_embeds. In the v1 engine, this should be straightforward since it inherits from the GPU model runner. Changes needed to support CPU in this cause should be minimal to none.
TPU/XPU/other model runners
Supporting other model runners is likely not much additional work, but was not investigated as part of this RFC focusing on feature parity with the v0 engine. This can be done in follow-up work.
Output Processing
Since the model executor in this proposal will handle generating inputs_embeds for newly generated tokens from previous steps, there is no need to generate the corresponding embeddings for those tokens as part of post-processing and sampling, like is currently the case in v0. That said, it is important that care be taken to ensure that the bug fixed in #21386 is not reintroduced.
Other Changes
- The list of other supported parameters in the
/v1/completions
user guide does not include prompt_embeds, but should: https://docs.vllm.ai/en/stable/serving/openai_compatible_server.html?h=completions#completions-api_1 - [Misc] V0 fallback for
--enable-prompt-embeds
#17615 Adds a fallback to the v0 engine when prompt embeds is enabled. Once this change is complete, this fallback will be unnecessary. - Speculative decoding in general does not make sense with a single set of prompt_embeds and should be forced to be disabled when prompt embeds is enabled. This is not currently enforced in v0, and can cause crashes. @Nan2018 has pointed out that in principle, we can support speculative decoding, but would require both the full model and the speculative model to share an input embeddings space OR require two separate sets of inputs embeddings. Regardless, supporting those increases the complexity of the logic, and is out of scope for this RFC, which is focusing on feature parity.
Feedback Period.
I think the standard feedback period of 1 week mentioned in the docs is sufficient, especially given that the v0 engine is actively being removed from vLLM, and as such vLLM might go a version where prompt embeddings is not available at all if this is delayed. It will take some time to implement this, and I plan on having a draft PR open soon, but probably not within a week.
CC List.
@DarkLight1337 @Nan2018 @CandiedCode @WoosukKwon
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.