Skip to content

[RFC]: Revise Logits Processor Programming Model #25389

@afeldman-nm

Description

@afeldman-nm

Motivation.

The purpose of this work is to hide details of the vLLM engine implementation from logits processor implementors; currently the vLLM logits processor programming model requires the logits processor implementor to consider complex bookkeeping details about persistent batch state changes. The proposed changes decrease the amount of bookkeeping related to batch ordering which the implementor must consider, by taking advantage of upcoming changes to the vLLM model runner.

This change impacts

  1. How existing builtin logits processors based on LogitsProcessor base class (min-p, logits bias, and min tokens) are implemented
  2. How new builtin logits processors would be implemented
  3. How custom logits processors would be implemented

There are still some logits processors hard-coded into the vLLM engine implementation - this interface change is a prerequisite for porting these hard-coded logits processors to become subclasses of LogitsProcessor.

Current logits processor programming model

Currently, to define a logits processor, you must subclass vllm.v1.sample.logits_processor.LogitsProcessor and define (at minimum) the following methods:

  • __init__(self, vllm_config: VllmConfig, device: torch.device, is_pin_memory: bool)
    • vllm_config: engine configuration data structure
    • device: hardware accelerator device info
    • is_pin_memory: flag indicating whether pin memory is available to support logits processor implementation
  • apply(self, logits: torch.Tensor) -> torch.Tensor:
    • Consume a (num_requests) x (vocab_size) logits tensor (logits)
    • Apply logits processor transformation at batch granularity
    • Return a transformed (num_requests) x (vocab_size) logits tensor
    • You can modify the input logits processors in-place or out-of-place; in-place is more memory-efficient
  • is_argmax_invariant(self) -> bool:
    • Return True if the logits processor is argmax invariant (never changes what is the highest-logit-value token ID for a given request), False if the logits processor may modify argmax
    • is_argmax_invariant() is evaluated once at startup; if True, vLLM will skip applying this logits processor in a given step when all requests use greedy sampling
  • update_state(self, batch_update: Optional["BatchUpdate"]) -> None:
    • Consume a BatchUpdate data structure representing persistent batch state changes at the beginning of the current engine step
    • Use the BatchUpdate members to update logits processor internal state
    • Note: batch update data structure may be None, signaling no change to the batch constituents. In this case, the LogitsProcessor might still want to update its state based on the updated output_token_ids lists that it could have retained when they were added.

These methods are also the interface presented to the engine for invoking the logits processor. So currently, the logits processor interface to the vLLM engine is the same as the programming model for defining a logits processor.

Problems

  1. Currently, when the logits processor implementor writes update_state(self, batch_update: Optional["BatchUpdate"]) -> None, they must handle complex bookkeeping resulting from changes in vLLM persistent batch state. Every add, removed or reordered request in a given step potentially results in an update to logits processor internal state. These batch state details are not conceptually related to logits processors, but must be considered because vLLM logits processors operate at batch granularity.

  2. Relatedly - there is a WIP revision to the vLLM model runner ( GPU Model Runner V2 #25266 ) which will simplify representation of persistent batch state and eliminate the complex bookkeeping - however, it will require a change in the logits processor engine interface.

Proposed Change.

  • GPU Model Runner V2 #25266 will eliminate batch "reordering" because the persistent batch state is unordered; each step applies a new batch order on-the-fly (call this "stepwise batch ordering"). Thus the update_state() method in the logits processor interface must accept a mapping from persistent batch index (call this "persistent index") to stepwise index
  • Separate "add request" and "remove request" functionality out from update_state(), to improve separation of concerns

Supporting data structures

Type annotations for added and removed requests:

# Persistent index of removed request
RemovedRequest = int

# (params, prompt length, persistent index) tuples for new
# requests added to the batch
AddedRequest = tuple[SamplingParams, int, int]

These two type annotations will be used in new logits processor add_request(), remove_request() methods

The BatchUpdate data structure is no longer necessary.

Defining a logits processor

To define a logits processor, you would subclass vllm.v1.sample.logits_processor.LogitsProcessor and define (at minimum) the following methods:

  • __init__(...), is_argmax_invariant(...): unchanged
  • apply(self, logits: torch.Tensor) -> torch.Tensor:
    • Essentially unchanged from current specification. However note that the mapping from logits rows to requests is determined by the current stepwise ordering.
    • Thus update_state() must leverage its index_mapping argument such that apply() knows how to transform each row of logits
  • add_request(self, req_info: AddedRequest) -> None
    • Possible outcomes: (1) add info about a new request to the logits processor. (2) no change.
  • remove_request(self, req_pdx: RemoveRequest) -> None
    • Remove information about a specified request from logits processor internal state, if such information is present
  • update_state(self, index_mapping: Optional[npt.NDArray], token_ids: npt.NDArray, seq_lens: npt.NDArray) -> None:
    • index_mapping: optional stepwise request ordering, as a 1D numpy array with length equal to the size of the persistent batch. None means the stepwise ordering is unchanged from the previous step
    • token_ids: a (persistent batch size) x (max num tokens) numpy array. Each row index corresponds to the persistent index of a request in the persistent batch; the contents of the row are the prompt ids and output ids concatenated for that request
    • seq_lens: a 1D numpy array with length equal to the size of the persistent batch. Each index corresponds to the persistent index of a request in the persistent batch; the value at each index is the total sequence length (prompt + output tokens)
    • Note: If index_mapping is None, the LogitsProcessor might still want to update its state based on the updated token_ids or seq_lens.

Why this new approach is beneficial:

While the logits processor implementor still has to consider details of batch ordering, it is much easier to apply what is effectively a single permutation matrix (index_mapping), rather than a complex sequence of add/remove/reorder operations as was done previously.

Feedback Period.

1 Week

CC List.

@njhill @WoosukKwon @simon-mo @robertgshaw2-redhat

Any Other Things.

No response

Before submitting a new issue...

  • Make sure you already searched for relevant issues, and asked the chatbot living at the bottom right corner of the documentation page, which can answer lots of frequently asked questions.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions