Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Log more GPU memory reservation info #4576

Open
wants to merge 4 commits into
base: main
Choose a base branch
from

Conversation

rkooo567
Copy link
Collaborator

@rkooo567 rkooo567 commented May 3, 2024

Currently we logs the GPU memory usage from model loading, but not for kv cache or cuda graph capture. This PR adds exact GPU memory reservation info.

It also improves error messages when there's not enough gpu for kv caches.


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@rkooo567
Copy link
Collaborator Author

rkooo567 commented May 3, 2024

cc @mgoin (for measuring gpu memory)

torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
if self.capture_max_memory:
torch.cuda.reset_peak_memory_stats(self.device)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This only measures the GPU memory used by "tensor", so it is inaccurate for cuda graph (which uses memory outside tensor). I kept the original way because i thought it could be useful. but I am open to just remove code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the change made in this PR is incorrect because torch.cuda.mem_get_info does not consider the "free" GPU memory managed by PyTorch caching allocator.

Once the GPU memory is allocated (via cuda-malloc) by PyTorch caching allocator, the memory is never cuda-freed unless the user enforces it (e.g., by empty_cache). While this memory can be regarded free by PyTorch allocator because it's not used for any tensor, it's not regarded free from the external point of view (e.g., in nvidia-smi or torch.cuda.mem_get_info). Therefore, the profiler cannot capture the GPU memory usage inside the PyTorch allocator, and thus may under-estimate the memory usage.

Copy link
Collaborator Author

@rkooo567 rkooo567 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah interesting... so it is like cached memory in fs.

one issue was that using the original method couldn't take into account of memory used by non-torch (so it inaccurately reported cuda graph mem usage). Let me see if I can find a way to take into account of buffer.

Copy link
Collaborator Author

@rkooo567 rkooo567 May 8, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@WoosukKwon

Currently thinking about 2 different options;

  • Option 1: Keep the current memory context manager as it is. For cuda graph, we measure the memory using snapshot (using mem_get_info). for kv cache and model, it should be fine measuring the memory using the status quo (a.k.a., memory_allocated because gpu memory is used just for tensors in these cases)
  • Option 2: Modify the memory context manager to use
free_mem_from_torch = torch.cuda.memory_reserved (total mem used by allocator) - torch.cuda.memory_allocated (total mem used by tensor). 
free, total = torch.cuda.mem_get_info(self.device)
mem_usage = total - (free + free_mem_from_torch)

This should include releasable memory from torch cache allocator to the free memory (based on https://pytorch.org/docs/stable/notes/cuda.html#memory-management)

Any thoughts?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For cuda graph, we measure the memory using snapshot (using mem_get_info).

What do you exactly mean by this?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So keep other part as it is. And for cuda graph capture part, we measure memoroy by

free, total = torch.cuda.mem_get_info(self.device)
used = total - free
cuda_capture()
free, total = torch.cuda.mem_get_info(self.device)
used_after = total - free
memory_used = used_after - used

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rkooo567 Thanks for the PR! This information would be useful. Please check out my comments.

Comment on lines +351 to +356
raise ValueError(
"No available memory for the cache blocks. vLLM needs {} more GPU "
"blocks to allocate. Try increasing `gpu_memory_utilization` when "
"initializing the engine. Or increase `tensor_parallel_size`, which"
"shards model weights across GPUs. It gives more memory to "
"allocate kv cache blocks per GPU.".format(-num_gpu_blocks))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the negative num_gpu_blocks does not give any useful information; it just means the memory profiling was inaccurate for some reason.

torch.cuda.reset_peak_memory_stats(self.device)
mem = torch.cuda.max_memory_allocated(self.device)
if self.capture_max_memory:
torch.cuda.reset_peak_memory_stats(self.device)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IIUC, the change made in this PR is incorrect because torch.cuda.mem_get_info does not consider the "free" GPU memory managed by PyTorch caching allocator.

Once the GPU memory is allocated (via cuda-malloc) by PyTorch caching allocator, the memory is never cuda-freed unless the user enforces it (e.g., by empty_cache). While this memory can be regarded free by PyTorch allocator because it's not used for any tensor, it's not regarded free from the external point of view (e.g., in nvidia-smi or torch.cuda.mem_get_info). Therefore, the profiler cannot capture the GPU memory usage inside the PyTorch allocator, and thus may under-estimate the memory usage.

self.model_runner.capture_model(self.gpu_cache)
mem_usage = m.consumed_memory
unit, scale = "GB", float(2**30)
logger.info("Capturing cuda graph reserves %.4f %s GPU memory.",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we use debug instead?

Copy link
Collaborator Author

@rkooo567 rkooo567 May 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

feel like it'd be useful to show exact memory allocation by default since I've seen some users not understanding this part (like here; https://www.reddit.com/r/LocalLLaMA/comments/1bz3bn1/whats_up_with_vllm/)

Copy link

This pull request has been automatically marked as stale because it has not had any activity within 90 days. It will be automatically closed if no further activity occurs within 30 days. Leave a comment if you feel this pull request should remain open. Thank you!

@github-actions github-actions bot added the stale label Oct 28, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants