Skip to content

Reinstate best_of for V0 #14356

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Mar 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions tests/v1/sample/test_sampling_params_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ def test_n_gt_1(model):
assert len(outputs[0].outputs) == 3


def test_best_of(model):
"""Raise a ValueError since best_of is deprecated."""

params = SamplingParams(n=2, best_of=3)
with pytest.raises(ValueError):
_ = model.generate(PROMPT, params)


def test_penalties(model):
"""Check that we do not get errors if applied."""

Expand Down
6 changes: 5 additions & 1 deletion vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,11 @@ class LLM:
throughput. However, if the value is too high, it may cause out-of-
memory (OOM) errors.
swap_space: The size (GiB) of CPU memory per GPU to use as swap space.
Too small values may cause out-of-memory (OOM) errors.
This can be used for temporarily storing the states of the requests
when their `best_of` sampling parameters are larger than 1. If all
requests will have `best_of=1`, you can safely set this to 0.
Noting that `best_of` is only supported in V0. Otherwise, too small
values may cause out-of-memory (OOM) errors.
cpu_offload_gb: The size (GiB) of CPU memory to use for offloading
the model weights. This virtually increases the GPU memory space
you can use to hold the model weights, at the cost of CPU-GPU data
Expand Down
4 changes: 4 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,7 @@ class ChatCompletionRequest(OpenAIBaseModel):
user: Optional[str] = None

# doc: begin-chat-completion-sampling-params
best_of: Optional[int] = None
use_beam_search: bool = False
top_k: Optional[int] = None
min_p: Optional[float] = None
Expand Down Expand Up @@ -478,6 +479,7 @@ def to_sampling_params(

return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
Expand Down Expand Up @@ -648,6 +650,7 @@ class CompletionRequest(OpenAIBaseModel):
# https://platform.openai.com/docs/api-reference/completions/create
model: Optional[str] = None
prompt: Union[list[int], list[list[int]], str, list[str]]
best_of: Optional[int] = None
echo: Optional[bool] = False
frequency_penalty: Optional[float] = 0.0
logit_bias: Optional[dict[str, float]] = None
Expand Down Expand Up @@ -845,6 +848,7 @@ def to_sampling_params(

return SamplingParams.from_optional(
n=self.n,
best_of=self.best_of,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
repetition_penalty=repetition_penalty,
Expand Down
8 changes: 6 additions & 2 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,12 @@ async def create_completion(
model_name = self._get_model_name(request.model, lora_request)
num_prompts = len(engine_prompts)

# We do not stream the results when use beam search.
stream = (request.stream and not request.use_beam_search)
# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. Noting that best_of is only supported in V0. In addition,
# we do not stream the results when use beam search.
stream = (request.stream
and (request.best_of is None or request.n == request.best_of)
and not request.use_beam_search)

# Streaming response
if stream:
Expand Down
24 changes: 24 additions & 0 deletions vllm/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,10 @@ class SamplingParams(

Args:
n: Number of output sequences to return for the given prompt.
best_of: Number of output sequences that are generated from the prompt.
From these `best_of` sequences, the top `n` sequences are returned.
`best_of` must be greater than or equal to `n`. By default,
`best_of` is set to `n`. Warning, this is only supported in V0.
presence_penalty: Float that penalizes new tokens based on whether they
appear in the generated text so far. Values > 0 encourage the model
to use new tokens, while values < 0 encourage the model to repeat
Expand Down Expand Up @@ -183,6 +187,7 @@ class SamplingParams(
"""

n: int = 1
best_of: Optional[int] = None
_real_n: Optional[int] = None
presence_penalty: float = 0.0
frequency_penalty: float = 0.0
Expand Down Expand Up @@ -226,6 +231,7 @@ class SamplingParams(
@staticmethod
def from_optional(
n: Optional[int] = 1,
best_of: Optional[int] = None,
presence_penalty: Optional[float] = 0.0,
frequency_penalty: Optional[float] = 0.0,
repetition_penalty: Optional[float] = 1.0,
Expand Down Expand Up @@ -264,6 +270,7 @@ def from_optional(

return SamplingParams(
n=1 if n is None else n,
best_of=best_of,
presence_penalty=0.0
if presence_penalty is None else presence_penalty,
frequency_penalty=0.0
Expand Down Expand Up @@ -296,6 +303,20 @@ def from_optional(
)

def __post_init__(self) -> None:
# how we deal with `best_of``:
# if `best_of`` is not set, we default to `n`;
# if `best_of`` is set, we set `n`` to `best_of`,
# and set `_real_n`` to the original `n`.
# when we return the result, we will check
# if we need to return `n` or `_real_n` results
if self.best_of:
if self.best_of < self.n:
raise ValueError(
f"best_of must be greater than or equal to n, "
f"got n={self.n} and best_of={self.best_of}.")
if not self._real_n:
self._real_n = self.n
self.n = self.best_of

if 0 < self.temperature < _MAX_TEMP:
logger.warning(
Expand Down Expand Up @@ -402,6 +423,9 @@ def _verify_args(self) -> None:
raise ValueError(
"stop strings are only supported when detokenize is True. "
"Set detokenize=True to use stop.")
if self.best_of != self._real_n and self.output_kind == (
RequestOutputKind.DELTA):
raise ValueError("best_of must equal n to use output_kind=DELTA")

def _verify_greedy_sampling(self) -> None:
if self.n > 1:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/engine/processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,9 @@ def _validate_supported_sampling_params(
self,
params: SamplingParams,
) -> None:
# Best of not yet supported.
if params.best_of is not None and params.best_of > 1:
raise ValueError("VLLM V1 does not yet support best_of.")
# Bad words not yet supported.
if params.bad_words:
raise ValueError("VLLM V1 does not yet support bad_words.")
Expand Down