Skip to content

Support non-string values in JSON keys from CLI #19471

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jun 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 17 additions & 17 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,32 @@
from vllm.platforms import current_platform


class TestConfig1:
class _TestConfig1:
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added underscore to the names to silence the pytest warnings, since normally classes beginning with Test are supposed to be test suites.

pass


@dataclass
class TestConfig2:
class _TestConfig2:
a: int
"""docstring"""


@dataclass
class TestConfig3:
class _TestConfig3:
a: int = 1


@dataclass
class TestConfig4:
class _TestConfig4:
a: Union[Literal[1], Literal[2]] = 1
"""docstring"""


@pytest.mark.parametrize(("test_config", "expected_error"), [
(TestConfig1, "must be a dataclass"),
(TestConfig2, "must have a default"),
(TestConfig3, "must have a docstring"),
(TestConfig4, "must use a single Literal"),
(_TestConfig1, "must be a dataclass"),
(_TestConfig2, "must have a default"),
(_TestConfig3, "must have a docstring"),
(_TestConfig4, "must use a single Literal"),
])
def test_config(test_config, expected_error):
with pytest.raises(Exception, match=expected_error):
Expand All @@ -57,23 +57,23 @@ def test_compile_config_repr_succeeds():
assert 'inductor_passes' in val


def test_get_field():
@dataclass
class _TestConfigFields:
a: int
b: dict = field(default_factory=dict)
c: str = "default"

@dataclass
class TestConfig:
a: int
b: dict = field(default_factory=dict)
c: str = "default"

def test_get_field():
with pytest.raises(ValueError):
get_field(TestConfig, "a")
get_field(_TestConfigFields, "a")

b = get_field(TestConfig, "b")
b = get_field(_TestConfigFields, "b")
assert isinstance(b, Field)
assert b.default is MISSING
assert b.default_factory is dict

c = get_field(TestConfig, "c")
c = get_field(_TestConfigFields, "c")
assert isinstance(c, Field)
assert c.default == "default"
assert c.default_factory is MISSING
Expand Down
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,15 @@ def test_dict_args(parser):
"val5",
"--hf_overrides.key-7.key_8",
"val6",
# Test data type detection
"--hf_overrides.key9",
"100",
"--hf_overrides.key10",
"100.0",
"--hf_overrides.key11",
"true",
"--hf_overrides.key12.key13",
"null",
]
parsed_args = parser.parse_args(args)
assert parsed_args.model_name == "something.something"
Expand All @@ -286,6 +295,12 @@ def test_dict_args(parser):
"key-7": {
"key_8": "val6",
},
"key9": 100,
"key10": 100.0,
"key11": True,
"key12": {
"key13": None,
},
}


Expand Down
23 changes: 16 additions & 7 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1466,7 +1466,7 @@ def repl(match: re.Match) -> str:
pattern = re.compile(r"(?<=--)[^\.]*")

# Convert underscores to dashes and vice versa in argument names
processed_args = []
processed_args = list[str]()
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Improve the type annotations

for arg in args:
if arg.startswith('--'):
if '=' in arg:
Expand All @@ -1483,7 +1483,7 @@ def repl(match: re.Match) -> str:
else:
processed_args.append(arg)

def create_nested_dict(keys: list[str], value: str):
def create_nested_dict(keys: list[str], value: str) -> dict[str, Any]:
"""Creates a nested dictionary from a list of keys and a value.

For example, `keys = ["a", "b", "c"]` and `value = 1` will create:
Expand All @@ -1494,27 +1494,36 @@ def create_nested_dict(keys: list[str], value: str):
nested_dict = {key: nested_dict}
return nested_dict

def recursive_dict_update(original: dict, update: dict):
def recursive_dict_update(
original: dict[str, Any],
update: dict[str, Any],
):
"""Recursively updates a dictionary with another dictionary."""
for k, v in update.items():
if isinstance(v, dict) and isinstance(original.get(k), dict):
recursive_dict_update(original[k], v)
else:
original[k] = v

delete = set()
dict_args: dict[str, dict] = defaultdict(dict)
delete = set[int]()
dict_args = defaultdict[str, dict[str, Any]](dict)
for i, processed_arg in enumerate(processed_args):
if processed_arg.startswith("--") and "." in processed_arg:
if "=" in processed_arg:
processed_arg, value = processed_arg.split("=", 1)
processed_arg, value_str = processed_arg.split("=", 1)
if "." not in processed_arg:
# False positive, . was only in the value
continue
else:
value = processed_args[i + 1]
value_str = processed_args[i + 1]
delete.add(i + 1)

key, *keys = processed_arg.split(".")
try:
value = json.loads(value_str)
except json.decoder.JSONDecodeError:
value = value_str
Comment on lines +1522 to +1525
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the main change in this PR which enables non-string values to be parsed correctly


# Merge all values with the same key into a single dict
arg_dict = create_nested_dict(keys, value)
recursive_dict_update(dict_args[key], arg_dict)
Expand Down