Skip to content

[Core] Add update_config RPC method #20095

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from vllm.compilation.backends import VllmBackend
from vllm.config import (LoadConfig, ModelConfig, PoolerConfig, VllmConfig,
config, get_field)
config, get_field, update_config)
from vllm.model_executor.layers.pooler import PoolingType
from vllm.platforms import current_platform

Expand Down Expand Up @@ -79,6 +79,34 @@ def test_get_field():
assert c.default_factory is MISSING


@dataclass
class _TestNestedConfig:
a: _TestConfigFields = field(
default_factory=lambda: _TestConfigFields(a=0))


def test_update_config():
# Simple update
config1 = _TestConfigFields(a=0)
new_config1 = update_config(config1, {"a": 42})
assert new_config1.a == 42
# Nonexistent field
with pytest.raises(AssertionError):
new_config1 = update_config(config1, {"nonexistent": 1})
# Nested update with dataclass
config2 = _TestNestedConfig()
new_inner_config = _TestConfigFields(a=1, c="new_value")
new_config2 = update_config(config2, {"a": new_inner_config})
assert new_config2.a == new_inner_config
# Nested update with dict
config3 = _TestNestedConfig()
new_config3 = update_config(config3, {"a": {"c": "new_value"}})
assert new_config3.a.c == "new_value"
# Nested update with invalid type
with pytest.raises(AssertionError):
new_config3 = update_config(config3, {"a": "new_value"})


@pytest.mark.parametrize(
("model_id", "expected_runner_type", "expected_task"),
[
Expand Down
16 changes: 14 additions & 2 deletions tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,16 +433,28 @@ def rnd_stride_order():
assert all(not kv.is_contiguous() for kv in model_runner.kv_caches)


def test_update_config(model_runner):
# Simple update
model_runner.update_config({"load_config": {"load_format": "dummy"}})
assert model_runner.load_config.load_format == "dummy"
# Raise error on non-existing config
with pytest.raises(AssertionError):
model_runner.update_config({"do_not_exist_config": "dummy"})


def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
# In this test, model_runner loads model + weights in one go, while
# model_runner_2 loads dummy weights first then load real weights inplace
model_runner.load_model()
original_load_format = model_runner_2.load_config.load_format
model_runner_2.load_config.load_format = "dummy"
model_runner_2.update_config({"load_config": {"load_format": "dummy"}})
model_runner_2.load_model() # Initial model loading with dummy weights
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.update_config(
{"load_config": {
"load_format": original_load_format
}})
model_runner_2.load_model() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())
Expand Down
21 changes: 20 additions & 1 deletion vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
ConfigType = type[DataclassInstance]
HfOverrides = Union[dict, Callable[[type], type]]
else:
DataclassInstance = Any
PlacementGroup = Any
PretrainedConfig = Any
ExecutorBase = Any
Expand All @@ -87,7 +88,7 @@
"vllm.model_executor.models")

logger = init_logger(__name__)

DataclassInstanceT = TypeVar("DataclassInstanceT", bound=DataclassInstance)
ConfigT = TypeVar("ConfigT", bound=ConfigType)

TaskOption = Literal["auto", "generate", "embedding", "embed", "classify",
Expand Down Expand Up @@ -4842,3 +4843,21 @@ def get_layers_from_vllm_config(vllm_config: VllmConfig,
vllm_config.compilation_config.static_forward_context.items()
if isinstance(layer, layer_type)
}


def update_config(config: DataclassInstanceT,
overrides: dict[str, Any]) -> DataclassInstanceT:
processed_overrides = {}
for field_name, value in overrides.items():
assert hasattr(
config, field_name), f"{type(config)} has no field `{field_name}`"
current_value = getattr(config, field_name)
if is_dataclass(current_value) and not is_dataclass(value):
assert isinstance(value, dict), (
f"Overrides to {type(config)}.{field_name} must be a dict"
f" or {type(current_value)}, but got {type(value)}")
value = update_config(
current_value, # type: ignore[type-var]
value)
processed_overrides[field_name] = value
return replace(config, **processed_overrides)
12 changes: 11 additions & 1 deletion vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from vllm.attention.layer import Attention
from vllm.compilation.counter import compilation_counter
from vllm.config import (CompilationLevel, VllmConfig,
get_layers_from_vllm_config)
get_layers_from_vllm_config, update_config)
from vllm.distributed.eplb.eplb_state import EplbState
from vllm.distributed.kv_transfer import (get_kv_transfer_group,
has_kv_transfer_group)
Expand Down Expand Up @@ -1721,6 +1721,16 @@ def generate_draft_token_ids(
draft_token_ids.append(drafter_output.tolist())
return draft_token_ids

def update_config(self, overrides: dict[str, Any]) -> None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this function feels a bit scary to to be really honest. due to:
1/ not every config would be updatable even if they exist -- for example updating parallel_config probably wouldn't work :(
2/ do we guarantee that the model runner always read values from self.xxxx_config not vllm_config.xxxx_config?
3/ how do we ensure the new config is a valid config for its type?

potentially we can limit updates to limited known good configs first

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. very good point. I've updated the PR to restrict the change to load_config and model_config for now, to fulfill our purpose of model/weights update
  2. this is messy in model runner itself, we should perhaps clean up in a separate PR.
  3. pydantic config validation still runs as we do dataclasses.replace

allowed_config_names = {"load_config", "model_config"}
for config_name, config_overrides in overrides.items():
assert config_name in allowed_config_names, \
f"Config `{config_name}` not supported. " \
f"Allowed configs: {allowed_config_names}"
config = getattr(self, config_name)
new_config = update_config(config, config_overrides)
setattr(self, config_name, new_config)

def load_model(self) -> None:
logger.info("Starting to load model %s...", self.model_config.model)
with DeviceMemoryProfiler() as m: # noqa: SIM117
Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
"""A GPU worker class."""
import gc
import os
from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Any, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -184,6 +184,9 @@ def load_model(self) -> None:
with context:
self.model_runner.load_model()

def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)

@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
Expand Down
12 changes: 11 additions & 1 deletion vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import bisect
import dataclasses
import gc
import time
from typing import TYPE_CHECKING, Optional, cast
from typing import TYPE_CHECKING, Any, Optional, cast
from unittest.mock import patch

import numpy as np
Expand Down Expand Up @@ -1104,6 +1105,15 @@ def concat_lists(input_lists):

return model_runner_output

def update_config(self, overrides: dict[str, Any]) -> None:
for config_name, config_overrides in overrides.items():
try:
config = getattr(self, config_name)
except AttributeError as exc:
raise ValueError(f"Unknown config {config_name}") from exc
new_config = dataclasses.replace(config, **config_overrides)
setattr(self, config_name, new_config)

def load_model(self) -> None:
self.device = self.device_config.device

Expand Down
5 changes: 4 additions & 1 deletion vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""A TPU worker class."""
import os
from typing import Optional
from typing import Any, Optional

import torch
import torch.distributed
Expand Down Expand Up @@ -259,6 +259,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def load_model(self) -> None:
self.model_runner.load_model()

def update_config(self, overrides: dict[str, Any]) -> None:
self.model_runner.update_config(overrides)

def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
Expand Down