Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reduce memory overhead of CUDA graphs #205

Closed
pommedeterresautee opened this issue Dec 12, 2022 · 1 comment · Fixed by #225
Closed

Reduce memory overhead of CUDA graphs #205

pommedeterresautee opened this issue Dec 12, 2022 · 1 comment · Fixed by #225
Assignees
Labels
performance make things faster, always

Comments

@pommedeterresautee
Copy link
Member

Current CUDA graph wrapper create a static input and static output per call to the model.
In decoder, it may create a bunch of tensors, we may want to limit those creations and try to recycle them.

@pommedeterresautee pommedeterresautee added the performance make things faster, always label Dec 12, 2022
@pommedeterresautee pommedeterresautee self-assigned this Dec 12, 2022
@pommedeterresautee
Copy link
Member Author

some code with some results

#  Copyright 2022 Lefebvre Sarrut
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#

from typing import Callable, Union

import torch

default_size = 10 * 8 * 24 * 64
tensor_pool = list()
reused_inputs = set()

def cuda_graphs_wrapper(
    model: Callable,
    inputs: Union[list[torch.Tensor], tuple[torch.Tensor]],
    copy_outputs: bool = False,
    pool: (int, int) = torch.cuda.graph_pool_handle(),
):
    """
    From torchdynamo
    """
    assert isinstance(inputs, (list, tuple)), f"inputs is of type {type(inputs)} instead of list"
    is_tensor_reused = list()
    for index, input_tensor in enumerate(inputs):
        is_tensor_reused.append(input_tensor.data_ptr() in reused_inputs)
        if index == len(tensor_pool):
            t = torch.zeros((default_size,), dtype=input_tensor.dtype, device="cuda")
            tensor_pool.append(t)
        else:
            t = tensor_pool[index]
        assert len(t.storage()) >= input_tensor.numel(), f"Tensors in pool are too small, please increase it: {len(t.storage())} < {input_tensor.numel()}"
        t.resize_(input_tensor.shape)

    # required warmup, not just for perf but for correctness
    torch.cuda.synchronize()
    stream = torch.cuda.Stream()
    stream.wait_stream(torch.cuda.current_stream())
    with torch.cuda.stream(stream):
        # 2 rounds, 1 to build the model (triton kernels, casting, etc.),
        # and 1 for warmup
        for _ in range(2):
            model(*inputs)
    stream.synchronize()
    torch.cuda.current_stream().wait_stream(stream)
    torch.cuda.synchronize()

    reused_inputs.clear()
    for i in inputs:
        reused_inputs.add(i.data_ptr())

    # record
    graph = torch.cuda.CUDAGraph()
    with torch.cuda.graph(graph, stream=stream, pool=pool):
        static_inputs = tensor_pool[: len(inputs)]
        static_outputs = model(*static_inputs)
    if not isinstance(static_outputs, (list, tuple)):
        static_outputs = (static_outputs,)

    def run(*cg_inputs):
        assert isinstance(cg_inputs, (list, tuple)), f"inputs is of unexpected type: {type(cg_inputs)}"
        assert len(is_tensor_reused) == len(cg_inputs), f"{len(is_tensor_reused)} != {len(cg_inputs)}"
        # assert len(static_inputs) == len(cg_inputs), f"{len(static_inputs)} != {len(cg_inputs)}"
        # cuda graph can only read data from the same address
        for src, dst, src_already_set in zip(cg_inputs, tensor_pool, is_tensor_reused):
            if not src_already_set:  # some tensors are reused from call to call, so we don't need to copy them
                dst.resize_(src.shape)
                dst.copy_(src)

        graph.replay()
        if copy_outputs:
            return [x.clone() for x in static_outputs]
        else:
            return static_outputs

    return run

Results on T5 notebook:

for num_beams=1

with optimization
torch.cuda.memory_allocated: 0.186218GB
torch.cuda.memory_reserved: 0.595703GB
torch.cuda.max_memory_reserved: 1.277344GB

without optimization
torch.cuda.memory_allocated: 0.193425GB
torch.cuda.memory_reserved: 0.601562GB
torch.cuda.max_memory_reserved: 1.277344GB

for num_beams=5

with optimization
torch.cuda.memory_allocated: 0.301963GB
torch.cuda.memory_reserved: 0.714844GB
torch.cuda.max_memory_reserved: 1.294922GB

withou
torch.cuda.memory_allocated: 0.381844GB
torch.cuda.memory_reserved: 0.796875GB
torch.cuda.max_memory_reserved: 1.294922GB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance make things faster, always
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant