Skip to content

[Core] Fix memory profiling #11120

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 13 commits into from
2 changes: 1 addition & 1 deletion tests/entrypoints/llm/test_lazy_outlines.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def run_lmfe(sample_regex):
llm = LLM(model="facebook/opt-125m",
enforce_eager=True,
guided_decoding_backend="lm-format-enforcer",
gpu_memory_utilization=0.6)
gpu_memory_utilization=0.3)
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)
outputs = llm.generate(
prompts=[
Expand Down
44 changes: 42 additions & 2 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
from functools import partial
from typing import AsyncIterator, Tuple

import cupy
import pytest
import torch

from vllm.utils import (FlexibleArgumentParser, StoreBoolean, deprecate_kwargs,
get_open_port, merge_async_iterators, supports_kw)
from vllm.utils import (FlexibleArgumentParser, MemorySnapshot, StoreBoolean,
deprecate_kwargs, get_open_port, memory_profiling,
merge_async_iterators, supports_kw)

from .utils import error_on_warning

Expand Down Expand Up @@ -270,3 +273,40 @@ def test_supports_kw(callable,kw_name,requires_kw_only,
requires_kw_only=requires_kw_only,
allow_var_kwargs=allow_var_kwargs
) == is_supported

def test_memory_profiling():
# Fake out some model loading + inference memory usage to test profiling

# Memory used by other processes will show up as cuda usage outside of torch
# Use cupy to emulate this
other_process_usage = cupy.zeros((1024, 1024, 512), dtype=cupy.float32)

# Take the initial memory snapshot
snapshot = MemorySnapshot()
snapshot.measure()

# Load the model (4GB)
model = torch.zeros(1024, 1024, 1024, dtype=torch.float32).to("cuda")

with memory_profiling(snapshot) as profile:
# Add some more "static" torch memory (1GB)
static_memory = \
torch.zeros(1024, 1024, 256, dtype=torch.float32).to("cuda")

# make a memory spike (1GB)
spike = torch.zeros(1024, 1024, 256, dtype=torch.float32).to("cuda")
del spike

# Add some extra non-torch memory (2GB)
extra_cuda = cupy.zeros((1024, 1024, 512), dtype=cupy.float32)

del model
del other_process_usage
del extra_cuda
del static_memory

# spike should be 1GB
assert profile.memory_spike_bytes == 1024 ** 3
# baseline should be model (4) + torch static memory (1) +
# extra non-torch (2) GB
assert profile.baseline_memory_bytes == 1024 ** 3 * 7
7 changes: 3 additions & 4 deletions tests/worker/test_profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ def test_gpu_memory_profiling():
is_driver_worker=True,
)

# Load the model so we can profile it
worker.init_device()
worker.load_model()

# Set 10GiB as the total gpu ram to be device-agnostic
def mock_mem_info():
current_usage = torch.cuda.memory_stats(
Expand All @@ -46,6 +42,9 @@ def mock_mem_info():

from unittest.mock import patch
with patch("torch.cuda.mem_get_info", side_effect=mock_mem_info):
# Load and profile the model
worker.init_device()
worker.load_model()
gpu_blocks, _ = worker.determine_num_available_blocks()

# Peak vram usage by torch should be 0.7077 GiB
Expand Down
8 changes: 3 additions & 5 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,11 +478,9 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
help='The fraction of GPU memory to be used for the model '
'executor, which can range from 0 to 1. For example, a value of '
'0.5 would imply 50%% GPU memory utilization. If unspecified, '
'will use the default value of 0.9. This is a global gpu memory '
'utilization limit, for example if 50%% of the gpu memory is '
'already used before vLLM starts and --gpu-memory-utilization is '
'set to 0.9, then only 40%% of the gpu memory will be allocated '
'to the model executor.')
'will use the default value of 0.9. To function properly, no other '
'processes should allocate memory on the gpu(s) while the engine '
'starts up.')
parser.add_argument(
'--num-gpu-blocks-override',
type=int,
Expand Down
120 changes: 118 additions & 2 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections import UserDict, defaultdict
from collections.abc import Iterable, Mapping
from dataclasses import dataclass
from functools import lru_cache, partial, wraps
from typing import (TYPE_CHECKING, Any, AsyncGenerator, Awaitable, Callable,
Dict, Generic, Hashable, List, Literal, Optional,
OrderedDict, Set, Tuple, Type, TypeVar, Union, overload)
Dict, Generator, Generic, Hashable, List, Literal,
Optional, OrderedDict, Set, Tuple, Type, TypeVar, Union,
overload)
from uuid import uuid4

import numpy as np
Expand Down Expand Up @@ -1635,6 +1637,120 @@ def resolve_obj_by_qualname(qualname: str) -> Any:
return getattr(module, obj_name)


@dataclass
class MemorySnapshot:
"""Memory snapshot."""
cuda_memory_in_bytes: int = 0
torch_memory_in_bytes: int = 0
timestamp: float = 0.0

def measure(self):
self.cuda_memory_in_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
self.torch_memory_in_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
self.timestamp = time.time()

def __sub__(self, other: "MemorySnapshot") -> "MemorySnapshot":
"""support a - b"""
return MemorySnapshot(
cuda_memory_in_bytes=self.cuda_memory_in_bytes -
other.cuda_memory_in_bytes,
torch_memory_in_bytes=self.torch_memory_in_bytes -
other.torch_memory_in_bytes,
timestamp=self.timestamp - other.timestamp)


@dataclass
class MemoryProfilingResult:
"""Memory profiling result.

The memory in one GPU can be classified into 3 categories:
1. (marked by -) memory used by other processes.
2. (marked by +) memory used by torch in this process.
3. (marked by *) memory used in this process, but not by torch.

The torch api `torch.cuda.memory_stats()` measures category (2). It has the
keys
- "allocated_bytes.all.current": the current torch memory usage
- "allocated_bytes.all.peak": the peak torch memory usage during profiling

We don't have direct APIs to get the first and third categories.
There's one API from cuda `torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]`,
which is the sum of all the three categories.

Because of the limitation of the APIs, to make profiling possible, we have the following assumptions:
- The memory used by other processes is constant during the profiling.
- The memory used in this process, but not by torch, will only grow during the profiling. Examples of this kind of memory are memory used by NCCL.
- The memory used by torch in this process can grow and shrink during the profiling.

Then, this profiler simply returns:
- The baseline memory increase during profiling, which is the difference
in total cuda memory allocated from before and after profiling
- The torch memory spike during profiling, which is the difference between
the peak torch memory usage and the post-profiling torch memory usage

Illustration:
| cuda memory |
| Other procs | This process |
| | torch memory | |
Before profiling: | ----------- | ++++++ | ** |
During profiling (peak): | ----------- | +++++++++++++++++ | ***** |
After profiling: | ----------- | ++++++++++++ | ***** |

This profiler returns two values:
Baseline memory: | | ++++++ | *** |
Memory Spike | | +++++ | |

""" # noqa
memory_spike_bytes: int = 0
baseline_memory_bytes: int = 0
profile_time: float = 0.0
total_gpu_memory_bytes: int = 0


@contextlib.contextmanager
def memory_profiling(
pre_profile_snapshot: MemorySnapshot
) -> Generator[MemoryProfilingResult, None, None]:
"""Memory profiling context manager.

pre_profile_snapshot is a snapshot of memory before anything that we want to
measure is allocated.
"""
result = MemoryProfilingResult()
_, result.total_gpu_memory_bytes = torch.cuda.mem_get_info()
profile_start_time = time.time()

# Prepare to measure peak memory usage
torch.cuda.reset_peak_memory_stats()

# Yield to run the code under profile
yield result

# Clean up anything that would be GC'ed so we don't measure it as well
torch.cuda.synchronize()
gc.collect()
torch.cuda.empty_cache()

post_profile_snapshot = MemorySnapshot()
post_profile_snapshot.measure()

torch_peak = torch.cuda.memory_stats()["allocated_bytes.all.peak"]
result.memory_spike_bytes = (torch_peak -
post_profile_snapshot.torch_memory_in_bytes)
result.baseline_memory_bytes = (post_profile_snapshot -
pre_profile_snapshot).cuda_memory_in_bytes

result.profile_time = post_profile_snapshot.timestamp - profile_start_time

assert result.baseline_memory_bytes >= 0, (
"Error in memory profiling. "
f"Negative memory usage detected: {result.baseline_memory_bytes}. "
"This happens when the GPU memory was not properly cleaned up before "
"initializing the vLLM instance.")


def kill_process_tree(pid: int):
"""
Kills all descendant processes of the given pid by sending SIGKILL.
Expand Down
94 changes: 35 additions & 59 deletions vllm/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""A GPU worker class."""
import gc
import os
import time
from typing import Dict, List, Optional, Set, Tuple, Type, Union

import torch
Expand All @@ -22,6 +21,7 @@
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sequence import (ExecuteModelRequest, IntermediateTensors,
SequenceGroupMetadata, SequenceGroupMetadataDelta)
from vllm.utils import MemorySnapshot, memory_profiling
from vllm.worker.cache_engine import CacheEngine
from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner
from vllm.worker.model_runner import GPUModelRunnerBase, ModelRunner
Expand Down Expand Up @@ -111,6 +111,7 @@ def __init__(
torch_profiler_trace_dir, use_gzip=True))
else:
self.profiler = None
self._initial_memory_snapshot = MemorySnapshot()

def start_profile(self):
if self.profiler is None:
Expand Down Expand Up @@ -140,7 +141,15 @@ def init_device(self) -> None:
_check_if_gpu_supports_dtype(self.model_config.dtype)
gc.collect()
torch.cuda.empty_cache()
self.init_gpu_memory = torch.cuda.mem_get_info()[0]
self._initial_memory_snapshot.measure()

free, total = torch.cuda.mem_get_info()
pct_free = free / total
if pct_free < self.cache_config.gpu_memory_utilization:
raise ValueError(
f"Only {pct_free*100:.2f}% of vram free, cannot allocate"
f"{self.cache_config.gpu_memory_utilization*100:.2f}%")

else:
raise RuntimeError(
f"Not support device type: {self.device_config.device}")
Expand Down Expand Up @@ -188,37 +197,15 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
"""
# Profile the memory usage of the model and get the maximum number of
# cache blocks that can be allocated with the remaining free memory.
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()

free_memory_pre_profile, total_gpu_memory = torch.cuda.mem_get_info()
start_time = time.time()

# Execute a forward pass with dummy inputs to profile the memory usage
# of the model.
self.model_runner.profile_run()
torch.cuda.synchronize()

self._assert_memory_footprint_increased_during_profiling()

# Get the peak memory allocation recorded by torch
peak_memory = torch.cuda.memory_stats()["allocated_bytes.all.peak"]

# Check for any memory left around that may have been allocated on the
# gpu outside of `torch`. NCCL operations, for example, can use a few
# GB during a forward pass
torch.cuda.empty_cache()
torch_allocated_bytes = torch.cuda.memory_stats(
)["allocated_bytes.all.current"]
total_allocated_bytes = torch.cuda.mem_get_info(
)[1] - torch.cuda.mem_get_info()[0]
non_torch_allocations = total_allocated_bytes - torch_allocated_bytes
if non_torch_allocations > 0:
peak_memory += non_torch_allocations

available_kv_cache_memory = (
total_gpu_memory * self.cache_config.gpu_memory_utilization -
peak_memory)
with memory_profiling(self._initial_memory_snapshot) as result:
# Execute a forward pass with dummy inputs
self.model_runner.profile_run()

target_gpu_utilization_bytes = result.total_gpu_memory_bytes * \
self.cache_config.gpu_memory_utilization
available_kv_cache_memory = (target_gpu_utilization_bytes -
result.baseline_memory_bytes -
result.memory_spike_bytes)

# Calculate the number of blocks that can be allocated with the
# profiled peak memory.
Expand All @@ -233,24 +220,23 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:
num_gpu_blocks = max(num_gpu_blocks, 0)
num_cpu_blocks = max(num_cpu_blocks, 0)

end_time = time.time()
logger.info(
"Memory profiling results: "
"duration=%.2f seconds, "
"total_gpu_memory=%.2fGiB, "
"initial_memory_usage=%.2fGiB, "
"peak_torch_memory=%.2fGiB, "
"memory_usage_post_profile=%.2fGiB, "
"non_torch_memory=%.2fGiB, "
"kv_cache_size=%.2fGiB, "
"gpu_memory_utilization=%.2f.", end_time - start_time,
total_gpu_memory / (1024**3),
(total_gpu_memory - free_memory_pre_profile) / (1024**3),
(peak_memory - non_torch_allocations) / (1024**3),
total_allocated_bytes / (1024**3),
non_torch_allocations / (1024**3),
"Memory profiling results:\n"
"duration %.2f seconds\n"
"total_gpu_memory %.2fGiB\n"
"gpu_memory_utilization %.2f\n"
"target_allocation %.2fGiB\n"
"baseline_memory %.2fGiB\n"
"max_inference_spike %.2fGiB\n"
"kv_cache_size %.2fGiB\n",
result.profile_time,
result.total_gpu_memory_bytes / (1024**3),
self.cache_config.gpu_memory_utilization,
target_gpu_utilization_bytes / (1024**3),
result.baseline_memory_bytes / (1024**3),
result.memory_spike_bytes / (1024**3),
available_kv_cache_memory / (1024**3),
self.cache_config.gpu_memory_utilization)
)

# Final cleanup
if self.model_runner.lora_manager:
Expand All @@ -259,16 +245,6 @@ def determine_num_available_blocks(self) -> Tuple[int, int]:

return num_gpu_blocks, num_cpu_blocks

def _assert_memory_footprint_increased_during_profiling(self):
# NOTE(woosuk): Here we assume that the other processes using the same
# GPU did not change their memory usage during the profiling.
free_gpu_memory, _ = torch.cuda.mem_get_info()
assert self.init_gpu_memory - free_gpu_memory > 0, (
"Error in memory profiling. "
f"Initial free memory {self.init_gpu_memory}, current free memory"
f" {free_gpu_memory}. This happens when the GPU memory was "
"not properly cleaned up before initializing the vLLM instance.")

def initialize_cache(self, num_gpu_blocks: int,
num_cpu_blocks: int) -> None:
"""Allocate GPU and CPU KV cache with the specified number of blocks.
Expand Down
Loading