Skip to content
This repository has been archived by the owner on Oct 11, 2024. It is now read-only.

Commit

Permalink
[Core]: Option To Use Prompt Token Ids Inside Logits Processor (vllm-…
Browse files Browse the repository at this point in the history
…project#4985)

Co-authored-by: Elisei Smirnov <el.smirnov@innopolis.university>
  • Loading branch information
2 people authored and robertgshaw2-neuralmagic committed Jul 14, 2024
1 parent db8e4c4 commit 0dc8e26
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
17 changes: 14 additions & 3 deletions vllm/model_executor/layers/logits_processor.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""A layer that compute logits from hidden_stats."""
import inspect
from typing import Optional

import torch
Expand Down Expand Up @@ -95,15 +96,25 @@ def _apply_logits_processors(
seq_ids = seq_group.seq_ids
sampling_params = seq_group.sampling_params
logits_processors = sampling_params.logits_processors

if logits_processors:
found_logits_processors = True

for seq_id, logits_row_idx in zip(seq_ids,
seq_group.sample_indices):
logits_row = logits[logits_row_idx]
token_ids = seq_group.seq_data[seq_id].output_token_ids
past_tokens_ids = seq_group.seq_data[seq_id].output_token_ids
prompt_tokens_ids = seq_group.seq_data[seq_id].prompt_token_ids

for logits_processor in logits_processors:
logits_row = logits_processor(token_ids, logits_row)
parameters = inspect.signature(logits_processor).parameters
if len(parameters) == 3:
logits_row = logits_processor(prompt_tokens_ids,
past_tokens_ids,
logits_row)
else:
logits_row = logits_processor(past_tokens_ids,
logits_row)

logits[logits_row_idx] = logits_row

logits_processed += len(seq_group.sample_indices) + len(
Expand Down
15 changes: 10 additions & 5 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,14 @@ class SamplingType(IntEnum):
BEAM = 3


LogitsProcessor = Callable[[List[int], torch.Tensor], torch.Tensor]
"""LogitsProcessor is a function that takes a list of previously generated
tokens and a tensor of the logits for the next token, and returns a modified
tensor of logits to sample from."""
LogitsProcessor = Union[Callable[[List[int], torch.Tensor], torch.Tensor],
Callable[[List[int], List[int], torch.Tensor],
torch.Tensor]]
"""LogitsProcessor is a function that takes a list
of previously generated tokens, the logits tensor
for the next token and, optionally, prompt tokens as a
first argument, and returns a modified tensor of logits
to sample from."""


class SamplingParams:
Expand Down Expand Up @@ -95,7 +99,8 @@ class SamplingParams:
spaces_between_special_tokens: Whether to add spaces between special
tokens in the output. Defaults to True.
logits_processors: List of functions that modify logits based on
previously generated tokens.
previously generated tokens, and optionally prompt tokens as
a first argument.
truncate_prompt_tokens: If set to an integer k, will use only the last k
tokens from the prompt (i.e., left truncation). Defaults to None
(i.e., no truncation).
Expand Down

0 comments on commit 0dc8e26

Please sign in to comment.