Skip to content

Commit a4f1ee3

Browse files
vincent-4b8zhonghmellor
authored
Deprecate best_of Sampling Parameter in anticipation for vLLM V1 (#13997)
Signed-off-by: vincent-4 <vincentzhongy+githubvincent4@gmail.com> Signed-off-by: Brayden Zhong <b8zhong@uwaterloo.ca> Signed-off-by: Harry Mellor <19981378+hmellor@users.noreply.github.com> Co-authored-by: Brayden Zhong <b8zhong@uwaterloo.ca> Co-authored-by: Harry Mellor <19981378+hmellor@users.noreply.github.com>
1 parent a32c866 commit a4f1ee3

File tree

12 files changed

+16
-88
lines changed

12 files changed

+16
-88
lines changed

benchmarks/backend_request_func.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,6 @@ class RequestFuncInput:
2727
output_len: int
2828
model: str
2929
model_name: Optional[str] = None
30-
best_of: int = 1
3130
logprobs: Optional[int] = None
3231
extra_body: Optional[dict] = None
3332
multi_modal_content: Optional[dict] = None
@@ -58,7 +57,6 @@ async def async_request_tgi(
5857
async with aiohttp.ClientSession(trust_env=True,
5958
timeout=AIOHTTP_TIMEOUT) as session:
6059
params = {
61-
"best_of": request_func_input.best_of,
6260
"max_new_tokens": request_func_input.output_len,
6361
"do_sample": True,
6462
"temperature": 0.01, # TGI does not accept 0.0 temperature.
@@ -130,7 +128,6 @@ async def async_request_trt_llm(
130128

131129
async with aiohttp.ClientSession(trust_env=True,
132130
timeout=AIOHTTP_TIMEOUT) as session:
133-
assert request_func_input.best_of == 1
134131
payload = {
135132
"accumulate_tokens": True,
136133
"text_input": request_func_input.prompt,
@@ -195,7 +192,6 @@ async def async_request_deepspeed_mii(
195192
) -> RequestFuncOutput:
196193
async with aiohttp.ClientSession(trust_env=True,
197194
timeout=AIOHTTP_TIMEOUT) as session:
198-
assert request_func_input.best_of == 1
199195

200196
payload = {
201197
"prompt": request_func_input.prompt,
@@ -249,7 +245,6 @@ async def async_request_openai_completions(
249245
if request_func_input.model_name else request_func_input.model,
250246
"prompt": request_func_input.prompt,
251247
"temperature": 0.0,
252-
"best_of": request_func_input.best_of,
253248
"max_tokens": request_func_input.output_len,
254249
"logprobs": request_func_input.logprobs,
255250
"stream": True,

benchmarks/benchmark_serving.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -560,7 +560,6 @@ async def benchmark(
560560
tokenizer: PreTrainedTokenizerBase,
561561
input_requests: list[tuple[str, int, int]],
562562
logprobs: Optional[int],
563-
best_of: int,
564563
request_rate: float,
565564
burstiness: float,
566565
disable_tqdm: bool,
@@ -592,7 +591,6 @@ async def benchmark(
592591
prompt_len=test_prompt_len,
593592
output_len=test_output_len,
594593
logprobs=logprobs,
595-
best_of=best_of,
596594
multi_modal_content=test_mm_content,
597595
ignore_eos=ignore_eos,
598596
)
@@ -619,7 +617,6 @@ async def benchmark(
619617
prompt_len=test_prompt_len,
620618
output_len=test_output_len,
621619
logprobs=logprobs,
622-
best_of=best_of,
623620
multi_modal_content=test_mm_content,
624621
ignore_eos=ignore_eos)
625622
profile_output = await request_func(request_func_input=profile_input)
@@ -668,7 +665,6 @@ async def limited_request_func(request_func_input, pbar):
668665
prompt_len=prompt_len,
669666
output_len=output_len,
670667
logprobs=logprobs,
671-
best_of=best_of,
672668
multi_modal_content=mm_content,
673669
ignore_eos=ignore_eos)
674670
tasks.append(
@@ -686,7 +682,6 @@ async def limited_request_func(request_func_input, pbar):
686682
prompt_len=test_prompt_len,
687683
output_len=test_output_len,
688684
logprobs=logprobs,
689-
best_of=best_of,
690685
)
691686
profile_output = await request_func(request_func_input=profile_input)
692687
if profile_output.success:
@@ -958,7 +953,6 @@ def main(args: argparse.Namespace):
958953
tokenizer=tokenizer,
959954
input_requests=input_requests,
960955
logprobs=args.logprobs,
961-
best_of=args.best_of,
962956
request_rate=args.request_rate,
963957
burstiness=args.burstiness,
964958
disable_tqdm=args.disable_tqdm,
@@ -983,7 +977,6 @@ def main(args: argparse.Namespace):
983977
result_json["backend"] = backend
984978
result_json["model_id"] = model_id
985979
result_json["tokenizer_id"] = tokenizer_id
986-
result_json["best_of"] = args.best_of
987980
result_json["num_prompts"] = args.num_prompts
988981

989982
# Metadata
@@ -1081,13 +1074,6 @@ def main(args: argparse.Namespace):
10811074
help=
10821075
"Name or path of the tokenizer, if not using the default tokenizer.", # noqa: E501
10831076
)
1084-
parser.add_argument(
1085-
"--best-of",
1086-
type=int,
1087-
default=1,
1088-
help="Generates `best_of` sequences per prompt and "
1089-
"returns the best one.",
1090-
)
10911077
parser.add_argument("--use-beam-search", action="store_true")
10921078
parser.add_argument(
10931079
"--num-prompts",

examples/offline_inference/llm_engine_example.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@ def create_test_prompts() -> list[tuple[str, SamplingParams]]:
1515
SamplingParams(temperature=0.8, top_k=5, presence_penalty=0.2)),
1616
("What is the meaning of life?",
1717
SamplingParams(n=2,
18-
best_of=5,
1918
temperature=0.8,
2019
top_p=0.95,
2120
frequency_penalty=0.1)),

examples/online_serving/opentelemetry/dummy_client.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
"model": "facebook/opt-125m",
2929
"prompt": prompt,
3030
"max_tokens": 10,
31-
"best_of": 20,
3231
"n": 3,
3332
"use_beam_search": "true",
3433
"temperature": 0.0,

tests/core/test_scheduler.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,6 @@ def test_schedule_decode_blocks_to_copy_update():
617617
num_gpu_blocks=16)
618618
_, seq_group = create_dummy_prompt("1",
619619
prompt_length=60,
620-
best_of=2,
621620
block_size=block_size)
622621
curr_loras = None
623622
scheduler._allocate_and_set_running(seq_group)
@@ -686,7 +685,6 @@ def test_schedule_swapped_cannot_swap_in():
686685
for i in range(2):
687686
_, seq_group = create_dummy_prompt(str(i),
688687
prompt_length=60,
689-
best_of=2,
690688
block_size=block_size)
691689
scheduler._allocate_and_set_running(seq_group)
692690
append_new_token_seq_group(60, seq_group, 1)
@@ -717,7 +715,6 @@ def test_infeasible_swap():
717715
for i in range(2):
718716
_, seq_group = create_dummy_prompt(str(i),
719717
prompt_length=60,
720-
best_of=2,
721718
block_size=block_size)
722719
scheduler._allocate_and_set_running(seq_group)
723720
append_new_token_seq_group(60, seq_group, 1)
@@ -747,7 +744,6 @@ def test_schedule_swapped_blocks_to_copy():
747744
curr_loras = None
748745
_, seq_group = create_dummy_prompt("1",
749746
prompt_length=60,
750-
best_of=2,
751747
block_size=block_size)
752748
scheduler._allocate_and_set_running(seq_group)
753749
append_new_token_seq_group(60, seq_group, 1)

tests/core/utils.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@ def create_dummy_prompt(
1818
prompt_length: int = -1,
1919
block_size: Optional[int] = None,
2020
lora_request: Optional[LoRARequest] = None,
21-
best_of: int = 1,
2221
prompt_tokens: Optional[list[int]] = None,
2322
min_tokens: int = 0,
2423
max_tokens: int = 16,
@@ -32,17 +31,19 @@ def create_dummy_prompt(
3231
prompt_tokens = list(range(prompt_length))
3332

3433
prompt_str = " ".join([str(t) for t in prompt_tokens])
35-
prompt = Sequence(int(request_id),
36-
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
37-
block_size=block_size)
38-
seq_group = SequenceGroup(request_id=request_id,
39-
seqs=[prompt],
40-
arrival_time=time.time(),
41-
sampling_params=SamplingParams(
42-
best_of=best_of,
43-
max_tokens=max_tokens,
44-
min_tokens=min_tokens),
45-
lora_request=lora_request)
34+
prompt = Sequence(
35+
int(request_id),
36+
inputs=token_inputs(prompt_tokens, prompt=prompt_str),
37+
block_size=block_size,
38+
)
39+
seq_group = SequenceGroup(
40+
request_id=request_id,
41+
seqs=[prompt],
42+
arrival_time=time.time(),
43+
sampling_params=SamplingParams(max_tokens=max_tokens,
44+
min_tokens=min_tokens),
45+
lora_request=lora_request,
46+
)
4647

4748
return prompt, seq_group
4849

@@ -72,7 +73,6 @@ def create_dummy_prompt_encoder_decoder(
7273
encoder_prompt_length: int,
7374
block_size: Optional[int] = None,
7475
lora_request: Optional[LoRARequest] = None,
75-
best_of: int = 1,
7676
) -> tuple[Sequence, Sequence, SequenceGroup]:
7777
if not block_size:
7878
block_size = decoder_prompt_length
@@ -102,7 +102,6 @@ def create_dummy_prompt_encoder_decoder(
102102

103103
seq_group = SequenceGroup(request_id=request_id,
104104
seqs=[decoder_prompt],
105-
sampling_params=SamplingParams(best_of=best_of),
106105
arrival_time=time.time(),
107106
lora_request=lora_request,
108107
encoder_seq=encoder_prompt)

tests/v1/sample/test_sampling_params_e2e.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,6 @@ def test_n_gt_1(model):
2525
assert len(outputs[0].outputs) == 3
2626

2727

28-
def test_best_of(model):
29-
"""Raise a ValueError since best_of is deprecated."""
30-
31-
params = SamplingParams(n=2, best_of=3)
32-
with pytest.raises(ValueError):
33-
_ = model.generate(PROMPT, params)
34-
35-
3628
def test_penalties(model):
3729
"""Check that we do not get errors if applied."""
3830

vllm/entrypoints/llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -97,10 +97,7 @@ class LLM:
9797
throughput. However, if the value is too high, it may cause out-of-
9898
memory (OOM) errors.
9999
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
100-
This can be used for temporarily storing the states of the requests
101-
when their `best_of` sampling parameters are larger than 1. If all
102-
requests will have `best_of=1`, you can safely set this to 0.
103-
Otherwise, too small values may cause out-of-memory (OOM) errors.
100+
Too small values may cause out-of-memory (OOM) errors.
104101
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
105102
the model weights. This virtually increases the GPU memory space
106103
you can use to hold the model weights, at the cost of CPU-GPU data

vllm/entrypoints/openai/protocol.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,6 @@ class ChatCompletionRequest(OpenAIBaseModel):
242242
user: Optional[str] = None
243243

244244
# doc: begin-chat-completion-sampling-params
245-
best_of: Optional[int] = None
246245
use_beam_search: bool = False
247246
top_k: Optional[int] = None
248247
min_p: Optional[float] = None
@@ -479,7 +478,6 @@ def to_sampling_params(
479478

480479
return SamplingParams.from_optional(
481480
n=self.n,
482-
best_of=self.best_of,
483481
presence_penalty=self.presence_penalty,
484482
frequency_penalty=self.frequency_penalty,
485483
repetition_penalty=repetition_penalty,
@@ -650,7 +648,6 @@ class CompletionRequest(OpenAIBaseModel):
650648
# https://platform.openai.com/docs/api-reference/completions/create
651649
model: Optional[str] = None
652650
prompt: Union[list[int], list[list[int]], str, list[str]]
653-
best_of: Optional[int] = None
654651
echo: Optional[bool] = False
655652
frequency_penalty: Optional[float] = 0.0
656653
logit_bias: Optional[dict[str, float]] = None
@@ -848,7 +845,6 @@ def to_sampling_params(
848845

849846
return SamplingParams.from_optional(
850847
n=self.n,
851-
best_of=self.best_of,
852848
presence_penalty=self.presence_penalty,
853849
frequency_penalty=self.frequency_penalty,
854850
repetition_penalty=repetition_penalty,

vllm/entrypoints/openai/serving_completion.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -168,12 +168,8 @@ async def create_completion(
168168
model_name = self._get_model_name(request.model, lora_request)
169169
num_prompts = len(engine_prompts)
170170

171-
# Similar to the OpenAI API, when n != best_of, we do not stream the
172-
# results. In addition, we do not stream the results when use
173-
# beam search.
174-
stream = (request.stream
175-
and (request.best_of is None or request.n == request.best_of)
176-
and not request.use_beam_search)
171+
# We do not stream the results when use beam search.
172+
stream = (request.stream and not request.use_beam_search)
177173

178174
# Streaming response
179175
if stream:

vllm/sampling_params.py

Lines changed: 0 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -116,10 +116,6 @@ class SamplingParams(
116116
117117
Args:
118118
n: Number of output sequences to return for the given prompt.
119-
best_of: Number of output sequences that are generated from the prompt.
120-
From these `best_of` sequences, the top `n` sequences are returned.
121-
`best_of` must be greater than or equal to `n`. By default,
122-
`best_of` is set to `n`.
123119
presence_penalty: Float that penalizes new tokens based on whether they
124120
appear in the generated text so far. Values > 0 encourage the model
125121
to use new tokens, while values < 0 encourage the model to repeat
@@ -187,7 +183,6 @@ class SamplingParams(
187183
"""
188184

189185
n: int = 1
190-
best_of: Optional[int] = None
191186
_real_n: Optional[int] = None
192187
presence_penalty: float = 0.0
193188
frequency_penalty: float = 0.0
@@ -231,7 +226,6 @@ class SamplingParams(
231226
@staticmethod
232227
def from_optional(
233228
n: Optional[int] = 1,
234-
best_of: Optional[int] = None,
235229
presence_penalty: Optional[float] = 0.0,
236230
frequency_penalty: Optional[float] = 0.0,
237231
repetition_penalty: Optional[float] = 1.0,
@@ -270,7 +264,6 @@ def from_optional(
270264

271265
return SamplingParams(
272266
n=1 if n is None else n,
273-
best_of=best_of,
274267
presence_penalty=0.0
275268
if presence_penalty is None else presence_penalty,
276269
frequency_penalty=0.0
@@ -303,20 +296,6 @@ def from_optional(
303296
)
304297

305298
def __post_init__(self) -> None:
306-
# how we deal with `best_of``:
307-
# if `best_of`` is not set, we default to `n`;
308-
# if `best_of`` is set, we set `n`` to `best_of`,
309-
# and set `_real_n`` to the original `n`.
310-
# when we return the result, we will check
311-
# if we need to return `n` or `_real_n` results
312-
if self.best_of:
313-
if self.best_of < self.n:
314-
raise ValueError(
315-
f"best_of must be greater than or equal to n, "
316-
f"got n={self.n} and best_of={self.best_of}.")
317-
if not self._real_n:
318-
self._real_n = self.n
319-
self.n = self.best_of
320299

321300
if 0 < self.temperature < _MAX_TEMP:
322301
logger.warning(
@@ -423,9 +402,6 @@ def _verify_args(self) -> None:
423402
raise ValueError(
424403
"stop strings are only supported when detokenize is True. "
425404
"Set detokenize=True to use stop.")
426-
if self.best_of != self._real_n and self.output_kind == (
427-
RequestOutputKind.DELTA):
428-
raise ValueError("best_of must equal n to use output_kind=DELTA")
429405

430406
def _verify_greedy_sampling(self) -> None:
431407
if self.n > 1:

vllm/v1/engine/processor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -93,9 +93,6 @@ def _validate_supported_sampling_params(
9393
self,
9494
params: SamplingParams,
9595
) -> None:
96-
# Best of not yet supported.
97-
if params.best_of:
98-
raise ValueError("VLLM V1 does not yet support best_of.")
9996
# Bad words not yet supported.
10097
if params.bad_words:
10198
raise ValueError("VLLM V1 does not yet support bad_words.")

0 commit comments

Comments
 (0)