Skip to content

Commit

Permalink
[Frontend] Support override generation config in args (vllm-project#1…
Browse files Browse the repository at this point in the history
…2409)

Signed-off-by: liuyanyi <wolfsonliu@163.com>
  • Loading branch information
liuyanyi authored Jan 29, 2025
1 parent d93bf4d commit ff7424f
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 8 deletions.
70 changes: 70 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,3 +281,73 @@ def test_uses_mrope(model_id, uses_mrope):
)

assert config.uses_mrope == uses_mrope


def test_generation_config_loading():
model_id = "Qwen/Qwen2.5-1.5B-Instruct"

# When set generation_config to None, the default generation config
# will not be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None)
assert model_config.get_diff_sampling_param() == {}

# When set generation_config to "auto", the default generation config
# should be loaded.
model_config = ModelConfig(model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto")

correct_generation_config = {
"repetition_penalty": 1.1,
"temperature": 0.7,
"top_p": 0.8,
"top_k": 20,
}

assert model_config.get_diff_sampling_param() == correct_generation_config

# The generation config could be overridden by the user.
override_generation_config = {"temperature": 0.5, "top_k": 5}

model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config="auto",
override_generation_config=override_generation_config)

override_result = correct_generation_config.copy()
override_result.update(override_generation_config)

assert model_config.get_diff_sampling_param() == override_result

# When generation_config is set to None and override_generation_config
# is set, the override_generation_config should be used directly.
model_config = ModelConfig(
model_id,
task="auto",
tokenizer=model_id,
tokenizer_mode="auto",
trust_remote_code=False,
seed=0,
dtype="float16",
generation_config=None,
override_generation_config=override_generation_config)

assert model_config.get_diff_sampling_param() == override_generation_config
13 changes: 11 additions & 2 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,8 @@ class ModelConfig:
`logits_processors` extra completion argument. Defaults to None,
which allows no processors.
generation_config: Configuration parameter file for generation.
override_generation_config: Override the generation config with the
given config.
"""

def compute_hash(self) -> str:
Expand Down Expand Up @@ -225,6 +227,7 @@ def __init__(
logits_processor_pattern: Optional[str] = None,
generation_config: Optional[str] = None,
enable_sleep_mode: bool = False,
override_generation_config: Optional[Dict[str, Any]] = None,
) -> None:
self.model = model
self.tokenizer = tokenizer
Expand Down Expand Up @@ -368,6 +371,7 @@ def __init__(
self.logits_processor_pattern = logits_processor_pattern

self.generation_config = generation_config
self.override_generation_config = override_generation_config or {}

self._verify_quantization()
self._verify_cuda_graph()
Expand Down Expand Up @@ -904,8 +908,13 @@ def get_diff_sampling_param(self) -> Dict[str, Any]:
"""
if self.generation_config is None:
# When generation_config is not set
return {}
config = self.try_get_generation_config()
config = {}
else:
config = self.try_get_generation_config()

# Overriding with given generation config
config.update(self.override_generation_config)

available_params = [
"repetition_penalty",
"temperature",
Expand Down
25 changes: 19 additions & 6 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ class EngineArgs:
kv_transfer_config: Optional[KVTransferConfig] = None

generation_config: Optional[str] = None
override_generation_config: Optional[Dict[str, Any]] = None
enable_sleep_mode: bool = False

calculate_kv_scales: Optional[bool] = None
Expand Down Expand Up @@ -936,12 +937,23 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
type=nullable_str,
default=None,
help="The folder path to the generation config. "
"Defaults to None, will use the default generation config in vLLM. "
"If set to 'auto', the generation config will be automatically "
"loaded from model. If set to a folder path, the generation config "
"will be loaded from the specified folder path. If "
"`max_new_tokens` is specified, then it sets a server-wide limit "
"on the number of output tokens for all requests.")
"Defaults to None, no generation config is loaded, vLLM defaults "
"will be used. If set to 'auto', the generation config will be "
"loaded from model path. If set to a folder path, the generation "
"config will be loaded from the specified folder path. If "
"`max_new_tokens` is specified in generation config, then "
"it sets a server-wide limit on the number of output tokens "
"for all requests.")

parser.add_argument(
"--override-generation-config",
type=json.loads,
default=None,
help="Overrides or sets generation config in JSON format. "
"e.g. ``{\"temperature\": 0.5}``. If used with "
"--generation-config=auto, the override parameters will be merged "
"with the default config from the model. If generation-config is "
"None, only the override parameters are used.")

parser.add_argument("--enable-sleep-mode",
action="store_true",
Expand Down Expand Up @@ -1002,6 +1014,7 @@ def create_model_config(self) -> ModelConfig:
override_pooler_config=self.override_pooler_config,
logits_processor_pattern=self.logits_processor_pattern,
generation_config=self.generation_config,
override_generation_config=self.override_generation_config,
enable_sleep_mode=self.enable_sleep_mode,
)

Expand Down

0 comments on commit ff7424f

Please sign in to comment.