Skip to content

Commit ad51984

Browse files
committed
rebase, address comments
Signed-off-by: 22quinn <33176974+22quinn@users.noreply.github.com>
1 parent e1f0835 commit ad51984

File tree

9 files changed

+221
-23
lines changed

9 files changed

+221
-23
lines changed

tests/v1/sample/test_sampler.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,46 @@ def _create_allowed_token_ids(
7777
return mask
7878

7979

80+
def _create_bad_words_token_ids(
81+
batch_size: int, vocab_size: int,
82+
bad_words_lengths: list[tuple[int]]) -> dict[int, list[list[int]]]:
83+
bad_words_token_ids = {}
84+
for batch_idx in range(batch_size):
85+
token_ids_single_batch = []
86+
for bad_words_length in bad_words_lengths:
87+
token_ids = np.random.choice(vocab_size,
88+
size=bad_words_length,
89+
replace=True).tolist()
90+
token_ids_single_batch.append(token_ids)
91+
bad_words_token_ids[batch_idx] = token_ids_single_batch
92+
return bad_words_token_ids
93+
94+
95+
def _update_output_token_ids_for_bad_words(metadata: SamplingMetadata,
96+
vocab_size: int) -> list[list[int]]:
97+
bad_words_last_tokens = []
98+
for batch_idx in range(len(metadata.bad_words_token_ids)):
99+
bad_words_token_ids = metadata.bad_words_token_ids[batch_idx]
100+
output_token_ids = metadata.output_token_ids[batch_idx]
101+
bad_words_last_token: list[int] = []
102+
for i, bad_word_token_ids in enumerate(bad_words_token_ids):
103+
if len(bad_word_token_ids) == 1:
104+
# Single token id always affects logits
105+
bad_words_last_token.append(bad_word_token_ids[0])
106+
else:
107+
prefix_length = len(bad_word_token_ids) - 1
108+
has_bad_words = np.random.choice([True, False])
109+
if has_bad_words:
110+
output_token_ids[-prefix_length:] = bad_word_token_ids[:-1]
111+
bad_words_last_token.append(bad_word_token_ids[-1])
112+
break # Maximum one update to output_token_ids
113+
else: # Make sure no accidental match to bad words
114+
output_token_ids[-1] = (bad_word_token_ids[-2] +
115+
1) % vocab_size
116+
bad_words_last_tokens.append(bad_words_last_token)
117+
return bad_words_last_tokens
118+
119+
80120
def _create_default_sampling_metadata(
81121
num_output_tokens: int,
82122
batch_size: int,
@@ -112,6 +152,7 @@ def _create_default_sampling_metadata(
112152
min_tokens={},
113153
logit_bias=[None] * batch_size,
114154
allowed_token_ids_mask=None,
155+
bad_words_token_ids={},
115156
)
116157
return fake_sampling_metadata
117158

@@ -467,3 +508,34 @@ def test_sampler_allowed_token_ids(device: str, batch_size: int,
467508
"inf"), f"{batch_idx}, {token_id}"
468509
else:
469510
assert logits_for_req[token_id] != -float("inf")
511+
512+
513+
@pytest.mark.parametrize("device", CUDA_DEVICES)
514+
@pytest.mark.parametrize("batch_size", [1, 2, 32])
515+
@pytest.mark.parametrize("bad_words_lengths", [(1, ), (1, 3), (2, 2)])
516+
def test_sampler_bad_words(device: str, batch_size: int,
517+
bad_words_lengths: list[tuple[int]]):
518+
"""
519+
Test to verify that when the bad words restriction is present, tokens
520+
are penalized based on their match with the bad words.
521+
"""
522+
torch.set_default_device(device)
523+
# Create fake logits where each token is assigned the same
524+
# logit value.
525+
fake_logits = _create_fake_logits(batch_size, VOCAB_SIZE)
526+
sampling_metadata = _create_default_sampling_metadata(
527+
NUM_OUTPUT_TOKENS, batch_size, VOCAB_SIZE, torch.device(device))
528+
sampling_metadata.bad_words_token_ids = _create_bad_words_token_ids(
529+
batch_size, VOCAB_SIZE, bad_words_lengths)
530+
bad_words_last_tokens = _update_output_token_ids_for_bad_words(
531+
sampling_metadata, VOCAB_SIZE)
532+
sampler = Sampler()
533+
logits = sampler.apply_bad_words(fake_logits, sampling_metadata)
534+
logits = logits.cpu()
535+
for batch_idx in range(batch_size):
536+
logits_for_req = logits[batch_idx]
537+
for token_id in range(VOCAB_SIZE):
538+
if token_id in bad_words_last_tokens[batch_idx]:
539+
assert logits_for_req[token_id] == -float("inf")
540+
else:
541+
assert logits_for_req[token_id] != -float("inf")

tests/v1/worker/test_gpu_input_batch.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _construct_expected_sampling_metadata(
100100
VOCAB_SIZE,
101101
dtype=torch.bool,
102102
device=device)
103+
bad_words_token_ids = {}
103104
for req in reqs:
104105
if req.req_id not in req_ids_retained:
105106
continue
@@ -123,6 +124,8 @@ def _construct_expected_sampling_metadata(
123124
if req.sampling_params.allowed_token_ids:
124125
allowed_token_ids_mask[index_in_input_batch][
125126
req.sampling_params.allowed_token_ids] = True
127+
bad_words_token_ids[
128+
index_in_input_batch] = req.sampling_params.bad_words_token_ids
126129

127130
return SamplingMetadata(
128131
temperature=torch.tensor(temperature, dtype=torch.float,
@@ -159,6 +162,7 @@ def _construct_expected_sampling_metadata(
159162
and all(x == 1 for x in repetition_penalties)),
160163
logit_bias=logit_bias,
161164
allowed_token_ids_mask=allowed_token_ids_mask,
165+
bad_words_token_ids=bad_words_token_ids,
162166
)
163167

164168

@@ -284,6 +288,8 @@ def same(t1: Optional[torch.Tensor], t2: Optional[torch.Tensor]) -> bool:
284288
assert torch.allclose(
285289
expected_sampling_metadata.allowed_token_ids_mask,
286290
sampling_metadata.allowed_token_ids_mask)
291+
assert expected_sampling_metadata.bad_words_token_ids == \
292+
sampling_metadata.bad_words_token_ids
287293

288294

289295
@pytest.mark.parametrize("device", CUDA_DEVICES)

vllm/sampling_params.py

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@
1111

1212
from vllm.logger import init_logger
1313
from vllm.logits_process import LogitsProcessor
14+
from vllm.transformers_utils.tokenizer import AnyTokenizer
15+
from vllm.transformers_utils.tokenizers.mistral import MistralTokenizer
1416

1517
logger = init_logger(__name__)
1618

@@ -199,7 +201,6 @@ class SamplingParams(
199201
seed: Optional[int] = None
200202
stop: Optional[Union[str, list[str]]] = None
201203
stop_token_ids: Optional[list[int]] = None
202-
bad_words: Optional[list[str]] = None
203204
ignore_eos: bool = False
204205
max_tokens: Optional[int] = 16
205206
min_tokens: int = 0
@@ -228,6 +229,10 @@ class SamplingParams(
228229
logit_bias: Optional[dict[int, float]] = None
229230
allowed_token_ids: Optional[list[int]] = None
230231

232+
# Fields used for bad words
233+
bad_words: Optional[list[str]] = None
234+
_bad_words_token_ids: list[list[int]] = msgspec.field(default_factory=list)
235+
231236
@staticmethod
232237
def from_optional(
233238
n: Optional[int] = 1,
@@ -458,6 +463,46 @@ def update_from_generation_config(
458463
eos_ids.update(self.stop_token_ids)
459464
self.stop_token_ids = list(eos_ids)
460465

466+
def update_from_tokenizer(self, tokenizer: AnyTokenizer) -> None:
467+
if self.bad_words is None:
468+
return
469+
for bad_word in self.bad_words:
470+
# To prohibit words both at the beginning
471+
# and in the middle of text
472+
# (related to add_prefix_space tokenizer parameter)
473+
for add_prefix_space in [False, True]:
474+
prefix = " " if add_prefix_space else ""
475+
prompt = prefix + bad_word.lstrip()
476+
477+
if isinstance(tokenizer, MistralTokenizer):
478+
# Mistral tokenizers should not add special tokens
479+
prompt_token_ids = tokenizer.encode(text=prompt)
480+
else:
481+
prompt_token_ids = tokenizer.encode(
482+
text=prompt, add_special_tokens=False)
483+
484+
# If no space at the beginning
485+
# or if prefix space produces a new word token
486+
if (not add_prefix_space) or (
487+
add_prefix_space and prompt_token_ids[0]
488+
!= self._bad_words_token_ids[-1][0]
489+
and len(prompt_token_ids) == len(
490+
self._bad_words_token_ids[-1])):
491+
self._bad_words_token_ids.append(prompt_token_ids)
492+
493+
invalid_token_ids = [
494+
token_id for bad_words_token_ids in self._bad_words_token_ids
495+
for token_id in bad_words_token_ids
496+
if token_id < 0 or token_id > tokenizer.max_token_id
497+
]
498+
if len(invalid_token_ids) > 0:
499+
raise ValueError(
500+
f"The model vocabulary size is {tokenizer.max_token_id+1},"
501+
f" but the following tokens"
502+
f" were specified as bad: {invalid_token_ids}."
503+
f" All token id values should be integers satisfying:"
504+
f" 0 <= token_id <= {tokenizer.max_token_id}.")
505+
461506
@cached_property
462507
def sampling_type(self) -> SamplingType:
463508
if self.temperature < _SAMPLING_EPS:
@@ -470,6 +515,11 @@ def sampling_type(self) -> SamplingType:
470515
def all_stop_token_ids(self) -> set[int]:
471516
return self._all_stop_token_ids
472517

518+
@property
519+
def bad_words_token_ids(self) -> list[list[int]]:
520+
# For internal use only. Backward compatibility not guaranteed
521+
return self._bad_words_token_ids
522+
473523
def clone(self) -> "SamplingParams":
474524
"""Deep copy, but maybe not the LogitsProcessor objects.
475525

vllm/v1/engine/processor.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,8 @@ def process_inputs(
209209
sampling_params = params.clone()
210210
sampling_params.update_from_generation_config(
211211
self.generation_config_fields, eos_token_id)
212+
sampling_params.update_from_tokenizer(
213+
self.tokenizer.get_lora_tokenizer(lora_request))
212214

213215
# Multimodal related.
214216
# Compute MM hashes (if enabled)

vllm/v1/sample/metadata.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,3 +38,6 @@ class SamplingMetadata:
3838
# `allowed_token_ids_mask` is a 2D bool tensor of shape (max batch size,
3939
# vocab size).
4040
allowed_token_ids_mask: Optional[torch.Tensor]
41+
42+
# req_index -> bad_words_token_ids
43+
bad_words_token_ids: dict[int, list[list[int]]]

vllm/v1/sample/ops/bad_words.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# SPDX-License-Identifier: Apache-2.0
2+
3+
import torch
4+
5+
_SMALLEST_LOGIT = float("-inf")
6+
7+
8+
def _apply_bad_words_single_batch(
9+
logits: torch.Tensor,
10+
bad_words_token_ids: list[list[int]],
11+
past_tokens_ids: list[int],
12+
) -> None:
13+
for bad_word_ids in bad_words_token_ids:
14+
if len(bad_word_ids) > len(past_tokens_ids) + 1:
15+
continue
16+
17+
prefix_length = len(bad_word_ids) - 1
18+
last_token_id = bad_word_ids[-1]
19+
if prefix_length > 0:
20+
actual_prefix = past_tokens_ids[-prefix_length:]
21+
else:
22+
actual_prefix = []
23+
expected_prefix = bad_word_ids[:prefix_length]
24+
25+
assert len(actual_prefix) == len(expected_prefix)
26+
27+
if actual_prefix == expected_prefix:
28+
logits[last_token_id] = _SMALLEST_LOGIT
29+
30+
31+
def apply_bad_words(
32+
logits: torch.Tensor,
33+
bad_words_token_ids: dict[int, list[list[int]]],
34+
past_tokens_ids: list[list[int]],
35+
) -> None:
36+
for i in range(logits.shape[0]):
37+
_apply_bad_words_single_batch(logits[i], bad_words_token_ids[i],
38+
past_tokens_ids[i])

vllm/v1/sample/sampler.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from vllm.v1.outputs import LogprobsTensors, SamplerOutput
88
from vllm.v1.sample.metadata import SamplingMetadata
9+
from vllm.v1.sample.ops.bad_words import apply_bad_words
910
from vllm.v1.sample.ops.penalties import (apply_all_penalties,
1011
apply_min_token_penalties)
1112
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler
@@ -38,6 +39,8 @@ def forward(
3839
logits = logits.to(torch.float32)
3940
# Apply allowed token ids.
4041
logits = self.apply_allowed_token_ids(logits, sampling_metadata)
42+
# Apply bad words exclusion.
43+
logits = self.apply_bad_words(logits, sampling_metadata)
4144
# Apply logits bias.
4245
logits = self.apply_logits_bias(logits, sampling_metadata)
4346
# Apply penalties (e.g., min_tokens, freq_penalties).
@@ -237,3 +240,16 @@ def apply_allowed_token_ids(
237240
logits.masked_fill_(sampling_metadata.allowed_token_ids_mask,
238241
float("-inf"))
239242
return logits
243+
244+
def apply_bad_words(
245+
self,
246+
logits: torch.Tensor,
247+
sampling_metadata: SamplingMetadata,
248+
) -> torch.Tensor:
249+
if sampling_metadata.bad_words_token_ids:
250+
apply_bad_words(
251+
logits,
252+
sampling_metadata.bad_words_token_ids,
253+
sampling_metadata.output_token_ids,
254+
)
255+
return logits

0 commit comments

Comments
 (0)