-
-
Notifications
You must be signed in to change notification settings - Fork 11k
Description
Motivation.
The purpose of this work is to hide details of the vLLM engine implementation from logits processor implementors; currently the vLLM logits processor programming model requires the logits processor implementor to consider complex bookkeeping details about persistent batch state changes. The proposed changes decrease the amount of bookkeeping related to batch ordering which the implementor must consider, by taking advantage of upcoming changes to the vLLM model runner.
This change impacts
- How existing builtin logits processors based on
LogitsProcessorbase class (min-p, logits bias, and min tokens) are implemented - How new builtin logits processors would be implemented
- How custom logits processors would be implemented
There are still some logits processors hard-coded into the vLLM engine implementation - this interface change is a prerequisite for porting these hard-coded logits processors to become subclasses of LogitsProcessor.
Current logits processor programming model
Currently, to define a logits processor, you must subclass vllm.v1.sample.logits_processor.LogitsProcessor and define (at minimum) the following methods:
__init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)vllm_config: engine configuration data structuredevice: hardware accelerator device infois_pin_memory: flag indicating whether pin memory is available to support logits processor implementation
apply(self, logits: torch.Tensor) -> torch.Tensor:- Consume a
(num_requests) x (vocab_size)logits tensor (logits) - Apply logits processor transformation at batch granularity
- Return a transformed
(num_requests) x (vocab_size)logits tensor - You can modify the input logits processors in-place or out-of-place; in-place is more memory-efficient
- Consume a
is_argmax_invariant(self) -> bool:- Return
Trueif the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request),Falseif the logits processor may modify argmax is_argmax_invariant()is evaluated once at startup; ifTrue, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling
- Return
update_state(self, batch_update: Optional["BatchUpdate"]) -> None:- Consume a
BatchUpdatedata structure representing persistent batch state changes at the beginning of the current engine step - Use the
BatchUpdatemembers to update logits processor internal state - Note: batch update data structure may be
None, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updatedoutput_token_idslists that it could have retained when they were added.
- Consume a
These methods are also the interface presented to the engine for invoking the logits processor. So currently, the logits processor interface to the vLLM engine is the same as the programming model for defining a logits processor.
Problems
-
Currently, when the logits processor implementor writes
update_state(self, batch_update: Optional["BatchUpdate"]) -> None, they must handle complex bookkeeping resulting from changes in vLLM persistent batch state. Every add, removed or reordered request in a given step potentially results in an update to logits processor internal state. These batch state details are not conceptually related to logits processors, but must be considered because vLLM logits processors operate at batch granularity. -
Relatedly - there is a WIP revision to the vLLM model runner ( GPU Model Runner V2 #25266 ) which will simplify representation of persistent batch state and eliminate the complex bookkeeping - however, it will require a change in the logits processor engine interface.
Proposed Change.
- GPU Model Runner V2 #25266 will eliminate batch "reordering" because the persistent batch state is unordered; each step applies a new batch order on-the-fly (call this "stepwise batch ordering"). Thus the
update_state()method in the logits processor interface must accept a mapping from persistent batch index (call this "persistent index") to stepwise index - Separate "add request" and "remove request" functionality out from
update_state(), to improve separation of concerns
Supporting data structures
Type annotations for added and removed requests:
# Persistent index of removed request
RemovedRequest = int
# (params, prompt length, persistent index) tuples for new
# requests added to the batch
AddedRequest = tuple[SamplingParams, int, int]
These two type annotations will be used in new logits processor add_request(), remove_request() methods
The BatchUpdate data structure is no longer necessary.
Defining a logits processor
To define a logits processor, you would subclass vllm.v1.sample.logits_processor.LogitsProcessor and define (at minimum) the following methods:
__init__(...),is_argmax_invariant(...): unchangedapply(self, logits: torch.Tensor) -> torch.Tensor:- Essentially unchanged from current specification. However note that the mapping from
logitsrows to requests is determined by the current stepwise ordering. - Thus
update_state()must leverage itsindex_mappingargument such thatapply()knows how to transform each row oflogits
- Essentially unchanged from current specification. However note that the mapping from
add_request(self, req_info: AddedRequest) -> None- Possible outcomes: (1) add info about a new request to the logits processor. (2) no change.
remove_request(self, req_pdx: RemoveRequest) -> None- Remove information about a specified request from logits processor internal state, if such information is present
update_state(self, index_mapping: Optional[npt.NDArray], token_ids: npt.NDArray, seq_lens: npt.NDArray) -> None:index_mapping: optional stepwise request ordering, as a 1D numpy array with length equal to the size of the persistent batch.Nonemeans the stepwise ordering is unchanged from the previous steptoken_ids: a(persistent batch size) x (max num tokens)numpy array. Each row index corresponds to the persistent index of a request in the persistent batch; the contents of the row are the prompt ids and output ids concatenated for that requestseq_lens: a 1D numpy array with length equal to the size of the persistent batch. Each index corresponds to the persistent index of a request in the persistent batch; the value at each index is the total sequence length (prompt + output tokens)- Note: If
index_mappingisNone, the LogitsProcessor might still want to update its state based on the updatedtoken_idsorseq_lens.
Why this new approach is beneficial:
While the logits processor implementor still has to consider details of batch ordering, it is much easier to apply what is effectively a single permutation matrix (index_mapping), rather than a complex sequence of add/remove/reorder operations as was done previously.
Feedback Period.
1 Week
CC List.
@njhill @WoosukKwon @simon-mo @robertgshaw2-redhat
Any Other Things.
No response
Before submitting a new issue...
- Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.