Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Log more GPU memory reservation info #4576

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 15 additions & 3 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,25 @@ def is_pin_memory_available() -> bool:

class CudaMemoryProfiler:

def __init__(self, device=None):
def __init__(self, device=None, capture_max_memory: bool = False):
"""A context manager to measure memory usage on a given device.

If capture_max_memory is True, it measures the maximum memory usage
during the profiling. However, it can only measure GPU memory used by
torch tensor. If it is False, it measures the memory delta which also
includes non-torch tensor GPU memory usage.
"""
self.device = device
self.capture_max_memory = capture_max_memory

def current_memory_usage(self) -> float:
# Return the memory usage in bytes.
torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
if self.capture_max_memory:
torch.cuda.reset_peak_memory_stats(self.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only measures the GPU memory used by "tensor", so it is inaccurate for cuda graph (which uses memory outside tensor). I kept the original way because i thought it could be useful. but I am open to just remove code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the change made in this PR is incorrect because torch.cuda.mem_get_info does not consider the "free" GPU memory managed by PyTorch caching allocator.

Once the GPU memory is allocated (via cuda-malloc) by PyTorch caching allocator, the memory is never cuda-freed unless the user enforces it (e.g., by empty_cache). While this memory can be regarded free by PyTorch allocator because it's not used for any tensor, it's not regarded free from the external point of view (e.g., in nvidia-smi or torch.cuda.mem_get_info). Therefore, the profiler cannot capture the GPU memory usage inside the PyTorch allocator, and thus may under-estimate the memory usage.

Copy link
Collaborator Author

@rkooo567 rkooo567 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting... so it is like cached memory in fs.

one issue was that using the original method couldn't take into account of memory used by non-torch (so it inaccurately reported cuda graph mem usage). Let me see if I can find a way to take into account of buffer.

Copy link
Collaborator Author

@rkooo567 rkooo567 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon

Currently thinking about 2 different options;

  • Option 1: Keep the current memory context manager as it is. For cuda graph, we measure the memory using snapshot (using mem_get_info). for kv cache and model, it should be fine measuring the memory using the status quo (a.k.a., memory_allocated because gpu memory is used just for tensors in these cases)
  • Option 2: Modify the memory context manager to use
free_mem_from_torch = torch.cuda.memory_reserved (total mem used by allocator) - torch.cuda.memory_allocated (total mem used by tensor). 
free, total = torch.cuda.mem_get_info(self.device)
mem_usage = total - (free + free_mem_from_torch)

This should include releasable memory from torch cache allocator to the free memory (based on https://pytorch.org/docs/stable/notes/cuda.html#memory-management)

Any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For cuda graph, we measure the memory using snapshot (using mem_get_info).

What do you exactly mean by this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So keep other part as it is. And for cuda graph capture part, we measure memoroy by

free, total = torch.cuda.mem_get_info(self.device)
used = total - free
cuda_capture()
free, total = torch.cuda.mem_get_info(self.device)
used_after = total - free
memory_used = used_after - used

mem = torch.cuda.max_memory_allocated(self.device)
else:
free, total = torch.cuda.mem_get_info(self.device)
mem = total - free
return mem

def __enter__(self):
Expand Down
6 changes: 4 additions & 2 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,12 @@
_PAD_SLOT_ID = -1
LORA_WARMUP_RANK = 8
_BATCH_SIZE_ALIGNMENT = 8
MAX_BATCH_SIZE_TO_CAPTURE = 256
# Capture graphs for token size 1, 2, 4, 8, 16, 24, 32, 40, ..., 256.
# NOTE: _get_graph_batch_size needs to be updated if this list is changed.
_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [
_BATCH_SIZE_ALIGNMENT * i for i in range(1, 33)
_BATCH_SIZE_ALIGNMENT * i
for i in range(1, MAX_BATCH_SIZE_TO_CAPTURE // _BATCH_SIZE_ALIGNMENT + 1)
]


Expand Down Expand Up @@ -156,7 +158,7 @@ def __init__(
# (max batch size to capture, max context len to capture / block size).
self.graph_block_tables: torch.Tensor # Set after initial profiling.

def load_model(self) -> None:
def load_model(self):
with CudaMemoryProfiler() as m:
self.model = get_model(
model_config=self.model_config,
Expand Down
31 changes: 25 additions & 6 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,17 @@
from vllm.distributed.device_communicators import pynccl_utils
from vllm.distributed.device_communicators.custom_all_reduce import (
init_custom_ar)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.model_executor import set_random_seed
from vllm.sequence import SamplerOutput, SequenceGroupMetadata
from vllm.utils import CudaMemoryProfiler
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.model_runner import ModelRunner
from vllm.worker.worker_base import WorkerBase

logger = init_logger(__name__)


class Worker(WorkerBase):
"""A worker class that executes (a partition of) the model on a GPU.
Expand Down Expand Up @@ -180,14 +184,26 @@ def initialize_cache(self, num_gpu_blocks: int,

def _init_cache_engine(self):
assert self.cache_config.num_gpu_blocks is not None
self.cache_engine = CacheEngine(self.cache_config, self.model_config,
self.parallel_config)
with CudaMemoryProfiler() as m:
self.cache_engine = CacheEngine(self.cache_config,
self.model_config,
self.parallel_config)
mem_usage = m.consumed_memory
unit, scale = "GB", float(2**30)
logger.info("GPU KV cache reserves %.4f %s GPU memory.",
mem_usage / scale, unit)
self.gpu_cache = self.cache_engine.gpu_cache
self.model_runner.set_block_size(self.cache_engine.block_size)

def _warm_up_model(self) -> None:
if not self.model_config.enforce_eager:
self.model_runner.capture_model(self.gpu_cache)
with CudaMemoryProfiler() as m:
self.model_runner.capture_model(self.gpu_cache)
mem_usage = m.consumed_memory
unit, scale = "GB", float(2**30)
logger.info("Capturing cuda graph reserves %.4f %s GPU memory.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use debug instead?

Copy link
Collaborator Author

@rkooo567 rkooo567 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel like it'd be useful to show exact memory allocation by default since I've seen some users not understanding this part (like here; https://www.reddit.com/r/LocalLLaMA/comments/1bz3bn1/whats_up_with_vllm/)

mem_usage / scale, unit)

# Reset the seed to ensure that the random state is not affected by
# the model initialization and profiling.
set_random_seed(self.model_config.seed)
Expand Down Expand Up @@ -332,9 +348,12 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype):
def raise_if_cache_size_invalid(num_gpu_blocks, block_size,
max_model_len) -> None:
if num_gpu_blocks <= 0:
raise ValueError("No available memory for the cache blocks. "
"Try increasing `gpu_memory_utilization` when "
"initializing the engine.")
raise ValueError(
"No available memory for the cache blocks. vLLM needs {} more GPU "
"blocks to allocate. Try increasing `gpu_memory_utilization` when "
"initializing the engine. Or increase `tensor_parallel_size`, which"
"shards model weights across GPUs. It gives more memory to "
"allocate kv cache blocks per GPU.".format(-num_gpu_blocks))
Comment on lines +351 to +356
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the negative num_gpu_blocks does not give any useful information; it just means the memory profiling was inaccurate for some reason.

max_seq_len = block_size * num_gpu_blocks
if max_model_len > max_seq_len:
raise ValueError(
Expand Down
Loading