Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[core] cudagraph output with tensor weak reference #9724

Merged
merged 4 commits into from
Oct 27, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented Oct 26, 2024

to correctly use cudagraph, we have to:

  1. capture the graph
  • prepare input buffer
  • run the graph (buffers created and destroyed inside cudagraph will be replayed later)
  • save the output buffer
  1. replay the graph
  • copy data to input buffer
  • replay the graph
  • read data from the output buffer

say we have a simple graph that does a series of operations, and we want to use cudagraph to record operations with different sizes:

import torch

GiB = 1024 * 1024 * 1024 // 4

# 4 GiB input buffer
input_buffer = torch.randn((4, GiB), device="cuda")

# output buffer sizes:
# 4 GiB, 4 GiB, 2 GiB, 1 GiB
sizes = [4, 3, 2, 1]
graphs = []
outputs = []

def report_memory(prefix):
    free, total = torch.cuda.mem_get_info()
    used = total - free
    print(f"{prefix}: Used: {used / 1024 / 1024} MB, Free: {free / 1024 / 1024} MB, Total: {total / 1024 / 1024} MB")

report_memory("Before capture")

use_weak_ref = False
from vllm.utils import weak_ref_tensor

pool = torch.cuda.graph_pool_handle()
for size in sizes:
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, pool=pool):
        x = input_buffer[:size]
        out = x.clone()
        out += 1
        out += 2
        out += 3
        if use_weak_ref:
            out = weak_ref_tensor(out)
        outputs.append(out)
        del x, out
    graphs.append(graph)

report_memory("After capture")

input_buffer.zero_()
graphs[1].replay() # outputs[1] will add 6
for i in range(4):
    print(f"{outputs[i][-1][0]=}")

The output is:

Before capture: Used: 4627.0625 MB, Free: 76489.625 MB, Total: 81116.6875 MB
After capture: Used: 15029.0625 MB, Free: 66087.625 MB, Total: 81116.6875 MB
i=0, outputs[i][-1][0]=tensor(0., device='cuda:0')
i=1, outputs[i][-1][0]=tensor(6., device='cuda:0')
i=2, outputs[i][-1][0]=tensor(0., device='cuda:0')
i=3, outputs[i][-1][0]=tensor(0., device='cuda:0')

To understand the output:

Step 1: Capture the Graph
---------------------------
1. Prepare Input Buffer (4 GiB)

       ┌─────────────────────────────────────────────────────────┐
Input: │                  input_buffer (4 GiB)                   │
       └─────────────────────────────────────────────────────────┘

2. Record Graph Operations
   └─ Graph captures series of operations on input slices of sizes: 4 GiB, 3 GiB, 2 GiB, 1 GiB
       ┌───────────┐   ┌───────────┐   ┌───────────┐   ┌───────────┐
       │ Size 4GiB │   │ Size 3GiB │   │ Size 2GiB │   │ Size 1GiB │
       └───────────┘   └───────────┘   └───────────┘   └───────────┘
   └─ Graph allocates output buffers for each size, adds +1, +2, +3 operations.

3. Save Output Buffers
   └─ Each cudagraph output saved in separate buffers

Memory Report After Capture:
   Used: 15029.0625 MB, Free: 66087.625 MB, Total: 81116.6875 MB

Step 2: Replay the Graph
--------------------------
1. Copy data to Input Buffer
   └─ Graph replay starts by loading data into input buffer.

2. Replay Graph Operations
   └─ Only `graphs[1]` is replayed
       └─ Graph performs `+1`, `+2`, and `+3` on input slice, resulting in output tensor value of 6.

3. Read data from Output Buffer
   └─ Result is saved and displayed for each graph output:
       ┌───────────────┐
       │ outputs[0]=0 │   # Slice size 4 GiB: replay not executed
       └───────────────┘
       ┌───────────────┐
       │ outputs[1]=6 │   # Slice size 3 GiB: replay executed
       └───────────────┘
       ┌───────────────┐
       │ outputs[2]=0 │   # Slice size 2 GiB: replay not executed
       └───────────────┘
       ┌───────────────┐
       │ outputs[3]=0 │   # Slice size 1 GiB: replay not executed
       └───────────────┘

Since we know we will only execute one graph at a time, can we somehow make the output of 4 cudagraphs share the same buffer as well?

This is not possible in normal pytorch, because we have to hold the output buffer for every output. However, when we hold a reference of a tensor from the cudagraph, that part of memory is reserved and will not be recycled.

With this pr, we introduce tensor weak reference, so that we can take a weak reference from the output tensor. This way, cudagraph can also recycle and reuse the output buffer.

When we change use_weak_ref = True in the code, we will get:

Before capture: Used: 4627.0625 MB, Free: 76489.625 MB, Total: 81116.6875 MB
After capture: Used: 8885.0625 MB, Free: 72231.625 MB, Total: 81116.6875 MB
i=0, outputs[i][-1][0]=tensor(0., device='cuda:0')
i=1, outputs[i][-1][0]=tensor(6., device='cuda:0')
i=2, outputs[i][-1][0]=tensor(6., device='cuda:0')
i=3, outputs[i][-1][0]=tensor(6., device='cuda:0')

Note that 4 graphs use 4 GiB memory in total.

Although we only replay graph 1, we can see that graph 2 and 3 's output also change. This is because they share the same buffer. It does not matter here, because we never execute two cudagraphs concurrently.

To understand the output:

1. Prepare Input Buffer (4 GiB)

       ┌─────────────────────────────────────────────────────────┐
Input: │                  input_buffer (4 GiB)                   │
       └─────────────────────────────────────────────────────────┘

2. output buffer

          ┌─────────────────────────────────────────────────────────┐
          │                      Shared Output Buffer               │
          │                (Used sequentially by each graph)        │
          └─────────────────────────────────────────────────────────┘

          ┌─────────────────────────────────────────────────────────┐
          │                    Output for Size 4 GiB                │
          ├───────────────────────────────────────────────┐
          │                 Output for Size 3 GiB         │
          ├───────────────────────────────┐
          │         Output for Size 2 GiB │
          ├───────────────┐
          │ Output for Size 1 GiB 
          └───────────────┘

After replay:

┌───────────────┐
│ outputs[0]=0 │   # Slice size 4 GiB: replay not executed
└───────────────┘
┌───────────────┐
│ outputs[1]=6 │   # Slice size 3 GiB: replay executed
└───────────────┘
┌───────────────┐
│ outputs[2]=6 │   # Slice size 2 GiB: updated due to shared buffer
└───────────────┘
┌───────────────┐
│ outputs[3]=6 │   # Slice size 1 GiB: updated due to shared buffer
└───────────────┘

Signed-off-by: youkaichao <youkaichao@gmail.com>
Copy link

👋 Hi! Thank you for contributing to the vLLM project.
Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can do one of these:

  • Add ready label to the PR
  • Enable auto-merge.

🚀

Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: youkaichao <youkaichao@gmail.com>
@youkaichao youkaichao changed the title [draft] weakref tensors [core] cudagraph output with tensor weak reference Oct 26, 2024
@youkaichao youkaichao added the ready ONLY add when PR is ready to merge/full CI is needed label Oct 27, 2024
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Very clever implementation!

@youkaichao youkaichao merged commit 8549c82 into vllm-project:main Oct 27, 2024
91 checks passed
@youkaichao youkaichao deleted the weakref_tensor branch October 27, 2024 07:19
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
cooleel pushed a commit to cooleel/vllm that referenced this pull request Oct 28, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Shanshan Wang <shanshan.wang@h2o.ai>
@alexm-neuralmagic
Copy link
Collaborator

@youkaichao why only 1,2,3 change to 6 and not 0,1,2,3?

@youkaichao
Copy link
Member Author

@alexm-neuralmagic because:

Slice size 4 GiB: replay not executed

3 GiB graph will modify the buffer of 2 GiB graph and also 1 GiB graph, but not 4 GiB buffer.

FerdinandZhong pushed a commit to FerdinandZhong/vllm that referenced this pull request Oct 29, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: qishuai <ferdinandzhong@gmail.com>
rasmith pushed a commit to rasmith/vllm that referenced this pull request Oct 30, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: Randall Smith <Randall.Smith@amd.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
NickLucche pushed a commit to NickLucche/vllm that referenced this pull request Oct 31, 2024
Signed-off-by: youkaichao <youkaichao@gmail.com>
Signed-off-by: NickLucche <nlucches@redhat.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
ready ONLY add when PR is ready to merge/full CI is needed
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants