Skip to content

Default to generation_config from model #12622

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 15 commits into from
Mar 8, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/entrypoints/openai/correctness/test_lmeval.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
TASK = "gsm8k"
FILTER = "exact_match,strict-match"
RTOL = 0.03
EXPECTED_VALUE = 0.58
EXPECTED_VALUE = 0.54
DEFAULT_ARGS = ["--max-model-len", "4096", "--disable-log-requests"]
MORE_ARGS_LIST = [
[], # Default
Expand Down
1 change: 1 addition & 0 deletions tests/entrypoints/openai/test_serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class MockModelConfig:
diff_sampling_param: Optional[dict] = None
allowed_local_media_path: str = ""
encoder_config = None
generation_config: str = "auto"

def get_diff_sampling_param(self):
return self.diff_sampling_param or {}
Expand Down
8 changes: 4 additions & 4 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ def test_uses_mrope(model_id, uses_mrope):
def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"

# When set generation_config to None, the default generation config
# When set generation_config to "vllm", the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
Expand All @@ -298,7 +298,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
generation_config="vllm")
assert model_config.get_diff_sampling_param() == {}

# When set generation_config to "auto", the default generation config
Expand Down Expand Up @@ -340,7 +340,7 @@ def test_generation_config_loading():

assert model_config.get_diff_sampling_param() == override_result

# When generation_config is set to None and override_generation_config
# When generation_config is set to "vllm" and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
Expand All @@ -350,7 +350,7 @@ def test_generation_config_loading():
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
generation_config="vllm",
override_generation_config=override_generation_config)

assert model_config.get_diff_sampling_param() == override_generation_config
15 changes: 6 additions & 9 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def __init__(
override_neuron_config: Optional[dict[str, Any]] = None,
override_pooler_config: Optional["PoolerConfig"] = None,
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
generation_config: str = "auto",
enable_sleep_mode: bool = False,
override_generation_config: Optional[dict[str, Any]] = None,
model_impl: Union[str, ModelImpl] = ModelImpl.AUTO,
Expand Down Expand Up @@ -951,7 +951,7 @@ def get_multimodal_config(self) -> "MultiModalConfig":
return self.multimodal_config

def try_get_generation_config(self) -> dict[str, Any]:
if self.generation_config is None or self.generation_config == "auto":
if self.generation_config in ("auto", "vllm"):
config = try_get_generation_config(
self.hf_config_path or self.model,
trust_remote_code=self.trust_remote_code,
Expand All @@ -971,17 +971,14 @@ def try_get_generation_config(self) -> dict[str, Any]:
def get_diff_sampling_param(self) -> dict[str, Any]:
"""
This method returns a dictionary containing the parameters
that differ from the default sampling parameters, but only
if `generation_config` is set. If `generation_config` is not
set, an empty dictionary is returned.
that differ from the default sampling parameters. If
`generation_config` is `"vllm"`, an empty dictionary is returned.

Returns:
dict[str, Any]: A dictionary with the differing sampling
parameters if `generation_config` is set, otherwise an
empty dictionary.
parameters, if `generation_config` is `"vllm"` an empty dictionary.
"""
if self.generation_config is None:
# When generation_config is not set
if self.generation_config == "vllm":
config = {}
else:
config = self.try_get_generation_config()
Expand Down
14 changes: 7 additions & 7 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,7 +205,7 @@ class EngineArgs:

kv_transfer_config: Optional[KVTransferConfig] = None

generation_config: Optional[str] = None
generation_config: Optional[str] = "auto"
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False
model_impl: str = "auto"
Expand Down Expand Up @@ -1018,13 +1018,13 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
parser.add_argument(
"--generation-config",
type=nullable_str,
default=None,
default="auto",
help="The folder path to the generation config. "
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"Defaults to 'auto', the generation config will be loaded from "
"model path. If set to 'vllm', no generation config is loaded, "
"vLLM defaults will be used. If set to a folder path, the "
"generation config will be loaded from the specified folder path. "
"If `max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")

Expand Down
6 changes: 4 additions & 2 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,8 +109,10 @@ def __init__(
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info("Overwriting default chat sampling param with: %s",
self.default_sampling_params)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default chat sampling params from %s: %s",
source, self.default_sampling_params)

async def create_chat_completion(
self,
Expand Down
7 changes: 4 additions & 3 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@ def __init__(
self.default_sampling_params = (
self.model_config.get_diff_sampling_param())
if self.default_sampling_params:
logger.info(
"Overwriting default completion sampling param with: %s",
self.default_sampling_params)
source = self.model_config.generation_config
source = "model" if source == "auto" else source
logger.info("Using default completion sampling params from %s: %s",
source, self.default_sampling_params)

async def create_completion(
self,
Expand Down