Skip to content

Commit ee49fc7

Browse files
author
Vincent Moens
committed
[Performance] Faster clone
ghstack-source-id: 14d5586 Pull Request resolved: #1043
1 parent fd400af commit ee49fc7

File tree

10 files changed

+158
-46
lines changed

10 files changed

+158
-46
lines changed

benchmarks/common/h2d_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
import pytest
99
import torch
10+
from packaging import version
1011

1112
from tensordict import TensorDict
1213

14+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
15+
1316

1417
@pytest.fixture
1518
def td():
@@ -50,6 +53,9 @@ def default_device():
5053

5154

5255
@pytest.mark.parametrize("consolidated", [False, True])
56+
@pytest.mark.skipif(
57+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
58+
)
5359
class TestTo:
5460
def test_to(self, benchmark, consolidated, td, default_device):
5561
if consolidated:

benchmarks/compile/compile_td_test.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import pytest
88
import torch
9+
from packaging import version
910
from tensordict import LazyStackedTensorDict, tensorclass, TensorDict
1011
from torch.utils._pytree import tree_map
1112

12-
TORCH_VERSION = torch.__version__
13+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1314

1415

1516
@tensorclass
@@ -106,7 +107,9 @@ def get_flat_tc():
106107

107108

108109
# Tests runtime of a simple arithmetic op over a highly nested tensordict
109-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
110+
@pytest.mark.skipif(
111+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
112+
)
110113
@pytest.mark.parametrize("mode", ["compile", "eager"])
111114
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
112115
def test_compile_add_one_nested(mode, dict_type, benchmark):
@@ -128,7 +131,9 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):
128131

129132

130133
# Tests the speed of copying a nested tensordict
131-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
134+
@pytest.mark.skipif(
135+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
136+
)
132137
@pytest.mark.parametrize("mode", ["compile", "eager"])
133138
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
134139
def test_compile_copy_nested(mode, dict_type, benchmark):
@@ -150,7 +155,9 @@ def test_compile_copy_nested(mode, dict_type, benchmark):
150155

151156

152157
# Tests runtime of a simple arithmetic op over a flat tensordict
153-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
158+
@pytest.mark.skipif(
159+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
160+
)
154161
@pytest.mark.parametrize("mode", ["compile", "eager"])
155162
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
156163
def test_compile_add_one_flat(mode, dict_type, benchmark):
@@ -177,7 +184,9 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
177184
benchmark(func, td)
178185

179186

180-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
187+
@pytest.mark.skipif(
188+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
189+
)
181190
@pytest.mark.parametrize("mode", ["eager", "compile"])
182191
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
183192
def test_compile_add_self_flat(mode, dict_type, benchmark):
@@ -207,7 +216,9 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):
207216

208217

209218
# Tests the speed of copying a flat tensordict
210-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
219+
@pytest.mark.skipif(
220+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
221+
)
211222
@pytest.mark.parametrize("mode", ["compile", "eager"])
212223
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
213224
def test_compile_copy_flat(mode, dict_type, benchmark):
@@ -235,7 +246,9 @@ def test_compile_copy_flat(mode, dict_type, benchmark):
235246

236247

237248
# Tests the speed of assigning entries to an empty tensordict
238-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
249+
@pytest.mark.skipif(
250+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
251+
)
239252
@pytest.mark.parametrize("mode", ["compile", "eager"])
240253
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
241254
def test_compile_assign_and_add(mode, dict_type, benchmark):
@@ -264,7 +277,9 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
264277
# Tests the speed of assigning entries to a lazy stacked tensordict
265278

266279

267-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
280+
@pytest.mark.skipif(
281+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
282+
)
268283
@pytest.mark.skipif(
269284
torch.cuda.is_available(), reason="max recursion depth error with cuda"
270285
)
@@ -285,7 +300,9 @@ def test_compile_assign_and_add_stack(mode, benchmark):
285300

286301

287302
# Tests indexing speed
288-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
303+
@pytest.mark.skipif(
304+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
305+
)
289306
@pytest.mark.parametrize("mode", ["compile", "eager"])
290307
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
291308
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])

benchmarks/compile/tensordict_nn_test.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
import pytest
1111
import torch
12+
13+
from packaging import version
1214
from tensordict import TensorDict, TensorDictParams
1315

1416
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
1517

16-
sys.setrecursionlimit(10000)
18+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1719

18-
TORCH_VERSION = torch.__version__
20+
sys.setrecursionlimit(10000)
1921

2022
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2123

@@ -49,7 +51,9 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
4951
)
5052

5153

52-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
54+
@pytest.mark.skipif(
55+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
56+
)
5357
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
5458
def test_mod_add(mode, benchmark):
5559
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -64,7 +68,9 @@ def test_mod_add(mode, benchmark):
6468
benchmark(module, td)
6569

6670

67-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
71+
@pytest.mark.skipif(
72+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
73+
)
6874
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
6975
def test_mod_wrap(mode, benchmark):
7076
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -80,7 +86,9 @@ def test_mod_wrap(mode, benchmark):
8086
benchmark(module, td)
8187

8288

83-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
89+
@pytest.mark.skipif(
90+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
91+
)
8492
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
8593
def test_mod_wrap_and_backward(mode, benchmark):
8694
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -104,7 +112,9 @@ def module_exec(td):
104112
benchmark(module_exec, td)
105113

106114

107-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
115+
@pytest.mark.skipif(
116+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
117+
)
108118
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
109119
def test_seq_add(mode, benchmark):
110120
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -129,7 +139,9 @@ def delhidden(td):
129139
benchmark(module, td)
130140

131141

132-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
142+
@pytest.mark.skipif(
143+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
144+
)
133145
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
134146
def test_seq_wrap(mode, benchmark):
135147
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -161,7 +173,9 @@ def delhidden(td):
161173
benchmark(module, td)
162174

163175

164-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
176+
@pytest.mark.skipif(
177+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
178+
)
165179
@pytest.mark.slow
166180
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
167181
def test_seq_wrap_and_backward(mode, benchmark):
@@ -201,7 +215,9 @@ def module_exec(td):
201215
benchmark(module_exec, td)
202216

203217

204-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
218+
@pytest.mark.skipif(
219+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
220+
)
205221
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
206222
@pytest.mark.parametrize("functional", [False, True])
207223
def test_func_call_runtime(mode, functional, benchmark):
@@ -272,7 +288,9 @@ def call(x, td):
272288
benchmark(call, x)
273289

274290

275-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
291+
@pytest.mark.skipif(
292+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
293+
)
276294
@pytest.mark.slow
277295
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
278296
@pytest.mark.parametrize(
@@ -354,7 +372,9 @@ def call(x, td):
354372
benchmark(call_vmap, x, td)
355373

356374

357-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
375+
@pytest.mark.skipif(
376+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
377+
)
358378
@pytest.mark.slow
359379
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
360380
@pytest.mark.parametrize("plain_decorator", [None, False, True])

tensordict/_td.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
30093009
return all([value.is_contiguous() for _, value in self.items()])
30103010

30113011
def _clone(self, recurse: bool = True) -> T:
3012+
if recurse and self.device is not None:
3013+
return self._clone_recurse()
3014+
30123015
result = TensorDict._new_unsafe(
30133016
source={key: _clone_value(value, recurse) for key, value in self.items()},
30143017
batch_size=self.batch_size,

tensordict/base.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8123,6 +8123,52 @@ def cosh_(self) -> T:
81238123
torch._foreach_cosh_(self._values_list(True, True))
81248124
return self
81258125

8126+
def _clone_recurse(self) -> TensorDictBase: # noqa: D417
8127+
keys, vals = self._items_list(True, True)
8128+
foreach_vals = {}
8129+
iter_vals = {}
8130+
for key, val in zip(keys, vals):
8131+
if (
8132+
type(val) is torch.Tensor
8133+
and not val.requires_grad
8134+
and val.dtype not in (torch.bool,)
8135+
):
8136+
foreach_vals[key] = val
8137+
else:
8138+
iter_vals[key] = val
8139+
if foreach_vals:
8140+
foreach_vals = dict(
8141+
_zip_strict(
8142+
foreach_vals.keys(),
8143+
torch._foreach_add(tuple(foreach_vals.values()), 0),
8144+
)
8145+
)
8146+
if iter_vals:
8147+
iter_vals = dict(
8148+
_zip_strict(
8149+
iter_vals.keys(),
8150+
(
8151+
val.clone() if hasattr(val, "clone") else val
8152+
for val in iter_vals.values()
8153+
),
8154+
)
8155+
)
8156+
8157+
items = foreach_vals
8158+
items.update(iter_vals)
8159+
result = self._fast_apply(
8160+
lambda name, val: items.pop(name, None),
8161+
named=True,
8162+
nested_keys=True,
8163+
is_leaf=_NESTED_TENSORS_AS_LISTS,
8164+
propagate_lock=False,
8165+
filter_empty=True,
8166+
default=None,
8167+
)
8168+
if items:
8169+
result.update(items)
8170+
return result
8171+
81268172
def add(
81278173
self,
81288174
other: TensorDictBase | torch.Tensor,

tensordict/nn/cudagraphs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def _call(
267267
"The output of the function must be a tensordict, a tensorclass or None. Got "
268268
f"type(out)={type(out)}."
269269
)
270+
if is_tensor_collection(out):
271+
out.lock_()
270272
self._out = out
271273
self.counter += 1
272274
if self._out_matches_in:
@@ -302,14 +304,15 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302304
torch._foreach_copy_(dests, srcs)
303305
torch.cuda.synchronize()
304306
self.graph.replay()
305-
if self._return_unchanged == "clone":
306-
result = self._out.clone()
307-
elif self._return_unchanged:
307+
if self._return_unchanged:
308308
result = self._out
309309
else:
310-
result = tree_map(
311-
lambda x: x.detach().clone() if x is not None else x,
312-
self._out,
310+
result = tree_unflatten(
311+
[
312+
out.clone() if hasattr(out, "clone") else out
313+
for out in self._out
314+
],
315+
self._out_struct,
313316
)
314317
return result
315318

@@ -340,7 +343,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340343
self.graph = torch.cuda.CUDAGraph()
341344
with torch.cuda.graph(self.graph):
342345
out = self.module(*self._args, **self._kwargs)
343-
self._out = out
346+
self._out, self._out_struct = tree_flatten(out)
344347
self.counter += 1
345348
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346349
# user.
@@ -356,11 +359,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356359
f"and the identity between input and output will not match anymore. "
357360
f"Make sure you don't rely on input-output identity further in the code."
358361
)
359-
if isinstance(self._out, torch.Tensor) or self._out is None:
360-
self._return_unchanged = (
361-
"clone" if self._out is not None else True
362-
)
362+
if not self._out:
363+
self._return_unchanged = True
363364
else:
365+
self._out = [
366+
out.lock_() if is_tensor_collection(out) else out
367+
for out in self._out
368+
]
364369
self._return_unchanged = False
365370
return this_out
366371

tensordict/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2531,6 +2531,7 @@ def _check_inbuild():
25312531
else:
25322532

25332533
def _zip_strict(*iterables):
2534+
iterables = tuple(tuple(it) for it in iterables)
25342535
lengths = {len(it) for it in iterables}
25352536
if len(lengths) > 1:
25362537
raise ValueError("lengths of iterables differ.")

0 commit comments

Comments
 (0)