-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Core] Log more GPU memory reservation info #4576
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -15,13 +15,17 @@ | |
from vllm.distributed.device_communicators import pynccl_utils | ||
from vllm.distributed.device_communicators.custom_all_reduce import ( | ||
init_custom_ar) | ||
from vllm.logger import init_logger | ||
from vllm.lora.request import LoRARequest | ||
from vllm.model_executor import set_random_seed | ||
from vllm.sequence import SamplerOutput, SequenceGroupMetadata | ||
from vllm.utils import CudaMemoryProfiler | ||
from vllm.worker.cache_engine import CacheEngine | ||
from vllm.worker.model_runner import ModelRunner | ||
from vllm.worker.worker_base import WorkerBase | ||
|
||
logger = init_logger(__name__) | ||
|
||
|
||
class Worker(WorkerBase): | ||
"""A worker class that executes (a partition of) the model on a GPU. | ||
|
@@ -180,14 +184,26 @@ def initialize_cache(self, num_gpu_blocks: int, | |
|
||
def _init_cache_engine(self): | ||
assert self.cache_config.num_gpu_blocks is not None | ||
self.cache_engine = CacheEngine(self.cache_config, self.model_config, | ||
self.parallel_config) | ||
with CudaMemoryProfiler() as m: | ||
self.cache_engine = CacheEngine(self.cache_config, | ||
self.model_config, | ||
self.parallel_config) | ||
mem_usage = m.consumed_memory | ||
unit, scale = "GB", float(2**30) | ||
logger.info("GPU KV cache reserves %.4f %s GPU memory.", | ||
mem_usage / scale, unit) | ||
self.gpu_cache = self.cache_engine.gpu_cache | ||
self.model_runner.set_block_size(self.cache_engine.block_size) | ||
|
||
def _warm_up_model(self) -> None: | ||
if not self.model_config.enforce_eager: | ||
self.model_runner.capture_model(self.gpu_cache) | ||
with CudaMemoryProfiler() as m: | ||
self.model_runner.capture_model(self.gpu_cache) | ||
mem_usage = m.consumed_memory | ||
unit, scale = "GB", float(2**30) | ||
logger.info("Capturing cuda graph reserves %.4f %s GPU memory.", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we use There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. feel like it'd be useful to show exact memory allocation by default since I've seen some users not understanding this part (like here; https://www.reddit.com/r/LocalLLaMA/comments/1bz3bn1/whats_up_with_vllm/) |
||
mem_usage / scale, unit) | ||
|
||
# Reset the seed to ensure that the random state is not affected by | ||
# the model initialization and profiling. | ||
set_random_seed(self.model_config.seed) | ||
|
@@ -332,9 +348,12 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): | |
def raise_if_cache_size_invalid(num_gpu_blocks, block_size, | ||
max_model_len) -> None: | ||
if num_gpu_blocks <= 0: | ||
raise ValueError("No available memory for the cache blocks. " | ||
"Try increasing `gpu_memory_utilization` when " | ||
"initializing the engine.") | ||
raise ValueError( | ||
"No available memory for the cache blocks. vLLM needs {} more GPU " | ||
"blocks to allocate. Try increasing `gpu_memory_utilization` when " | ||
"initializing the engine. Or increase `tensor_parallel_size`, which" | ||
"shards model weights across GPUs. It gives more memory to " | ||
"allocate kv cache blocks per GPU.".format(-num_gpu_blocks)) | ||
Comment on lines
+351
to
+356
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. IIUC, the negative |
||
max_seq_len = block_size * num_gpu_blocks | ||
if max_model_len > max_seq_len: | ||
raise ValueError( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This only measures the GPU memory used by "tensor", so it is inaccurate for cuda graph (which uses memory outside tensor). I kept the original way because i thought it could be useful. but I am open to just remove code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
IIUC, the change made in this PR is incorrect because
torch.cuda.mem_get_info
does not consider the "free" GPU memory managed by PyTorch caching allocator.Once the GPU memory is allocated (via
cuda-malloc
) by PyTorch caching allocator, the memory is nevercuda-free
d unless the user enforces it (e.g., byempty_cache
). While this memory can be regarded free by PyTorch allocator because it's not used for any tensor, it's not regarded free from the external point of view (e.g., in nvidia-smi ortorch.cuda.mem_get_info
). Therefore, the profiler cannot capture the GPU memory usage inside the PyTorch allocator, and thus may under-estimate the memory usage.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ah interesting... so it is like cached memory in fs.
one issue was that using the original method couldn't take into account of memory used by non-torch (so it inaccurately reported cuda graph mem usage). Let me see if I can find a way to take into account of buffer.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@WoosukKwon
Currently thinking about 2 different options;
This should include releasable memory from torch cache allocator to the free memory (based on https://pytorch.org/docs/stable/notes/cuda.html#memory-management)
Any thoughts?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What do you exactly mean by this?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So keep other part as it is. And for cuda graph capture part, we measure memoroy by