81
81
# However to achieve this you would need to write a complicated collate
82
82
# function that make sure that every modality is aggregated properly.
83
83
84
+
84
85
def collate_dict_fn (dict_list ):
85
86
final_dict = {}
86
87
for key in dict_list [0 ].keys ():
87
- final_dict [key ]= []
88
+ final_dict [key ] = []
88
89
for single_dict in dict_list :
89
90
final_dict [key ].append (single_dict [key ])
90
91
final_dict [key ] = torch .stack (final_dict [key ], dim = 0 )
91
92
return final_dict
92
93
94
+
95
+ import torch
96
+
93
97
###############################################################################
94
98
# dataloader = Dataloader(DictDataset(), collate_fn = collate_dict_fn)
95
99
#
@@ -120,11 +124,9 @@ def collate_dict_fn(dict_list):
120
124
from torchrl .data import TensorDict
121
125
from torchrl .data .tensordict .tensordict import (
122
126
UnsqueezedTensorDict ,
123
- ViewedTensorDict ,
127
+ _ViewedTensorDict ,
124
128
PermutedTensorDict ,
125
- LazyStackedTensorDict ,
126
129
)
127
- import torch
128
130
129
131
###############################################################################
130
132
# TensorDict is a Datastructure indexed by either keys or numerical indices.
@@ -147,7 +149,7 @@ def collate_dict_fn(dict_list):
147
149
# does not work
148
150
try :
149
151
tensordict = TensorDict ({"a" : a , "b" : b }, batch_size = [3 , 4 , 5 ])
150
- except :
152
+ except RuntimeError :
151
153
print ("caramba!" )
152
154
153
155
###############################################################################
@@ -158,10 +160,10 @@ def collate_dict_fn(dict_list):
158
160
a = torch .zeros (3 , 4 )
159
161
b = TensorDict (
160
162
{
161
- "c" : torch .zeros (3 , 4 , 5 , dtype = torch .int32 ),
162
- "d" : torch .zeros (3 , 4 , 5 , 6 , dtype = torch .float32 )
163
+ "c" : torch .zeros (3 , 4 , 5 , dtype = torch .int32 ),
164
+ "d" : torch .zeros (3 , 4 , 5 , 6 , dtype = torch .float32 ),
163
165
},
164
- batch_size = [3 , 4 , 5 ]
166
+ batch_size = [3 , 4 , 5 ],
165
167
)
166
168
tensordict = TensorDict ({"a" : a , "b" : b }, batch_size = [3 , 4 ])
167
169
print (tensordict )
@@ -233,7 +235,7 @@ def collate_dict_fn(dict_list):
233
235
# The ``update`` method can be used to update a TensorDict with another one
234
236
# (or with a dict):
235
237
236
- tensordict .update ({"a" : torch .ones ((3 , 4 , 5 )), "d" : 2 * torch .ones ((3 , 4 , 2 ))})
238
+ tensordict .update ({"a" : torch .ones ((3 , 4 , 5 )), "d" : 2 * torch .ones ((3 , 4 , 2 ))})
237
239
# Also works with tensordict.update(TensorDict({"a":torch.ones((3, 4, 5)),
238
240
# "c":torch.ones((3, 4, 2))}, batch_size=[3,4]))
239
241
print (f"a is now equal to 1: { (tensordict ['a' ] == 1 ).all ()} " )
@@ -262,7 +264,9 @@ def collate_dict_fn(dict_list):
262
264
# but it must be shared across tensors. Indeed, you cannot have items that don't
263
265
# share the batch size inside the same TensorDict:
264
266
265
- tensordict = TensorDict ({"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ])
267
+ tensordict = TensorDict (
268
+ {"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ]
269
+ )
266
270
print (f"Our TensorDict is of size { tensordict .shape } " )
267
271
268
272
###############################################################################
@@ -302,8 +306,10 @@ def collate_dict_fn(dict_list):
302
306
tensordict = TensorDict ({}, [10 ])
303
307
for i in range (2 ):
304
308
tensordict [i ] = TensorDict ({"a" : torch .randn (3 , 4 )}, [])
305
- assert (tensordict [9 ]["a" ] == torch .zeros ((3 ,4 ))).all ()
306
- tensordict = TensorDict ({"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ])
309
+ assert (tensordict [9 ]["a" ] == torch .zeros ((3 , 4 ))).all ()
310
+ tensordict = TensorDict (
311
+ {"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ]
312
+ )
307
313
308
314
###############################################################################
309
315
# Devices
@@ -327,7 +333,9 @@ def collate_dict_fn(dict_list):
327
333
# than the original item.
328
334
329
335
tensordict_clone = tensordict .clone ()
330
- print (f"Content is identical ({ (tensordict ['a' ] == tensordict_clone ['a' ]).all ()} ) but duplicated ({ tensordict ['a' ] is not tensordict_clone ['a' ]} )" )
336
+ print (
337
+ f"Content is identical ({ (tensordict ['a' ] == tensordict_clone ['a' ]).all ()} ) but duplicated ({ tensordict ['a' ] is not tensordict_clone ['a' ]} )"
338
+ )
331
339
332
340
###############################################################################
333
341
# **Slicing and Indexing**
@@ -356,7 +364,9 @@ def collate_dict_fn(dict_list):
356
364
# to the original tensordict as well as the desired index such that tensor
357
365
# modifications can be achieved easily.
358
366
359
- tensordict = TensorDict ({"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ])
367
+ tensordict = TensorDict (
368
+ {"a" : torch .zeros (3 , 4 , 5 ), "b" : torch .zeros (3 , 4 )}, batch_size = [3 , 4 ]
369
+ )
360
370
# a SubTensorDict keeps track of the original one: it does not create a copy in memory of the original data
361
371
subtd = tensordict .get_sub_tensordict ((slice (None ), torch .tensor ([1 , 3 ])))
362
372
tensordict .fill_ ("a" , - 1 )
@@ -422,10 +432,10 @@ def collate_dict_fn(dict_list):
422
432
###############################################################################
423
433
# **View**
424
434
#
425
- # Support for the view operation returning a ``ViewedTensorDict ``.
435
+ # Support for the view operation returning a ``_ViewedTensorDict ``.
426
436
# Use ``to_tensordict`` to comeback to retrieve TensorDict.
427
437
428
- assert type (tensordict .view (- 1 )) == ViewedTensorDict
438
+ assert type (tensordict .view (- 1 )) == _ViewedTensorDict
429
439
assert tensordict .view (- 1 ).shape [0 ] == 12
430
440
431
441
###############################################################################
@@ -434,8 +444,8 @@ def collate_dict_fn(dict_list):
434
444
# We can permute the dims of ``TensorDict``. Permute is a Lazy operation that
435
445
# returns PermutedTensorDict. Use ``to_tensordict`` to convert to ``TensorDict``.
436
446
437
- assert type (tensordict .permute (1 ,0 )) == PermutedTensorDict
438
- assert tensordict .permute (1 ,0 ).batch_size == torch .Size ([4 , 3 ])
447
+ assert type (tensordict .permute (1 , 0 )) == PermutedTensorDict
448
+ assert tensordict .permute (1 , 0 ).batch_size == torch .Size ([4 , 3 ])
439
449
440
450
###############################################################################
441
451
# **Reshape**
0 commit comments