Skip to content

[Performance] Faster clone #1043

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 12 commits into from
Oct 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions benchmarks/common/h2d_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,12 @@

import pytest
import torch
from packaging import version

from tensordict import TensorDict

TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@pytest.fixture
def td():
Expand Down Expand Up @@ -50,6 +53,9 @@ def default_device():


@pytest.mark.parametrize("consolidated", [False, True])
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
)
class TestTo:
def test_to(self, benchmark, consolidated, td, default_device):
if consolidated:
Expand Down
35 changes: 26 additions & 9 deletions benchmarks/compile/compile_td_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@

import pytest
import torch
from packaging import version
from tensordict import LazyStackedTensorDict, tensorclass, TensorDict
from torch.utils._pytree import tree_map

TORCH_VERSION = torch.__version__
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)


@tensorclass
Expand Down Expand Up @@ -106,7 +107,9 @@ def get_flat_tc():


# Tests runtime of a simple arithmetic op over a highly nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_add_one_nested(mode, dict_type, benchmark):
Expand All @@ -128,7 +131,9 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):


# Tests the speed of copying a nested tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_nested(mode, dict_type, benchmark):
Expand All @@ -150,7 +155,9 @@ def test_compile_copy_nested(mode, dict_type, benchmark):


# Tests runtime of a simple arithmetic op over a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_one_flat(mode, dict_type, benchmark):
Expand All @@ -177,7 +184,9 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
benchmark(func, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
def test_compile_add_self_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -207,7 +216,9 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):


# Tests the speed of copying a flat tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_copy_flat(mode, dict_type, benchmark):
Expand Down Expand Up @@ -235,7 +246,9 @@ def test_compile_copy_flat(mode, dict_type, benchmark):


# Tests the speed of assigning entries to an empty tensordict
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
def test_compile_assign_and_add(mode, dict_type, benchmark):
Expand Down Expand Up @@ -264,7 +277,9 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
# Tests the speed of assigning entries to a lazy stacked tensordict


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.skipif(
torch.cuda.is_available(), reason="max recursion depth error with cuda"
)
Expand All @@ -285,7 +300,9 @@ def test_compile_assign_and_add_stack(mode, benchmark):


# Tests indexing speed
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["compile", "eager"])
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])
Expand Down
42 changes: 31 additions & 11 deletions benchmarks/compile/tensordict_nn_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@

import pytest
import torch

from packaging import version
from tensordict import TensorDict, TensorDictParams

from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq

sys.setrecursionlimit(10000)
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)

TORCH_VERSION = torch.__version__
sys.setrecursionlimit(10000)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Expand Down Expand Up @@ -49,7 +51,9 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -64,7 +68,9 @@ def test_mod_add(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -80,7 +86,9 @@ def test_mod_wrap(mode, benchmark):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_mod_wrap_and_backward(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -104,7 +112,9 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_add(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand All @@ -129,7 +139,9 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap(mode, benchmark):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
Expand Down Expand Up @@ -161,7 +173,9 @@ def delhidden(td):
benchmark(module, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
def test_seq_wrap_and_backward(mode, benchmark):
Expand Down Expand Up @@ -201,7 +215,9 @@ def module_exec(td):
benchmark(module_exec, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("functional", [False, True])
def test_func_call_runtime(mode, functional, benchmark):
Expand Down Expand Up @@ -272,7 +288,9 @@ def call(x, td):
benchmark(call, x)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize(
Expand Down Expand Up @@ -354,7 +372,9 @@ def call(x, td):
benchmark(call_vmap, x, td)


@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
@pytest.mark.skipif(
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
)
@pytest.mark.slow
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
@pytest.mark.parametrize("plain_decorator", [None, False, True])
Expand Down
3 changes: 3 additions & 0 deletions tensordict/_td.py
Original file line number Diff line number Diff line change
Expand Up @@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
return all([value.is_contiguous() for _, value in self.items()])

def _clone(self, recurse: bool = True) -> T:
if recurse and self.device is not None:
return self._clone_recurse()

result = TensorDict._new_unsafe(
source={key: _clone_value(value, recurse) for key, value in self.items()},
batch_size=self.batch_size,
Expand Down
46 changes: 46 additions & 0 deletions tensordict/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8123,6 +8123,52 @@ def cosh_(self) -> T:
torch._foreach_cosh_(self._values_list(True, True))
return self

def _clone_recurse(self) -> TensorDictBase: # noqa: D417
keys, vals = self._items_list(True, True)
foreach_vals = {}
iter_vals = {}
for key, val in zip(keys, vals):
if (
type(val) is torch.Tensor
and not val.requires_grad
and val.dtype not in (torch.bool,)
):
foreach_vals[key] = val
else:
iter_vals[key] = val
if foreach_vals:
foreach_vals = dict(
_zip_strict(
foreach_vals.keys(),
torch._foreach_add(tuple(foreach_vals.values()), 0),
)
)
if iter_vals:
iter_vals = dict(
_zip_strict(
iter_vals.keys(),
(
val.clone() if hasattr(val, "clone") else val
for val in iter_vals.values()
),
)
)

items = foreach_vals
items.update(iter_vals)
result = self._fast_apply(
lambda name, val: items.pop(name, None),
named=True,
nested_keys=True,
is_leaf=_NESTED_TENSORS_AS_LISTS,
propagate_lock=False,
filter_empty=True,
default=None,
)
if items:
result.update(items)
return result

def add(
self,
other: TensorDictBase | torch.Tensor,
Expand Down
27 changes: 16 additions & 11 deletions tensordict/nn/cudagraphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,8 @@ def _call(
"The output of the function must be a tensordict, a tensorclass or None. Got "
f"type(out)={type(out)}."
)
if is_tensor_collection(out):
out.lock_()
self._out = out
self.counter += 1
if self._out_matches_in:
Expand Down Expand Up @@ -302,14 +304,15 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
torch._foreach_copy_(dests, srcs)
torch.cuda.synchronize()
self.graph.replay()
if self._return_unchanged == "clone":
result = self._out.clone()
elif self._return_unchanged:
if self._return_unchanged:
result = self._out
else:
result = tree_map(
lambda x: x.detach().clone() if x is not None else x,
self._out,
result = tree_unflatten(
[
out.clone() if hasattr(out, "clone") else out
for out in self._out
],
self._out_struct,
)
return result

Expand Down Expand Up @@ -340,7 +343,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
self.graph = torch.cuda.CUDAGraph()
with torch.cuda.graph(self.graph):
out = self.module(*self._args, **self._kwargs)
self._out = out
self._out, self._out_struct = tree_flatten(out)
self.counter += 1
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
# user.
Expand All @@ -356,11 +359,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
f"and the identity between input and output will not match anymore. "
f"Make sure you don't rely on input-output identity further in the code."
)
if isinstance(self._out, torch.Tensor) or self._out is None:
self._return_unchanged = (
"clone" if self._out is not None else True
)
if not self._out:
self._return_unchanged = True
else:
self._out = [
out.lock_() if is_tensor_collection(out) else out
for out in self._out
]
self._return_unchanged = False
return this_out

Expand Down
1 change: 1 addition & 0 deletions tensordict/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2531,6 +2531,7 @@ def _check_inbuild():
else:

def _zip_strict(*iterables):
iterables = tuple(tuple(it) for it in iterables)
lengths = {len(it) for it in iterables}
if len(lengths) > 1:
raise ValueError("lengths of iterables differ.")
Expand Down
Loading
Loading