diff --git a/vllm/utils.py b/vllm/utils.py index 9cdf623379516..5b94067cec777 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -3,6 +3,7 @@ import socket import subprocess import uuid +import gc from platform import uname from typing import List, Tuple, Union from packaging.version import parse, Version @@ -309,3 +310,27 @@ def create_kv_caches_with_random( f"Does not support value cache of type {cache_dtype}") value_caches.append(value_cache) return key_caches, value_caches + + +class measure_cuda_memory: + + def __init__(self, device=None): + self.device = device + + def current_memory_usage(self) -> float: + # Return the memory usage in bytes. + torch.cuda.reset_peak_memory_stats(self.device) + mem = torch.cuda.max_memory_allocated(self.device) + return mem + + def __enter__(self): + self.initial_memory = self.current_memory_usage() + # This allows us to call methods of the context manager if needed + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.final_memory = self.current_memory_usage() + self.consumed_memory = self.final_memory - self.initial_memory + + # Force garbage collection + gc.collect() diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index aff8ebc903623..b01f865f1bb03 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -21,7 +21,7 @@ from vllm.lora.worker_manager import LRUCacheWorkerLoRAManager from vllm.lora.layers import LoRAMapping from vllm.lora.request import LoRARequest -from vllm.utils import in_wsl +from vllm.utils import in_wsl, measure_cuda_memory logger = init_logger(__name__) @@ -85,11 +85,17 @@ def __init__( self.model_config.enforce_eager = True def load_model(self) -> None: - self.model = get_model(self.model_config, - self.device_config, - lora_config=self.lora_config, - parallel_config=self.parallel_config, - scheduler_config=self.scheduler_config) + with measure_cuda_memory() as m: + self.model = get_model(self.model_config, + self.device_config, + lora_config=self.lora_config, + parallel_config=self.parallel_config, + scheduler_config=self.scheduler_config) + + self.model_memory_usage = m.consumed_memory + logger.info( + f"Loading model weights took {self.model_memory_usage / float(2**30):.4f} GB" + ) vocab_size = self.model.config.vocab_size