-
-
Notifications
You must be signed in to change notification settings - Fork 8.4k
[V1] Reuse V0's memory_profiling util for gpu worker memory profiling #19312
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
c23e73b
b13d767
83247be
eb6cbdd
2dfc26d
4f0d18c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -22,7 +22,7 @@ | |
from vllm.model_executor import set_random_seed | ||
from vllm.platforms import current_platform | ||
from vllm.sequence import IntermediateTensors | ||
from vllm.utils import GiB_bytes | ||
from vllm.utils import GiB_bytes, MemorySnapshot, memory_profiling | ||
from vllm.v1.kv_cache_interface import KVCacheConfig, KVCacheSpec | ||
from vllm.v1.outputs import ModelRunnerOutput | ||
from vllm.v1.utils import report_usage_stats | ||
|
@@ -130,20 +130,22 @@ def init_device(self): | |
_check_if_gpu_supports_dtype(self.model_config.dtype) | ||
gc.collect() | ||
torch.cuda.empty_cache() | ||
self.init_gpu_memory, total_gpu_memory = torch.cuda.mem_get_info() | ||
requested_memory = (total_gpu_memory * | ||
self.cache_config.gpu_memory_utilization) | ||
if self.init_gpu_memory < requested_memory: | ||
|
||
# take current memory snapshot | ||
self.init_snapshot = MemorySnapshot() | ||
self.requested_memory = (self.init_snapshot.total_memory * | ||
self.cache_config.gpu_memory_utilization) | ||
if self.init_snapshot.free_memory < self.requested_memory: | ||
GiB = lambda b: round(b / GiB_bytes, 2) | ||
raise ValueError( | ||
f"Free memory on device ({GiB(self.init_gpu_memory)}/" | ||
f"{GiB(total_gpu_memory)} GiB) on startup is less than " | ||
f"desired GPU memory utilization " | ||
f"Free memory on device " | ||
f"({GiB(self.init_snapshot.free_memory)}/" | ||
f"{GiB(self.init_snapshot.total_memory)} GiB) on startup " | ||
f"is less than desired GPU memory utilization " | ||
f"({self.cache_config.gpu_memory_utilization}, " | ||
f"{GiB(requested_memory)} GiB). Decrease GPU memory " | ||
f"{GiB(self.requested_memory)} GiB). Decrease GPU memory " | ||
f"utilization or reduce GPU memory used by other processes." | ||
) | ||
|
||
else: | ||
raise RuntimeError( | ||
f"Not support device type: {self.device_config.device}") | ||
|
@@ -192,57 +194,39 @@ def determine_available_memory(self) -> int: | |
""" | ||
torch.cuda.empty_cache() | ||
torch.cuda.reset_peak_memory_stats() | ||
GiB = lambda b: b / GiB_bytes | ||
|
||
_, total_gpu_memory = torch.cuda.mem_get_info() | ||
# Execute a forward pass with dummy inputs to profile the memory usage | ||
# of the model. | ||
self.model_runner.profile_run() | ||
with memory_profiling( | ||
self.init_snapshot, | ||
weights_memory=int( | ||
self.model_runner.model_memory_usage)) as profile_result: | ||
self.model_runner.profile_run() | ||
|
||
free_gpu_memory, _ = torch.cuda.mem_get_info() | ||
free_gpu_memory = profile_result.after_profile.free_memory | ||
# NOTE(woosuk): Here we assume that the other processes using the same | ||
# GPU did not change their memory usage during the profiling. | ||
assert self.init_gpu_memory > free_gpu_memory, ( | ||
assert self.init_snapshot.free_memory > free_gpu_memory, ( | ||
"Error in memory profiling. " | ||
f"Initial free memory {self.init_gpu_memory/GiB_bytes} GiB, " | ||
f"current free memory {free_gpu_memory/GiB_bytes} GiB. " | ||
f"This happens when the GPU memory was not properly cleaned up " | ||
f"before initializing the vLLM instance.") | ||
|
||
# Get the peak memory allocation recorded by torch | ||
peak_torch_memory = torch.cuda.memory_stats( | ||
)["allocated_bytes.all.peak"] | ||
|
||
# Check for any memory left around that may have been allocated on the | ||
# gpu outside of `torch`. NCCL operations, for example, can use a few | ||
# GB during a forward pass. | ||
torch.cuda.empty_cache() | ||
torch_allocated_bytes = torch.cuda.memory_stats( | ||
)["allocated_bytes.all.current"] | ||
|
||
# Reset after emptying torch cache | ||
free_gpu_memory = torch.cuda.mem_get_info()[0] | ||
f"Initial free memory {GiB(self.init_snapshot.free_memory)} GiB, " | ||
f"current free memory {GiB(free_gpu_memory)} GiB. " | ||
"This happens when other processes sharing the same container " | ||
"release GPU memory while vLLM is profiling during initialization. " | ||
"To fix this, ensure consistent GPU memory allocation or " | ||
"isolate vLLM in its own container.") | ||
available_kv_cache_memory = self.requested_memory \ | ||
- profile_result.non_kv_cache_memory | ||
|
||
# Total forward allocation (current) is equal to the diff in free memory | ||
fwd_alloc_bytes = self.init_gpu_memory - free_gpu_memory | ||
# We assume current non-torch allocation is equal to peak | ||
non_torch_alloc_bytes = max(0, fwd_alloc_bytes - torch_allocated_bytes) | ||
# Total forward allocation (peak) is peak torch + non-torch | ||
peak_memory = peak_torch_memory + non_torch_alloc_bytes | ||
|
||
available_kv_cache_memory = ( | ||
total_gpu_memory * self.cache_config.gpu_memory_utilization - | ||
peak_memory) | ||
|
||
GiB = lambda b: b / GiB_bytes | ||
logger.debug( | ||
"Initial free memory: %.2f GiB, free memory: %.2f GiB, " | ||
"total GPU memory: %.2f GiB", GiB(self.init_gpu_memory), | ||
GiB(free_gpu_memory), GiB(total_gpu_memory)) | ||
logger.debug( | ||
"Peak torch memory: %.2f GiB, non-torch forward-pass memory: " | ||
"%.2f GiB, available KVCache memory: %.2f GiB", | ||
GiB(peak_torch_memory), GiB(non_torch_alloc_bytes), | ||
GiB(available_kv_cache_memory)) | ||
"requested GPU memory: %.2f GiB", | ||
GiB(self.init_snapshot.free_memory), GiB(free_gpu_memory), | ||
GiB(self.requested_memory)) | ||
logger.debug(profile_result) | ||
logger.info("Available KV cache memory: %.2f GiB", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @ywang96 kept 1 info which i think it's useful for users |
||
GiB(available_kv_cache_memory)) | ||
gc.collect() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. it's some proactive final cleanup after the profile runs. probably wouldn't matter too much just in case we left around some objects. gc perf overhead here should not matter too. |
||
|
||
return int(available_kv_cache_memory) | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should we make this
info
as well? It would be very useful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i guess this message could be confusing to generic end users. keeping it as debug for now for developers to turn on through VLLM_LOGGING_LEVEL=DEBUG .
it's always easier to flip it if too many issues complain about this part.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Generally speaking we now try to reduce the amount of startup server logs as much as possible so that it's less confusing to the end users, and IMO it makes sense to keep this kind of information to
DEBUG
level.