-
-
Notifications
You must be signed in to change notification settings - Fork 4.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[core] cudagraph output with tensor weak reference #9724
Conversation
👋 Hi! Thank you for contributing to the vLLM project. Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can do one of these:
🚀 |
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Very clever implementation!
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
@youkaichao why only 1,2,3 change to 6 and not 0,1,2,3? |
@alexm-neuralmagic because:
3 GiB graph will modify the buffer of 2 GiB graph and also 1 GiB graph, but not 4 GiB buffer. |
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: qishuai <ferdinandzhong@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: Randall Smith <Randall.Smith@amd.com>
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com>
Signed-off-by: youkaichao <youkaichao@gmail.com> Signed-off-by: NickLucche <nlucches@redhat.com>
to correctly use cudagraph, we have to:
say we have a simple graph that does a series of operations, and we want to use cudagraph to record operations with different sizes:
The output is:
To understand the output:
Since we know we will only execute one graph at a time, can we somehow make the output of 4 cudagraphs share the same buffer as well?
This is not possible in normal pytorch, because we have to hold the output buffer for every output. However, when we hold a reference of a tensor from the cudagraph, that part of memory is reserved and will not be recycled.
With this pr, we introduce tensor weak reference, so that we can take a weak reference from the output tensor. This way, cudagraph can also recycle and reuse the output buffer.
When we change
use_weak_ref = True
in the code, we will get:Note that 4 graphs use 4 GiB memory in total.
Although we only replay graph 1, we can see that graph 2 and 3 's output also change. This is because they share the same buffer. It does not matter here, because we never execute two cudagraphs concurrently.
To understand the output: