Skip to content

Commit c21839d

Browse files
author
Vincent Moens
committed
[Performance] Make _to_consolidated compatible with compile
ghstack-source-id: c69db5f Pull Request resolved: #1041
1 parent ee49fc7 commit c21839d

File tree

2 files changed

+193
-21
lines changed

2 files changed

+193
-21
lines changed

benchmarks/common/h2d_test.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,20 +52,32 @@ def default_device():
5252
pytest.skip("CUDA/MPS is not available")
5353

5454

55-
@pytest.mark.parametrize("consolidated", [False, True])
55+
@pytest.mark.parametrize("consolidated,compiled", [[False,False], [True,False],[True,True]])
5656
@pytest.mark.skipif(
5757
TORCH_VERSION < version.parse("2.5.0"), reason="requires torch>=2.5"
5858
)
5959
class TestTo:
60-
def test_to(self, benchmark, consolidated, td, default_device):
60+
def test_to(self, benchmark, consolidated, td, default_device, compiled):
6161
if consolidated:
6262
td = td.consolidate()
63-
benchmark(lambda: td.to(default_device))
63+
def to(td):
64+
return td.to(default_device)
6465

65-
def test_to_njt(self, benchmark, consolidated, njt_td, default_device):
66+
if compiled:
67+
to = torch.compile(to)
68+
69+
benchmark(to, td)
70+
71+
def test_to_njt(self, benchmark, consolidated, njt_td, default_device, compiled):
6672
if consolidated:
6773
njt_td = njt_td.consolidate()
68-
benchmark(lambda: njt_td.to(default_device))
74+
def to(td):
75+
return td.to(default_device)
76+
77+
if compiled:
78+
to = torch.compile(to)
79+
80+
benchmark(to, njt_td)
6981

7082

7183
if __name__ == "__main__":

tensordict/base.py

Lines changed: 176 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -3521,9 +3521,10 @@ def _reduce_vals_and_metadata(self, *, dtype=NO_DEFAULT, requires_metadata):
35213521

35223522
flat_size = []
35233523
start = 0
3524+
sorting_index = 0
35243525

35253526
def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
3526-
nonlocal start
3527+
nonlocal start, sorting_index
35273528
n = value.element_size() * value.numel()
35283529
if need_padding:
35293530
pad = n % 8
@@ -3541,7 +3542,10 @@ def add_single_value(value, key, metadata_dict, dtype, shape, flat_size):
35413542
start,
35423543
stop,
35433544
pad,
3545+
flat_size[-1],
3546+
sorting_index,
35443547
)
3548+
sorting_index = sorting_index + 1
35453549
start = stop
35463550

35473551
def assign(
@@ -10441,6 +10445,7 @@ def to(self, *args, **kwargs) -> T:
1044110445
pin_memory=non_blocking_pin,
1044210446
num_threads=num_threads,
1044310447
non_blocking=non_blocking,
10448+
compilable=is_dynamo_compiling(),
1044410449
)
1044510450

1044610451
if non_blocking is None:
@@ -10498,14 +10503,49 @@ def to_pinmem(tensor, _to=to):
1049810503
self._sync_all()
1049910504
return result
1050010505

10501-
def _to_consolidated(self, *, device, pin_memory, num_threads, non_blocking):
10506+
def _to_consolidated(
10507+
self, *, device, pin_memory, num_threads, non_blocking, compilable
10508+
):
1050210509
if num_threads is None:
1050310510
# unspecified num_threads should mean 0
1050410511
num_threads = 0
10512+
1050510513
storage = self._consolidated["storage"]
10506-
if pin_memory:
10507-
storage = storage.pin_memory()
10508-
storage_cast = storage.to(device, non_blocking=True)
10514+
10515+
@torch.compiler.disable()
10516+
def to(storage):
10517+
if pin_memory:
10518+
storage = storage.pin_memory()
10519+
storage_cast = storage.to(device, non_blocking=True)
10520+
return storage_cast
10521+
10522+
storage_cast = to(storage)
10523+
10524+
if compilable:
10525+
result = self._to_consolidated_compile(
10526+
device=device, num_threads=num_threads, storage_cast=storage_cast
10527+
)
10528+
else:
10529+
result = self._to_consolidated_eager(
10530+
device=device, num_threads=num_threads, storage_cast=storage_cast
10531+
)
10532+
10533+
if non_blocking in (False, None):
10534+
if device.type == "cuda" and non_blocking is False:
10535+
# sending to CUDA force sync
10536+
cuda_device = device
10537+
elif storage.device.type == "cuda":
10538+
# sending from cuda: need sync unless intentionally not asked for
10539+
cuda_device = storage.device.type
10540+
else:
10541+
cuda_device = None
10542+
if cuda_device is not None:
10543+
torch.cuda.current_stream(cuda_device).synchronize()
10544+
10545+
return result
10546+
10547+
def _to_consolidated_eager(self, *, device, num_threads, storage_cast):
10548+
1050910549
untyped_storage = storage_cast.untyped_storage()
1051010550

1051110551
def set_(x):
@@ -10574,18 +10614,138 @@ def copy_dict(d):
1057410614
}
1057510615

1057610616
result._consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10577-
if non_blocking in (False, None):
10578-
if device.type == "cuda" and non_blocking is False:
10579-
# sending to CUDA force sync
10580-
cuda_device = device
10581-
elif storage.device.type == "cuda":
10582-
# sending from cuda: need sync unless intentionally not asked for
10583-
cuda_device = storage.device.type
10584-
else:
10585-
cuda_device = None
10586-
if cuda_device is not None:
10587-
torch.cuda.current_stream(cuda_device).synchronize()
10617+
return result
10618+
10619+
def _to_consolidated_compile(self, *, device, num_threads, storage_cast):
10620+
10621+
def get_tensors_length(metadata, lengths=None, pos=None, keys=None, prefix=()):
10622+
root = False
10623+
if lengths is None:
10624+
lengths = []
10625+
pos = []
10626+
keys = []
10627+
root = True
10628+
for k, v in metadata["leaves"].items():
10629+
lengths.append(v[-2])
10630+
pos.append(v[-1])
10631+
keys.append(prefix + (k,))
10632+
for k, d in metadata.items():
10633+
if "leaves" in d:
10634+
get_tensors_length(
10635+
d, lengths=lengths, pos=pos, keys=keys, prefix=prefix + (k,)
10636+
)
10637+
if root:
10638+
# l = torch.empty(len(lengths), dtype=torch.long)
10639+
# l[torch.as_tensor(pos)] = torch.as_tensor(lengths)
10640+
out0 = [
10641+
None,
10642+
] * len(pos)
10643+
out1 = [
10644+
None,
10645+
] * len(pos)
10646+
for p, l, k in zip(pos, lengths, keys):
10647+
out0[p] = k
10648+
out1[p] = l
10649+
return out0, out1
10650+
10651+
def split_storage(consolidated):
10652+
keys, splits = get_tensors_length(consolidated["metadata"])
10653+
return dict(zip(keys, consolidated["storage"].split(splits)))
10654+
10655+
if num_threads is None:
10656+
# unspecified num_threads should mean 0
10657+
num_threads = 0
10658+
10659+
_consolidated = {"storage": storage_cast}
10660+
if "metadata" in self._consolidated:
10661+
# faster than deepcopy
10662+
def copy_dict(d):
10663+
return {
10664+
k: v if not isinstance(v, dict) else copy_dict(v)
10665+
for k, v in d.items()
10666+
}
10667+
10668+
_consolidated["metadata"] = copy_dict(self._consolidated["metadata"])
10669+
10670+
slice_map = split_storage(_consolidated)
10671+
10672+
def view_as(src, dest):
10673+
return src.view(dest.dtype)[: dest.numel()].view(dest.shape)
1058810674

10675+
def set_(name, x):
10676+
if not isinstance(name, tuple):
10677+
name = (name,)
10678+
if x.is_nested:
10679+
from torch._subclasses.fake_tensor import FakeTensor
10680+
from torch._subclasses.functional_tensor import FunctionalTensor
10681+
from torch.nested._internal.nested_tensor import (
10682+
_tensor_symint_registry,
10683+
NestedTensor,
10684+
)
10685+
from torch.nested._internal.ops import extract_kwargs
10686+
10687+
if x.layout != torch.jagged:
10688+
raise RuntimeError(
10689+
"to(device) with nested tensors that do not have a jagged layout is not implemented yet. "
10690+
"Please raise an issue on GitHub."
10691+
)
10692+
kwargs = extract_kwargs(x)
10693+
values = x._values
10694+
lengths = x._lengths
10695+
offsets = x._offsets
10696+
storage_offsets = slice_map[
10697+
(
10698+
*name[:-1],
10699+
"<NJT_OFFSETS>" + name[-1],
10700+
)
10701+
]
10702+
kwargs["offsets"] = view_as(storage_offsets, offsets)
10703+
if lengths is not None:
10704+
storage_lengths = slice_map[
10705+
(
10706+
*name[:-1],
10707+
"<NJT_LENGTHS>" + name[-1],
10708+
)
10709+
]
10710+
kwargs["lengths"] = view_as(storage_lengths, lengths)
10711+
ragged_source = lengths
10712+
else:
10713+
ragged_source = offsets
10714+
new_thing = kwargs.get("lengths", kwargs.get("offsets"))
10715+
if isinstance(new_thing, (FakeTensor, FunctionalTensor)):
10716+
from torch._subclasses.functional_tensor import (
10717+
mb_unwrap_functional_tensor,
10718+
)
10719+
10720+
# Temporary hack until we have the union find
10721+
tgt = mb_unwrap_functional_tensor(new_thing)
10722+
src = mb_unwrap_functional_tensor(ragged_source)
10723+
tgt.nested_int_memo = src.nested_int_memo
10724+
else:
10725+
_tensor_symint_registry[new_thing] = _tensor_symint_registry[
10726+
ragged_source
10727+
]
10728+
10729+
storage_values = slice_map[
10730+
(
10731+
*name[:-1],
10732+
"<NJT_VALUES>" + name[-1],
10733+
)
10734+
]
10735+
return NestedTensor(
10736+
view_as(storage_values, values),
10737+
**kwargs,
10738+
)
10739+
return view_as(slice_map[name], x)
10740+
10741+
result = self._fast_apply(
10742+
set_,
10743+
device=torch.device(device),
10744+
num_threads=num_threads,
10745+
named=True,
10746+
nested_keys=True,
10747+
)
10748+
result._consolidated = _consolidated
1058910749
return result
1059010750

1059110751
def _sync_all(self):

0 commit comments

Comments
 (0)