Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/offline_inference/logits_processor/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class object.
BatchUpdate,
LogitsProcessor,
)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
from vllm.v1.sample.logits_processor.interface import process_dict_updates


# Hypothetical custom logits processor
Expand Down
2 changes: 1 addition & 1 deletion tests/v1/logits_processors/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
AdapterLogitsProcessor,
BatchUpdate, LogitsProcessor,
RequestLogitsProcessor)
from vllm.v1.sample.logits_processor.builtin import process_dict_updates
from vllm.v1.sample.logits_processor.interface import process_dict_updates

logger = init_logger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions vllm/v1/sample/logits_processor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,11 @@
from vllm.sampling_params import SamplingParams
from vllm.v1.sample.logits_processor.builtin import (LogitBiasLogitsProcessor,
MinPLogitsProcessor,
MinTokensLogitsProcessor,
process_dict_updates)
MinTokensLogitsProcessor)
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
MoveDirectionality,
process_dict_updates)
from vllm.v1.sample.logits_processor.state import (BatchUpdateBuilder,
LogitsProcessors)

Expand All @@ -36,9 +36,9 @@
LOGITSPROCS_GROUP = 'vllm.logits_processors'

BUILTIN_LOGITS_PROCESSORS: list[type[LogitsProcessor]] = [
MinTokensLogitsProcessor,

Check failure on line 39 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]

Check failure on line 39 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]

Check failure on line 39 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]
LogitBiasLogitsProcessor,

Check failure on line 40 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]

Check failure on line 40 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]
MinPLogitsProcessor,

Check failure on line 41 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]

Check failure on line 41 in vllm/v1/sample/logits_processor/__init__.py

View workflow job for this annotation

GitHub Actions / pre-commit

Only concrete class can be given where "type[LogitsProcessor[Any]]" is expected [type-abstract]
]


Expand Down
49 changes: 3 additions & 46 deletions vllm/v1/sample/logits_processor/builtin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,19 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from collections.abc import Sequence
from typing import TYPE_CHECKING, Callable, Optional, TypeVar
from typing import TYPE_CHECKING, Optional

import torch

from vllm import SamplingParams
from vllm.v1.sample.logits_processor.interface import (BatchUpdate,
LogitsProcessor,
MoveDirectionality)
MoveDirectionality,
process_dict_updates)

if TYPE_CHECKING:
from vllm.config import VllmConfig

T = TypeVar("T")


class MinPLogitsProcessor(LogitsProcessor):

Expand Down Expand Up @@ -231,45 +230,3 @@ def apply(self, logits: torch.Tensor) -> torch.Tensor:
# Inhibit EOS token for requests which have not reached min length
logits[self.logits_slice] = -float("inf")
return logits


def process_dict_updates(
req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
new_state: Callable[[SamplingParams, Optional[list[int]], list[int]],
Optional[T]]
) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors."""

if not batch_update:
# Nothing to do.
return False

updated = False
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids,
output_tok_ids)) is not None:
req_entries[index] = state
updated = True
elif req_entries.pop(index, None) is not None:
updated = True

if req_entries:
# Process removed requests.
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True

# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None)
b_entry = req_entries.pop(b_index, None)
if a_entry is not None:
req_entries[b_index] = a_entry
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry

return updated
92 changes: 83 additions & 9 deletions vllm/v1/sample/logits_processor/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from collections.abc import Sequence
from dataclasses import dataclass
from enum import Enum, auto
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Callable, Generic, Optional, TypeVar

import torch

Expand All @@ -13,6 +13,8 @@
if TYPE_CHECKING:
from vllm.config import VllmConfig

T = TypeVar("T")


class MoveDirectionality(Enum):
# One-way i1->i2 req move within batch
Expand Down Expand Up @@ -56,12 +58,13 @@
moved: Sequence[MovedRequest]


class LogitsProcessor(ABC):
class LogitsProcessor(ABC, Generic[T]):

@abstractmethod
def __init__(self, vllm_config: "VllmConfig", device: torch.device,
is_pin_memory: bool) -> None:
raise NotImplementedError

# Per-request logits processor state
self.states: dict[int, T] = {}

@abstractmethod
def apply(self, logits: torch.Tensor) -> torch.Tensor:
Expand All @@ -72,17 +75,41 @@
"""
raise NotImplementedError

@abstractmethod
def is_argmax_invariant(self) -> bool:
"""True if logits processor has no impact on the
argmax computation in greedy sampling.
"""True if logits processor has no impact on the argmax computation in
greedy sampling; causes logits processor to be optimized away in greedy
sampling scenarios. Base-class default is false but can be overriden by
subclass.
NOTE: may or may not have the same value for all
instances of a given LogitsProcessor subclass,
depending on subclass implementation.
"""
raise NotImplementedError
return False

@abstractmethod
def get_state_from_params(self, params: SamplingParams,
prompt_tok_ids: list[int],
out_tok_ids: list[int]) -> Optional[T]:
"""Produce a minimal representation of initial logits processor state
for a newly-added request

Args:
params: `SamplingParams` instance for request newly-added to batch
prompt_tok_ids: list of new request prompt token ids
out_tok_ids: list of request generated tokens as of current engine
step

Returns:
`None` if logits processor is not applicable to request; otherwise,
instance of initial logits processor state representation
"""
raise NotImplementedError

def state_update_callback(self) -> None:
"""Override to implement specialized optimizations to logits processor
state management."""
pass

def update_state(
self,
batch_update: Optional["BatchUpdate"],
Expand All @@ -94,4 +121,51 @@
batch_update: Non-None iff there have been changes
to the batch makeup.
"""
raise NotImplementedError
needs_update = process_dict_updates(self.states, batch_update,
self.get_state_from_params)

Check failure on line 125 in vllm/v1/sample/logits_processor/interface.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 3 to "process_dict_updates" has incompatible type "Callable[[Any, list[int], list[int]], T | None]"; expected "Callable[[Any, list[int] | None, list[int]], T | None]" [arg-type]

Check failure on line 125 in vllm/v1/sample/logits_processor/interface.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 3 to "process_dict_updates" has incompatible type "Callable[[Any, list[int], list[int]], T | None]"; expected "Callable[[Any, list[int] | None, list[int]], T | None]" [arg-type]

Check failure on line 125 in vllm/v1/sample/logits_processor/interface.py

View workflow job for this annotation

GitHub Actions / pre-commit

Argument 3 to "process_dict_updates" has incompatible type "Callable[[Any, list[int], list[int]], Optional[T]]"; expected "Callable[[Any, Optional[list[int]], list[int]], Optional[T]]" [arg-type]

if needs_update:
# Apply custom
self.state_update_callback()


def process_dict_updates(
req_entries: dict[int, T], batch_update: Optional[BatchUpdate],
new_state: Callable[[SamplingParams, Optional[list[int]], list[int]],
Optional[T]]
) -> bool:
"""Utility function to update dict state for sparse LogitsProcessors."""

if not batch_update:
# Nothing to do.
return False

updated = False
for index, params, prompt_tok_ids, output_tok_ids in batch_update.added:
if (state := new_state(params, prompt_tok_ids,
output_tok_ids)) is not None:
req_entries[index] = state
updated = True
elif req_entries.pop(index, None) is not None:
updated = True

if req_entries:
# Process removed requests.
for index in batch_update.removed:
if req_entries.pop(index, None):
updated = True

# Process moved requests, unidirectional (a->b) and
# swapped (a<->b)
for a_index, b_index, direct in batch_update.moved:
a_entry = req_entries.pop(a_index, None)
b_entry = req_entries.pop(b_index, None)
if a_entry is not None:
req_entries[b_index] = a_entry
updated = True
if b_entry is not None:
updated = True
if direct == MoveDirectionality.SWAP:
req_entries[a_index] = b_entry

return updated
Loading