Skip to content

Commit 8a9de10

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: bb7f342 Pull Request resolved: #1041
1 parent ee49fc7 commit 8a9de10

File tree

4 files changed

+217
-22
lines changed

4 files changed

+217
-22
lines changed

benchmarks/common/h2d_test.py

Lines changed: 33 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,12 @@
1414
TORCH_VERSION = version.parse(version.parse(torch.__version__).base_version)
1515

1616

17+
@pytest.fixture(autouse=True, scope="function")
18+
def empty_compiler_cache():
19+
torch._dynamo.reset_code_caches()
20+
yield
21+
22+
1723
@pytest.fixture
1824
def td():
1925
return TensorDict(
@@ -52,20 +58,42 @@ def default_device():
5258
pytest.skip("CUDA/MPS is not available")
5359

5460

55-
@pytest.mark.parametrize("consolidated", [False, True])
61+
@pytest.mark.parametrize(
62+
"consolidated,compiled", [[False, False], [True, False], [True, True]]
63+
)
5664
@pytest.mark.skipif(
5765
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5866
)
5967
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
68+
def test_to(self, benchmark, consolidated, td, default_device, compiled):
6169
if consolidated:
6270
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
6471

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
72+
def to(td):
73+
return td.to(default_device)
74+
75+
if compiled:
76+
to = torch.compile(to)
77+
78+
for _ in range(3):
79+
to(td)
80+
81+
benchmark(to, td)
82+
83+
def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled):
6684
if consolidated:
6785
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
86+
87+
def to(td):
88+
return td.to(default_device)
89+
90+
if compiled:
91+
to = torch.compile(to)
92+
93+
for _ in range(3):
94+
to(njt_td)
95+
96+
benchmark(to, njt_td)
6997

7098

7199
if __name__ == "__main__":

benchmarks/compile/compile_td_test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,12 @@ class MyTensorClass:
2323
f: torch.Tensor
2424

2525

26+
@pytest.fixture(autouse=True, scope="function")
27+
def empty_compiler_cache():
28+
torch._dynamo.reset_code_caches()
29+
yield
30+
31+
2632
# Functions
2733
def add_one(td):
2834
return td + 1

tensordict/base.py

Lines changed: 170 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@
7272
_split_tensordict,
7373
_td_fields,
7474
_unravel_key_to_tuple,
75-
_zip_strict,
75+
_zip_strict,_to_escape_compile,
7676
cache,
7777
convert_ellipsis_to_idx,
7878
DeviceType,
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
1044110445
pin_memory=non_blocking_pin,
1044210446
num_threads=num_threads,
1044310447
non_blocking=non_blocking,
10448+
compilable=is_dynamo_compiling(),
1044410449
)
1044510450

1044610451
if non_blocking is None:
@@ -10498,14 +10503,42 @@ def to_pinmem(tensor, _to=to):
1049810503
self._sync_all()
1049910504
return result
1050010505

10501-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10506+
def _to_consolidated(
10507+
self, *, device, pin_memory, num_threads, non_blocking, compilable
10508+
):
1050210509
if num_threads is None:
1050310510
# unspecified num_threads should mean 0
1050410511
num_threads = 0
10512+
1050510513
storage = self._consolidated["storage"]
10506-
if pin_memory:
10507-
storage = storage.pin_memory()
10508-
storage_cast = storage.to(device, non_blocking=True)
10514+
10515+
storage_cast = _to_escape_compile(storage)
10516+
10517+
if compilable:
10518+
result = self._to_consolidated_compile(
10519+
device=device, num_threads=num_threads, storage_cast=storage_cast
10520+
)
10521+
else:
10522+
result = self._to_consolidated_eager(
10523+
device=device, num_threads=num_threads, storage_cast=storage_cast
10524+
)
10525+
10526+
if non_blocking in (False, None):
10527+
if device.type == "cuda" and non_blocking is False:
10528+
# sending to CUDA force sync
10529+
cuda_device = device
10530+
elif storage.device.type == "cuda":
10531+
# sending from cuda: need sync unless intentionally not asked for
10532+
cuda_device = storage.device.type
10533+
else:
10534+
cuda_device = None
10535+
if cuda_device is not None:
10536+
torch.cuda.current_stream(cuda_device).synchronize()
10537+
10538+
return result
10539+
10540+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10541+
1050910542
untyped_storage = storage_cast.untyped_storage()
1051010543

1051110544
def set_(x):
@@ -10574,18 +10607,138 @@ def copy_dict(d):
1057410607
}
1057510608

1057610609
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10577-
if non_blocking in (False, None):
10578-
if device.type == "cuda" and non_blocking is False:
10579-
# sending to CUDA force sync
10580-
cuda_device = device
10581-
elif storage.device.type == "cuda":
10582-
# sending from cuda: need sync unless intentionally not asked for
10583-
cuda_device = storage.device.type
10584-
else:
10585-
cuda_device = None
10586-
if cuda_device is not None:
10587-
torch.cuda.current_stream(cuda_device).synchronize()
10610+
return result
10611+
10612+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10613+
10614+
def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
10615+
root = False
10616+
if lengths is None:
10617+
lengths = []
10618+
pos = []
10619+
keys = []
10620+
root = True
10621+
for k, v in metadata["leaves"].items():
10622+
lengths.append(v[-2])
10623+
pos.append(v[-1])
10624+
keys.append(prefix + (k,))
10625+
for k, d in metadata.items():
10626+
if "leaves" in d:
10627+
get_tensors_length(
10628+
d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)
10629+
)
10630+
if root:
10631+
# l = torch.empty(len(lengths), dtype=torch.long)
10632+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10633+
out0 = [
10634+
None,
10635+
] * len(pos)
10636+
out1 = [
10637+
None,
10638+
] * len(pos)
10639+
for p, l, k in zip(pos, lengths, keys):
10640+
out0[p] = k
10641+
out1[p] = l
10642+
return out0, out1
10643+
10644+
def split_storage(consolidated):
10645+
keys, splits = get_tensors_length(consolidated["metadata"])
10646+
return dict(zip(keys, consolidated["storage"].split(splits)))
10647+
10648+
if num_threads is None:
10649+
# unspecified num_threads should mean 0
10650+
num_threads = 0
10651+
10652+
_consolidated = {"storage": storage_cast}
10653+
if "metadata" in self._consolidated:
10654+
# faster than deepcopy
10655+
def copy_dict(d):
10656+
return {
10657+
k: v if not isinstance(v, dict) else copy_dict(v)
10658+
for k, v in d.items()
10659+
}
10660+
10661+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10662+
10663+
slice_map = split_storage(_consolidated)
10664+
10665+
def view_as(src, dest):
10666+
return src.view(dest.dtype)[: dest.numel()].view(dest.shape)
1058810667

10668+
def set_(name, x):
10669+
if not isinstance(name, tuple):
10670+
name = (name,)
10671+
if x.is_nested:
10672+
from torch._subclasses.fake_tensor import FakeTensor
10673+
from torch._subclasses.functional_tensor import FunctionalTensor
10674+
from torch.nested._internal.nested_tensor import (
10675+
_tensor_symint_registry,
10676+
NestedTensor,
10677+
)
10678+
from torch.nested._internal.ops import extract_kwargs
10679+
10680+
if x.layout != torch.jagged:
10681+
raise RuntimeError(
10682+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10683+
"Please raise an issue on GitHub."
10684+
)
10685+
kwargs = extract_kwargs(x)
10686+
values = x._values
10687+
lengths = x._lengths
10688+
offsets = x._offsets
10689+
storage_offsets = slice_map[
10690+
(
10691+
*name[:-1],
10692+
"<NJT_OFFSETS>" + name[-1],
10693+
)
10694+
]
10695+
kwargs["offsets"] = view_as(storage_offsets, offsets)
10696+
if lengths is not None:
10697+
storage_lengths = slice_map[
10698+
(
10699+
*name[:-1],
10700+
"<NJT_LENGTHS>" + name[-1],
10701+
)
10702+
]
10703+
kwargs["lengths"] = view_as(storage_lengths, lengths)
10704+
ragged_source = lengths
10705+
else:
10706+
ragged_source = offsets
10707+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10708+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10709+
from torch._subclasses.functional_tensor import (
10710+
mb_unwrap_functional_tensor,
10711+
)
10712+
10713+
# Temporary hack until we have the union find
10714+
tgt = mb_unwrap_functional_tensor(new_thing)
10715+
src = mb_unwrap_functional_tensor(ragged_source)
10716+
tgt.nested_int_memo = src.nested_int_memo
10717+
else:
10718+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10719+
ragged_source
10720+
]
10721+
10722+
storage_values = slice_map[
10723+
(
10724+
*name[:-1],
10725+
"<NJT_VALUES>" + name[-1],
10726+
)
10727+
]
10728+
return NestedTensor(
10729+
view_as(storage_values, values),
10730+
**kwargs,
10731+
)
10732+
return view_as(slice_map[name], x)
10733+
10734+
result = self._fast_apply(
10735+
set_,
10736+
device=torch.device(device),
10737+
num_threads=num_threads,
10738+
named=True,
10739+
nested_keys=True,
10740+
)
10741+
result._consolidated = _consolidated
1058910742
return result
1059010743

1059110744
def _sync_all(self):

tensordict/utils.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2694,3 +2694,11 @@ def _rebuild_njt_from_njt(x, values, offsets, lengths):
26942694
values,
26952695
**kwargs,
26962696
)
2697+
2698+
2699+
@torch.compiler.disable()
2700+
def _to_escape_compile(storage, device, pin_memory):
2701+
if pin_memory:
2702+
storage = storage.pin_memory()
2703+
storage_cast = storage.to(device, non_blocking=True)
2704+
return storage_cast

0 commit comments

Comments
 (0)