Skip to content

[Core]: Option To Use Prompt Token Ids Inside Logits Processor #4985

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
May 23, 2024
Merged
17 changes: 14 additions & 3 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A layer that compute logits from hidden_stats."""
import inspect
from typing import Optional

import torch
Expand Down Expand Up @@ -95,15 +96,25 @@ def _apply_logits_processors(
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors

if logits_processors:
found_logits_processors = True

for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = seq_group.seq_data[seq_id].output_token_ids
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids

for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)

logits[logits_row_idx] = logits_row

logits_processed += len(seq_group.sample_indices) + len(
Expand Down
15 changes: 10 additions & 5 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ class SamplingType(IntEnum):
BEAM = 3


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""


class SamplingParams:
Expand Down Expand Up @@ -95,7 +99,8 @@ class SamplingParams:
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
previously generated tokens, and optionally prompt tokens as
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
Expand Down
Loading