Skip to content

[Core] Add reload_weights RPC method #20096

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tests/v1/worker/test_gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -443,7 +443,7 @@ def test_load_model_weights_inplace(dist_init, model_runner, model_runner_2):
assert str(model_runner.get_model().state_dict()) != str(
model_runner_2.get_model().state_dict())
model_runner_2.load_config.load_format = original_load_format
model_runner_2.load_model() # Load real weights inplace
model_runner_2.reload_weights() # Load real weights inplace
assert str(model_runner.get_model().state_dict()) == str(
model_runner_2.get_model().state_dict())

Expand Down
19 changes: 8 additions & 11 deletions vllm/v1/worker/gpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1697,17 +1697,9 @@ def load_model(self) -> None:
with DeviceMemoryProfiler() as m: # noqa: SIM117
time_before_load = time.perf_counter()
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info(
"Model was already initialized. Loading weights inplace..."
)
model_loader.load_weights(self.model,
model_config=self.model_config)
logger.info("Loading model from scratch...")
self.model = model_loader.load_model(
vllm_config=self.vllm_config, model_config=self.model_config)
Comment on lines +1700 to +1702
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

This change alters the behavior of load_model. Previously, if the model was already loaded (hasattr(self, "model")), it would reload the weights in-place. Now, load_model will always reload the entire model from scratch. While this simplifies load_model and the new reload_weights method correctly encapsulates the in-place loading, this is a significant behavior change.

This could be an issue if other parts of the codebase call load_model expecting it to only reload weights. For instance, the test test_load_model_weights_inplace in tests/v1/worker/test_gpu_model_runner.py appears to rely on the old behavior by calling load_model to perform what the test describes as an in-place weight load.

If this behavior change is intentional, I recommend updating the relevant tests to use the new reload_weights method to accurately reflect the new API design. Otherwise, the previous conditional logic should be restored to avoid breaking existing functionality.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is intentional and test is now updated

if has_step_pooler(self.model):
self.input_batch.logits_processing_needs_token_ids = True
if self.lora_config:
Expand All @@ -1729,6 +1721,11 @@ def load_model(self) -> None:
time_after_load - time_before_load)
prepare_communication_buffer_for_model(self.model)

def reload_weights(self) -> None:
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)

def save_tensorized_model(
self,
tensorizer_config: "TensorizerConfig",
Expand Down
37 changes: 20 additions & 17 deletions vllm/v1/worker/gpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""A GPU worker class."""
import gc
import os
from contextlib import AbstractContextManager, nullcontext
from typing import TYPE_CHECKING, Optional

import torch
Expand Down Expand Up @@ -112,6 +113,19 @@ def wake_up(self, tags: Optional[list[str]] = None) -> None:
buffer.data.copy_(self._sleep_saved_buffers[name].data)
self._sleep_saved_buffers = {}

def _maybe_get_memory_pool_context(self,
tag: str) -> AbstractContextManager:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
if tag == "weights":
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag=tag)
else:
context = nullcontext()
return context

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
self.cache_config.num_gpu_blocks = num_gpu_blocks
Expand Down Expand Up @@ -172,18 +186,13 @@ def init_device(self):
# FIXME(youkaichao & ywang96): Use TorchDispatchMode instead of memory pool
# to hijack tensor allocation.
def load_model(self) -> None:
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
assert allocator.get_current_usage() == 0, (
"Sleep mode can only be "
"used for one instance per process.")
context = allocator.use_memory_pool(tag="weights")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
with self._maybe_get_memory_pool_context(tag="weights"):
self.model_runner.load_model()

def reload_weights(self) -> None:
with self._maybe_get_memory_pool_context(tag="weights"):
self.model_runner.reload_weights()

@torch.inference_mode()
def determine_available_memory(self) -> int:
"""Profiles the peak memory usage of the model to determine how much
Expand Down Expand Up @@ -240,13 +249,7 @@ def get_kv_cache_spec(self) -> dict[str, KVCacheSpec]:

def initialize_from_config(self, kv_cache_config: KVCacheConfig) -> None:
"""Allocate GPU KV cache with the specified kv_cache_config."""
if self.vllm_config.model_config.enable_sleep_mode:
allocator = CuMemAllocator.get_instance()
context = allocator.use_memory_pool(tag="kv_cache")
else:
from contextlib import nullcontext
context = nullcontext()
with context:
with self._maybe_get_memory_pool_context(tag="kv_cache"):
self.model_runner.initialize_kv_cache(kv_cache_config)

def compile_or_warm_up_model(self) -> None:
Expand Down
18 changes: 8 additions & 10 deletions vllm/v1/worker/tpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,16 +995,9 @@ def load_model(self) -> None:
else:
# model = get_model(vllm_config=self.vllm_config)
model_loader = get_model_loader(self.load_config)
if not hasattr(self, "model"):
logger.info("Loading model from scratch...")
model = model_loader.load_model(
vllm_config=self.vllm_config,
model_config=self.model_config)
else:
logger.info("Model was already initialized. \
Loading weights inplace...")
model_loader.load_weights(self.model,
model_config=self.model_config)
logger.info("Loading model from scratch...")
model = model_loader.load_model(vllm_config=self.vllm_config,
model_config=self.model_config)
if self.lora_config is not None:
model = self.load_lora_model(model, self.model_config,
self.scheduler_config,
Expand All @@ -1019,6 +1012,11 @@ def load_model(self) -> None:
self.model = model
self.sampler = TPUSampler()

def reload_weights(self) -> None:
model_loader = get_model_loader(self.load_config)
logger.info("Reloading weights inplace...")
model_loader.load_weights(self.model, model_config=self.model_config)

@torch.no_grad()
def _dummy_run(self, num_tokens: int) -> None:
if self.is_multimodal_model:
Expand Down
3 changes: 3 additions & 0 deletions vllm/v1/worker/tpu_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,6 +248,9 @@ def add_lora(self, lora_request: LoRARequest) -> bool:
def load_model(self) -> None:
self.model_runner.load_model()

def reload_weights(self) -> None:
self.model_runner.reload_weights()

def compile_or_warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model()
Expand Down