-
Notifications
You must be signed in to change notification settings - Fork 95
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Reduce memory overhead of CUDA graphs #205
Labels
performance
make things faster, always
Comments
some code with some results # Copyright 2022 Lefebvre Sarrut
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from typing import Callable, Union
import torch
default_size = 10 * 8 * 24 * 64
tensor_pool = list()
reused_inputs = set()
def cuda_graphs_wrapper(
model: Callable,
inputs: Union[list[torch.Tensor], tuple[torch.Tensor]],
copy_outputs: bool = False,
pool: (int, int) = torch.cuda.graph_pool_handle(),
):
"""
From torchdynamo
"""
assert isinstance(inputs, (list, tuple)), f"inputs is of type {type(inputs)} instead of list"
is_tensor_reused = list()
for index, input_tensor in enumerate(inputs):
is_tensor_reused.append(input_tensor.data_ptr() in reused_inputs)
if index == len(tensor_pool):
t = torch.zeros((default_size,), dtype=input_tensor.dtype, device="cuda")
tensor_pool.append(t)
else:
t = tensor_pool[index]
assert len(t.storage()) >= input_tensor.numel(), f"Tensors in pool are too small, please increase it: {len(t.storage())} < {input_tensor.numel()}"
t.resize_(input_tensor.shape)
# required warmup, not just for perf but for correctness
torch.cuda.synchronize()
stream = torch.cuda.Stream()
stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(stream):
# 2 rounds, 1 to build the model (triton kernels, casting, etc.),
# and 1 for warmup
for _ in range(2):
model(*inputs)
stream.synchronize()
torch.cuda.current_stream().wait_stream(stream)
torch.cuda.synchronize()
reused_inputs.clear()
for i in inputs:
reused_inputs.add(i.data_ptr())
# record
graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(graph, stream=stream, pool=pool):
static_inputs = tensor_pool[: len(inputs)]
static_outputs = model(*static_inputs)
if not isinstance(static_outputs, (list, tuple)):
static_outputs = (static_outputs,)
def run(*cg_inputs):
assert isinstance(cg_inputs, (list, tuple)), f"inputs is of unexpected type: {type(cg_inputs)}"
assert len(is_tensor_reused) == len(cg_inputs), f"{len(is_tensor_reused)} != {len(cg_inputs)}"
# assert len(static_inputs) == len(cg_inputs), f"{len(static_inputs)} != {len(cg_inputs)}"
# cuda graph can only read data from the same address
for src, dst, src_already_set in zip(cg_inputs, tensor_pool, is_tensor_reused):
if not src_already_set: # some tensors are reused from call to call, so we don't need to copy them
dst.resize_(src.shape)
dst.copy_(src)
graph.replay()
if copy_outputs:
return [x.clone() for x in static_outputs]
else:
return static_outputs
return run Results on T5 notebook: for
for
|
This was referenced Dec 17, 2022
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Current CUDA graph wrapper create a static input and static output per call to the model.
In decoder, it may create a bunch of tensors, we may want to limit those creations and try to recycle them.
The text was updated successfully, but these errors were encountered: