5
5
from __future__ import annotations
6
6
7
7
import copyreg
8
+ import queue
8
9
from multiprocessing .reduction import ForkingPickler
9
10
10
11
import torch
21
22
22
23
23
24
def _rebuild_tensordict_files (flat_key_values , metadata_dict , is_shared : bool = False ):
25
+ _nt_values_and_keys = queue .Queue ()
26
+ _nt_lengths = queue .Queue ()
27
+ _nt_offsets = queue .Queue ()
28
+
24
29
def from_metadata (metadata = metadata_dict , prefix = None ):
30
+ metadata = dict (metadata )
31
+
32
+ _ = metadata .pop ("njt_values_start" , None )
33
+ _ = metadata .pop ("njt_lengths_start" , None )
34
+ _ = metadata .pop ("njt_offsets_start" , None )
35
+
25
36
non_tensor = metadata .pop ("non_tensors" )
26
37
leaves = metadata .pop ("leaves" )
27
38
cls = metadata .pop ("cls" )
@@ -36,22 +47,21 @@ def from_metadata(metadata=metadata_dict, prefix=None):
36
47
total_key = (key ,) if prefix is None else prefix + (key ,)
37
48
if total_key [- 1 ].startswith ("<NJT>" ):
38
49
nested_values = flat_key_values [total_key ]
39
- nested_lengths = None
50
+ total_key = total_key [:- 1 ] + total_key [- 1 ].replace ("<NJT>" , "" )
51
+ _nt_values_and_keys .put ((nested_values , total_key ))
40
52
continue
41
53
if total_key [- 1 ].startswith ("<NJT_LENGTHS>" ):
42
54
nested_lengths = flat_key_values [total_key ]
55
+ _nt_lengths .put (nested_lengths )
43
56
continue
44
57
elif total_key [- 1 ].startswith ("<NJT_OFFSETS" ):
45
58
offsets = flat_key_values [total_key ]
46
- key = key .replace ("<NJT_OFFSETS>" , "" )
47
- value = torch .nested .nested_tensor_from_jagged (
48
- nested_values , offsets = offsets , lengths = nested_lengths
49
- )
50
- del nested_values
51
- del nested_lengths
59
+ _nt_offsets .put (offsets )
60
+ continue
52
61
else :
53
62
value = flat_key_values [total_key ]
54
63
d [key ] = value
64
+
55
65
for k , v in metadata .items ():
56
66
# Each remaining key is a tuple pointing to a sub-tensordict
57
67
d [k ] = from_metadata (
@@ -64,7 +74,18 @@ def from_metadata(metadata=metadata_dict, prefix=None):
64
74
# result._is_shared = is_shared
65
75
return result
66
76
67
- return from_metadata ()
77
+ result = from_metadata ()
78
+ # Then assign the nested tensors
79
+ while not _nt_values_and_keys .empty ():
80
+ vals , key = _nt_values_and_keys .get ()
81
+ lengths = _nt_lengths .get ()
82
+ offsets = _nt_offsets .get ()
83
+ value = torch .nested .nested_tensor_from_jagged (
84
+ vals , offsets = offsets , lengths = lengths
85
+ )
86
+ result ._set_tuple (key , value , inplace = False , validated = True )
87
+
88
+ return result
68
89
69
90
70
91
def _rebuild_tensordict_files_shared (flat_key_values , metadata_dict ):
@@ -75,9 +96,18 @@ def _rebuild_tensordict_files_consolidated(
75
96
metadata ,
76
97
storage ,
77
98
):
99
+ _nt_values_and_keys = queue .Queue ()
100
+ _nt_lengths = queue .Queue ()
101
+ _nt_offsets = queue .Queue ()
102
+
78
103
def from_metadata (metadata = metadata , prefix = None ):
79
104
consolidated = {"storage" : storage , "metadata" : metadata }
80
105
metadata = dict (metadata )
106
+
107
+ _ = metadata .pop ("njt_values_start" , None )
108
+ _ = metadata .pop ("njt_lengths_start" , None )
109
+ _ = metadata .pop ("njt_offsets_start" , None )
110
+
81
111
non_tensor = metadata .pop ("non_tensors" )
82
112
leaves = metadata .pop ("leaves" )
83
113
cls = metadata .pop ("cls" )
@@ -99,31 +129,45 @@ def from_metadata(metadata=metadata, prefix=None):
99
129
value = value [: local_shape .numel ()]
100
130
value = value .view (local_shape )
101
131
if key .startswith ("<NJT>" ):
102
- nested_values = value
103
- nested_lengths = None
132
+ key = key .replace ("<NJT>" , "" )
133
+ if prefix :
134
+ total_key = prefix + (key ,)
135
+ else :
136
+ total_key = (key ,)
137
+ _nt_values_and_keys .put ((value , total_key ))
104
138
continue
105
139
elif key .startswith ("<NJT_LENGTHS>" ):
106
- nested_lengths = value
140
+ _nt_lengths . put ( value )
107
141
continue
108
142
elif key .startswith ("<NJT_OFFSETS>" ):
109
- offsets = value
110
- value = torch .nested .nested_tensor_from_jagged (
111
- nested_values , offsets = offsets , lengths = nested_lengths
112
- )
113
- key = key .replace ("<NJT_OFFSETS>" , "" )
143
+ _nt_offsets .put (value )
144
+ if _nt_offsets .qsize () > _nt_lengths .qsize ():
145
+ _nt_lengths .put (None )
146
+ continue
114
147
d [key ] = value
115
- for k , v in metadata .items ():
148
+ for key , val in metadata .items ():
116
149
# Each remaining key is a tuple pointing to a sub-tensordict
117
- d [k ] = from_metadata (
118
- v , prefix = prefix + (k ,) if prefix is not None else (k ,)
150
+ d [key ] = from_metadata (
151
+ val , prefix = prefix + (key ,) if prefix is not None else (key ,)
119
152
)
120
153
result = CLS_MAP [cls ]._from_dict_validated (d , ** cls_metadata )
121
154
if is_locked :
122
155
result = result .lock_ ()
123
156
result ._consolidated = consolidated
124
157
return result
125
158
126
- return from_metadata ()
159
+ result = from_metadata ()
160
+ # Then assign the nested tensors
161
+ while not _nt_values_and_keys .empty ():
162
+ vals , key = _nt_values_and_keys .get ()
163
+ lengths = _nt_lengths .get ()
164
+ offsets = _nt_offsets .get ()
165
+ value = torch .nested .nested_tensor_from_jagged (
166
+ vals , offsets = offsets , lengths = lengths
167
+ )
168
+ result ._set_tuple (key , value , inplace = False , validated = True )
169
+
170
+ return result
127
171
128
172
129
173
def _make_td (cls , state ):
0 commit comments