Skip to content

Commit 298d6d1

Browse files
author
Vincent Moens
committed
[Refactor] Put values, lengths and offsets of NJTs together in storage
ghstack-source-id: f3d8ed7 Pull Request resolved: #1023
1 parent e4174f1 commit 298d6d1

File tree

4 files changed

+176
-94
lines changed

4 files changed

+176
-94
lines changed

tensordict/_reductions.py

Lines changed: 64 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from __future__ import annotations
66

77
import copyreg
8+
import queue
89
from multiprocessing.reduction import ForkingPickler
910

1011
import torch
@@ -21,7 +22,17 @@
2122

2223

2324
def _rebuild_tensordict_files(flat_key_values, metadata_dict, is_shared: bool = False):
25+
_nt_values_and_keys = queue.Queue()
26+
_nt_lengths = queue.Queue()
27+
_nt_offsets = queue.Queue()
28+
2429
def from_metadata(metadata=metadata_dict, prefix=None):
30+
metadata = dict(metadata)
31+
32+
_ = metadata.pop("njt_values_start", None)
33+
_ = metadata.pop("njt_lengths_start", None)
34+
_ = metadata.pop("njt_offsets_start", None)
35+
2536
non_tensor = metadata.pop("non_tensors")
2637
leaves = metadata.pop("leaves")
2738
cls = metadata.pop("cls")
@@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None):
3647
total_key = (key,) if prefix is None else prefix + (key,)
3748
if total_key[-1].startswith("<NJT>"):
3849
nested_values = flat_key_values[total_key]
39-
nested_lengths = None
50+
total_key = total_key[:-1] + total_key[-1].replace("<NJT>", "")
51+
_nt_values_and_keys.put((nested_values, total_key))
4052
continue
4153
if total_key[-1].startswith("<NJT_LENGTHS>"):
4254
nested_lengths = flat_key_values[total_key]
55+
_nt_lengths.put(nested_lengths)
4356
continue
4457
elif total_key[-1].startswith("<NJT_OFFSETS"):
4558
offsets = flat_key_values[total_key]
46-
key = key.replace("<NJT_OFFSETS>", "")
47-
value = torch.nested.nested_tensor_from_jagged(
48-
nested_values, offsets=offsets, lengths=nested_lengths
49-
)
50-
del nested_values
51-
del nested_lengths
59+
_nt_offsets.put(offsets)
60+
continue
5261
else:
5362
value = flat_key_values[total_key]
5463
d[key] = value
64+
5565
for k, v in metadata.items():
5666
# Each remaining key is a tuple pointing to a sub-tensordict
5767
d[k] = from_metadata(
@@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None):
6474
# result._is_shared = is_shared
6575
return result
6676

67-
return from_metadata()
77+
result = from_metadata()
78+
# Then assign the nested tensors
79+
while not _nt_values_and_keys.empty():
80+
vals, key = _nt_values_and_keys.get()
81+
lengths = _nt_lengths.get()
82+
offsets = _nt_offsets.get()
83+
value = torch.nested.nested_tensor_from_jagged(
84+
vals, offsets=offsets, lengths=lengths
85+
)
86+
result._set_tuple(key, value, inplace=False, validated=True)
87+
88+
return result
6889

6990

7091
def _rebuild_tensordict_files_shared(flat_key_values, metadata_dict):
@@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated(
7596
metadata,
7697
storage,
7798
):
99+
_nt_values_and_keys = queue.Queue()
100+
_nt_lengths = queue.Queue()
101+
_nt_offsets = queue.Queue()
102+
78103
def from_metadata(metadata=metadata, prefix=None):
79104
consolidated = {"storage": storage, "metadata": metadata}
80105
metadata = dict(metadata)
106+
107+
_ = metadata.pop("njt_values_start", None)
108+
_ = metadata.pop("njt_lengths_start", None)
109+
_ = metadata.pop("njt_offsets_start", None)
110+
81111
non_tensor = metadata.pop("non_tensors")
82112
leaves = metadata.pop("leaves")
83113
cls = metadata.pop("cls")
@@ -99,31 +129,45 @@ def from_metadata(metadata=metadata, prefix=None):
99129
value = value[: local_shape.numel()]
100130
value = value.view(local_shape)
101131
if key.startswith("<NJT>"):
102-
nested_values = value
103-
nested_lengths = None
132+
key = key.replace("<NJT>", "")
133+
if prefix:
134+
total_key = prefix + (key,)
135+
else:
136+
total_key = (key,)
137+
_nt_values_and_keys.put((value, total_key))
104138
continue
105139
elif key.startswith("<NJT_LENGTHS>"):
106-
nested_lengths = value
140+
_nt_lengths.put(value)
107141
continue
108142
elif key.startswith("<NJT_OFFSETS>"):
109-
offsets = value
110-
value = torch.nested.nested_tensor_from_jagged(
111-
nested_values, offsets=offsets, lengths=nested_lengths
112-
)
113-
key = key.replace("<NJT_OFFSETS>", "")
143+
_nt_offsets.put(value)
144+
if _nt_offsets.qsize() > _nt_lengths.qsize():
145+
_nt_lengths.put(None)
146+
continue
114147
d[key] = value
115-
for k, v in metadata.items():
148+
for key, val in metadata.items():
116149
# Each remaining key is a tuple pointing to a sub-tensordict
117-
d[k] = from_metadata(
118-
v, prefix=prefix + (k,) if prefix is not None else (k,)
150+
d[key] = from_metadata(
151+
val, prefix=prefix + (key,) if prefix is not None else (key,)
119152
)
120153
result = CLS_MAP[cls]._from_dict_validated(d, **cls_metadata)
121154
if is_locked:
122155
result = result.lock_()
123156
result._consolidated = consolidated
124157
return result
125158

126-
return from_metadata()
159+
result = from_metadata()
160+
# Then assign the nested tensors
161+
while not _nt_values_and_keys.empty():
162+
vals, key = _nt_values_and_keys.get()
163+
lengths = _nt_lengths.get()
164+
offsets = _nt_offsets.get()
165+
value = torch.nested.nested_tensor_from_jagged(
166+
vals, offsets=offsets, lengths=lengths
167+
)
168+
result._set_tuple(key, value, inplace=False, validated=True)
169+
170+
return result
127171

128172

129173
def _make_td(cls, state):

0 commit comments

Comments
 (0)