Skip to content

Commit ebec8ee

Browse files
author
Vincent Moens
committed
[Performance] Faster clone
ghstack-source-id: 60f8c7d Pull Request resolved: #1043
1 parent fe6db77 commit ebec8ee

File tree

9 files changed

+114
-39
lines changed

9 files changed

+114
-39
lines changed

benchmarks/common/h2d_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,12 @@
77

88
import pytest
99
import torch
10+
from packaging import version
1011

1112
from tensordict import TensorDict
1213

14+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
15+
1316

1417
@pytest.fixture
1518
def td():
@@ -50,6 +53,9 @@ def default_device():
5053

5154

5255
@pytest.mark.parametrize("consolidated", [False, True])
56+
@pytest.mark.skipif(
57+
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
58+
)
5359
class TestTo:
5460
def test_to(self, benchmark, consolidated, td, default_device):
5561
if consolidated:

benchmarks/compile/compile_td_test.py

Lines changed: 26 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66

77
import pytest
88
import torch
9+
from packaging import version
910
from tensordict import LazyStackedTensorDict, tensorclass, TensorDict
1011
from torch.utils._pytree import tree_map
1112

12-
TORCH_VERSION = torch.__version__
13+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1314

1415

1516
@tensorclass
@@ -106,7 +107,9 @@ def get_flat_tc():
106107

107108

108109
# Tests runtime of a simple arithmetic op over a highly nested tensordict
109-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
110+
@pytest.mark.skipif(
111+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
112+
)
110113
@pytest.mark.parametrize("mode", ["compile", "eager"])
111114
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
112115
def test_compile_add_one_nested(mode, dict_type, benchmark):
@@ -128,7 +131,9 @@ def test_compile_add_one_nested(mode, dict_type, benchmark):
128131

129132

130133
# Tests the speed of copying a nested tensordict
131-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
134+
@pytest.mark.skipif(
135+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
136+
)
132137
@pytest.mark.parametrize("mode", ["compile", "eager"])
133138
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
134139
def test_compile_copy_nested(mode, dict_type, benchmark):
@@ -150,7 +155,9 @@ def test_compile_copy_nested(mode, dict_type, benchmark):
150155

151156

152157
# Tests runtime of a simple arithmetic op over a flat tensordict
153-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
158+
@pytest.mark.skipif(
159+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
160+
)
154161
@pytest.mark.parametrize("mode", ["compile", "eager"])
155162
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
156163
def test_compile_add_one_flat(mode, dict_type, benchmark):
@@ -177,7 +184,9 @@ def test_compile_add_one_flat(mode, dict_type, benchmark):
177184
benchmark(func, td)
178185

179186

180-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
187+
@pytest.mark.skipif(
188+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
189+
)
181190
@pytest.mark.parametrize("mode", ["eager", "compile"])
182191
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
183192
def test_compile_add_self_flat(mode, dict_type, benchmark):
@@ -207,7 +216,9 @@ def test_compile_add_self_flat(mode, dict_type, benchmark):
207216

208217

209218
# Tests the speed of copying a flat tensordict
210-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
219+
@pytest.mark.skipif(
220+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
221+
)
211222
@pytest.mark.parametrize("mode", ["compile", "eager"])
212223
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
213224
def test_compile_copy_flat(mode, dict_type, benchmark):
@@ -235,7 +246,9 @@ def test_compile_copy_flat(mode, dict_type, benchmark):
235246

236247

237248
# Tests the speed of assigning entries to an empty tensordict
238-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
249+
@pytest.mark.skipif(
250+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
251+
)
239252
@pytest.mark.parametrize("mode", ["compile", "eager"])
240253
@pytest.mark.parametrize("dict_type", ["tensordict", "pytree"])
241254
def test_compile_assign_and_add(mode, dict_type, benchmark):
@@ -264,7 +277,9 @@ def test_compile_assign_and_add(mode, dict_type, benchmark):
264277
# Tests the speed of assigning entries to a lazy stacked tensordict
265278

266279

267-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
280+
@pytest.mark.skipif(
281+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
282+
)
268283
@pytest.mark.skipif(
269284
torch.cuda.is_available(), reason="max recursion depth error with cuda"
270285
)
@@ -285,7 +300,9 @@ def test_compile_assign_and_add_stack(mode, benchmark):
285300

286301

287302
# Tests indexing speed
288-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
303+
@pytest.mark.skipif(
304+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
305+
)
289306
@pytest.mark.parametrize("mode", ["compile", "eager"])
290307
@pytest.mark.parametrize("dict_type", ["tensordict", "tensorclass", "pytree"])
291308
@pytest.mark.parametrize("index_type", ["tensor", "slice", "int"])

benchmarks/compile/tensordict_nn_test.py

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99

1010
import pytest
1111
import torch
12+
13+
from packaging import version
1214
from tensordict import TensorDict, TensorDictParams
1315

1416
from tensordict.nn import TensorDictModule as Mod, TensorDictSequential as Seq
1517

16-
sys.setrecursionlimit(10000)
18+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1719

18-
TORCH_VERSION = torch.__version__
20+
sys.setrecursionlimit(10000)
1921

2022
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
2123

@@ -49,7 +51,9 @@ def mlp(device, depth=2, num_cells=32, feature_dim=3):
4951
)
5052

5153

52-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
54+
@pytest.mark.skipif(
55+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
56+
)
5357
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
5458
def test_mod_add(mode, benchmark):
5559
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -64,7 +68,9 @@ def test_mod_add(mode, benchmark):
6468
benchmark(module, td)
6569

6670

67-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
71+
@pytest.mark.skipif(
72+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
73+
)
6874
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
6975
def test_mod_wrap(mode, benchmark):
7076
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -80,7 +86,9 @@ def test_mod_wrap(mode, benchmark):
8086
benchmark(module, td)
8187

8288

83-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
89+
@pytest.mark.skipif(
90+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
91+
)
8492
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
8593
def test_mod_wrap_and_backward(mode, benchmark):
8694
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -104,7 +112,9 @@ def module_exec(td):
104112
benchmark(module_exec, td)
105113

106114

107-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
115+
@pytest.mark.skipif(
116+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
117+
)
108118
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
109119
def test_seq_add(mode, benchmark):
110120
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -129,7 +139,9 @@ def delhidden(td):
129139
benchmark(module, td)
130140

131141

132-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
142+
@pytest.mark.skipif(
143+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
144+
)
133145
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
134146
def test_seq_wrap(mode, benchmark):
135147
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -161,7 +173,9 @@ def delhidden(td):
161173
benchmark(module, td)
162174

163175

164-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
176+
@pytest.mark.skipif(
177+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
178+
)
165179
@pytest.mark.slow
166180
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
167181
def test_seq_wrap_and_backward(mode, benchmark):
@@ -201,7 +215,9 @@ def module_exec(td):
201215
benchmark(module_exec, td)
202216

203217

204-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
218+
@pytest.mark.skipif(
219+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
220+
)
205221
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
206222
@pytest.mark.parametrize("functional", [False, True])
207223
def test_func_call_runtime(mode, functional, benchmark):
@@ -272,7 +288,9 @@ def call(x, td):
272288
benchmark(call, x)
273289

274290

275-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
291+
@pytest.mark.skipif(
292+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
293+
)
276294
@pytest.mark.slow
277295
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
278296
@pytest.mark.parametrize(
@@ -354,7 +372,9 @@ def call(x, td):
354372
benchmark(call_vmap, x, td)
355373

356374

357-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
375+
@pytest.mark.skipif(
376+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
377+
)
358378
@pytest.mark.slow
359379
@pytest.mark.parametrize("mode", ["eager", "compile", "compile-overhead"])
360380
@pytest.mark.parametrize("plain_decorator", [None, False, True])

tensordict/_td.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3009,6 +3009,9 @@ def is_contiguous(self) -> bool:
30093009
return all([value.is_contiguous() for _, value in self.items()])
30103010

30113011
def _clone(self, recurse: bool = True) -> T:
3012+
if recurse and self.device is not None:
3013+
return self._clone_recurse()
3014+
30123015
result = TensorDict._new_unsafe(
30133016
source={key: _clone_value(value, recurse) for key, value in self.items()},
30143017
batch_size=self.batch_size,

tensordict/base.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8123,6 +8123,26 @@ def cosh_(self) -> T:
81238123
torch._foreach_cosh_(self._values_list(True, True))
81248124
return self
81258125

8126+
def _clone_recurse(self) -> TensorDictBase: # noqa: D417
8127+
keys, vals = self._items_list(True, True)
8128+
if all(type(val) is torch.Tensor and not val.requires_grad for val in vals):
8129+
vals = torch._foreach_add(vals, 0)
8130+
else:
8131+
vals = (val.clone() if hasattr(val, "clone") else val for val in vals)
8132+
items = dict(zip(keys, vals))
8133+
result = self._fast_apply(
8134+
lambda name, val: items.pop(name, None),
8135+
named=True,
8136+
nested_keys=True,
8137+
is_leaf=_NESTED_TENSORS_AS_LISTS,
8138+
propagate_lock=False,
8139+
filter_empty=True,
8140+
default=None,
8141+
)
8142+
if items:
8143+
result.update(items)
8144+
return result
8145+
81268146
def add(
81278147
self,
81288148
other: TensorDictBase | torch.Tensor,

tensordict/nn/cudagraphs.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,8 @@ def _call(
267267
"The output of the function must be a tensordict, a tensorclass or None. Got "
268268
f"type(out)={type(out)}."
269269
)
270+
if is_tensor_collection(out):
271+
out.lock_()
270272
self._out = out
271273
self.counter += 1
272274
if self._out_matches_in:
@@ -302,14 +304,15 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
302304
torch._foreach_copy_(dests, srcs)
303305
torch.cuda.synchronize()
304306
self.graph.replay()
305-
if self._return_unchanged == "clone":
306-
result = self._out.clone()
307-
elif self._return_unchanged:
307+
if self._return_unchanged:
308308
result = self._out
309309
else:
310-
result = tree_map(
311-
lambda x: x.detach().clone() if x is not None else x,
312-
self._out,
310+
result = tree_unflatten(
311+
[
312+
out.clone() if hasattr(out, "clone") else out
313+
for out in self._out
314+
],
315+
self._out_struct,
313316
)
314317
return result
315318

@@ -340,7 +343,7 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
340343
self.graph = torch.cuda.CUDAGraph()
341344
with torch.cuda.graph(self.graph):
342345
out = self.module(*self._args, **self._kwargs)
343-
self._out = out
346+
self._out, self._out_struct = tree_flatten(out)
344347
self.counter += 1
345348
# Check that there is not intersection between the indentity of inputs and outputs, otherwise warn
346349
# user.
@@ -356,11 +359,13 @@ def _call(*args: torch.Tensor, **kwargs: torch.Tensor):
356359
f"and the identity between input and output will not match anymore. "
357360
f"Make sure you don't rely on input-output identity further in the code."
358361
)
359-
if isinstance(self._out, torch.Tensor) or self._out is None:
360-
self._return_unchanged = (
361-
"clone" if self._out is not None else True
362-
)
362+
if not self._out:
363+
self._return_unchanged = True
363364
else:
365+
self._out = [
366+
out.lock_() if is_tensor_collection(out) else out
367+
for out in self._out
368+
]
364369
self._return_unchanged = False
365370
return this_out
366371

test/test_compile.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333

3434
from torch.utils._pytree import SUPPORTED_NODES, tree_map
3535

36-
TORCH_VERSION = version.parse(torch.__version__).base_version
36+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
3737

3838
_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None
3939

@@ -53,7 +53,9 @@ def func(x, y):
5353
funcv_c(x, y)
5454

5555

56-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
56+
@pytest.mark.skipif(
57+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
58+
)
5759
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
5860
class TestTD:
5961
def test_tensor_output(self, mode):
@@ -340,7 +342,9 @@ class MyClass:
340342
c: Any = None
341343

342344

343-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
345+
@pytest.mark.skipif(
346+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
347+
)
344348
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
345349
class TestTC:
346350
def test_tc_tensor_output(self, mode):
@@ -579,7 +583,9 @@ def locked_op(tc):
579583
assert (tc_op == tc_op_c).all()
580584

581585

582-
@pytest.mark.skipif(TORCH_VERSION < "2.4", reason="requires torch>=2.4")
586+
@pytest.mark.skipif(
587+
TORCH_VERSION < version.parse("2.4.0"), reason="requires torch>=2.4"
588+
)
583589
@pytest.mark.parametrize("mode", [None, "reduce-overhead"])
584590
class TestNN:
585591
def test_func(self, mode):

test/test_distributed.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ def test_fsdp_module(self, tmpdir):
108108

109109

110110
# not using TorchVersion to make the comparison work with dev
111-
TORCH_VERSION = version.parse(
112-
".".join(map(str, version.parse(torch.__version__).release))
113-
)
111+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
114112

115113

116114
@pytest.mark.skipif(

test/test_tensordict.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,7 @@
9292
_has_h5py = True
9393
except ImportError:
9494
_has_h5py = False
95-
TORCH_VERSION = version.parse(torch.__version__).base_version
95+
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
9696

9797
_has_onnx = importlib.util.find_spec("onnxruntime", None) is not None
9898

0 commit comments

Comments
 (0)