12
12
import torch
13
13
from _utils_internal import get_available_devices
14
14
from torch import multiprocessing as mp
15
- from torchrl .data import SavedTensorDict , TensorDict
15
+ from torchrl .data import SavedTensorDict , TensorDict , MemmapTensor
16
16
from torchrl .data .tensordict .tensordict import (
17
17
assert_allclose_td ,
18
18
LazyStackedTensorDict ,
19
+ stack as stack_td ,
19
20
)
20
21
from torchrl .data .tensordict .utils import _getitem_batch_size , convert_ellipsis_to_idx
21
22
@@ -67,14 +68,14 @@ def test_tensordict_set(device):
67
68
def test_stack (device ):
68
69
torch .manual_seed (1 )
69
70
tds_list = [TensorDict (source = {}, batch_size = (4 , 5 )) for _ in range (3 )]
70
- tds = torch . stack (tds_list , 0 )
71
+ tds = stack_td (tds_list , 0 , contiguous = False )
71
72
assert tds [0 ] is tds_list [0 ]
72
73
73
74
td = TensorDict (
74
75
source = {"a" : torch .randn (4 , 5 , 3 , device = device )}, batch_size = (4 , 5 )
75
76
)
76
77
td_list = list (td )
77
- td_reconstruct = torch . stack (td_list , 0 )
78
+ td_reconstruct = stack_td (td_list , 0 )
78
79
assert td_reconstruct .batch_size == td .batch_size
79
80
assert (td_reconstruct == td ).all ()
80
81
@@ -95,13 +96,13 @@ def test_tensordict_indexing(device):
95
96
td_select = td [None , :2 ]
96
97
td_select ._check_batch_size ()
97
98
98
- td_reconstruct = torch . stack ([_td for _td in td ], 0 )
99
+ td_reconstruct = stack_td ([_td for _td in td ], 0 , contiguous = False )
99
100
assert (
100
101
td_reconstruct == td
101
102
).all (), f"td and td_reconstruct differ, got { td } and { td_reconstruct } "
102
103
103
- superlist = [torch . stack ([__td for __td in _td ], 0 ) for _td in td ]
104
- td_reconstruct = torch . stack (superlist , 0 )
104
+ superlist = [stack_td ([__td for __td in _td ], 0 , contiguous = False ) for _td in td ]
105
+ td_reconstruct = stack_td (superlist , 0 , contiguous = False )
105
106
assert (
106
107
td_reconstruct == td
107
108
).all (), f"td and td_reconstruct differ, got { td == td_reconstruct } "
@@ -342,8 +343,10 @@ def test_permute_with_tensordict_operations(device):
342
343
"b" : torch .randn (4 , 5 , 7 , device = device ),
343
344
"c" : torch .randn (4 , 5 , device = device ),
344
345
}
345
- td1 = torch .stack (
346
- [TensorDict (batch_size = (4 , 5 ), source = d ).clone () for _ in range (6 )], 2
346
+ td1 = stack_td (
347
+ [TensorDict (batch_size = (4 , 5 ), source = d ).clone () for _ in range (6 )],
348
+ 2 ,
349
+ contiguous = False ,
347
350
).permute (2 , 1 , 0 )
348
351
assert td1 .shape == torch .Size ((6 , 5 , 4 ))
349
352
@@ -370,7 +373,7 @@ def test_stacked_td(stack_dim, device):
370
373
tensordicts3 = tensordicts [3 ]
371
374
sub_td = LazyStackedTensorDict (* tensordicts , stack_dim = stack_dim )
372
375
373
- std_bis = torch . stack (tensordicts , dim = stack_dim )
376
+ std_bis = stack_td (tensordicts , dim = stack_dim , contiguous = False )
374
377
assert (sub_td == std_bis ).all ()
375
378
376
379
item = tuple ([* [slice (None ) for _ in range (stack_dim )], 0 ])
@@ -426,7 +429,7 @@ def test_savedtensordict(device):
426
429
)
427
430
for i in range (4 )
428
431
]
429
- ss = torch . stack (ss_list , 0 )
432
+ ss = stack_td (ss_list , 0 )
430
433
assert ss_list [1 ] is ss [1 ]
431
434
torch .testing .assert_allclose (ss_list [1 ].get ("a" ), vals [1 ])
432
435
torch .testing .assert_allclose (ss_list [1 ].get ("a" ), ss [1 ].get ("a" ))
@@ -480,6 +483,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation):
480
483
"sub_td" ,
481
484
"idx_td" ,
482
485
"saved_td" ,
486
+ "memmap_td" ,
483
487
"unsqueezed_td" ,
484
488
"td_reset_bs" ,
485
489
],
@@ -514,7 +518,7 @@ def stacked_td(self):
514
518
},
515
519
batch_size = [4 , 3 , 1 ],
516
520
)
517
- return torch . stack ([td1 , td2 ], 2 )
521
+ return stack_td ([td1 , td2 ], 2 )
518
522
519
523
@property
520
524
def idx_td (self ):
@@ -544,6 +548,10 @@ def sub_td(self):
544
548
def saved_td (self ):
545
549
return SavedTensorDict (source = self .td )
546
550
551
+ @property
552
+ def memmap_td (self ):
553
+ return self .td .memmap_ ()
554
+
547
555
@property
548
556
def unsqueezed_td (self ):
549
557
td = TensorDict (
@@ -618,10 +626,14 @@ def test_cast(self, td_name):
618
626
td_saved = td .to (SavedTensorDict )
619
627
assert (td == td_saved ).all ()
620
628
621
- def test_remove (self , td_name ):
629
+ @pytest .mark .parametrize ("call_del" , [True , False ])
630
+ def test_remove (self , td_name , call_del ):
622
631
torch .manual_seed (1 )
623
632
td = getattr (self , td_name )
624
- td = td .del_ ("a" )
633
+ if call_del :
634
+ del td ["a" ]
635
+ else :
636
+ td = td .del_ ("a" )
625
637
assert td is not None
626
638
assert "a" not in td .keys ()
627
639
@@ -754,7 +766,7 @@ def test_unbind(self, td_name):
754
766
torch .manual_seed (1 )
755
767
td = getattr (self , td_name )
756
768
td_unbind = torch .unbind (td , dim = 0 )
757
- assert (td == torch . stack (td_unbind , 0 )).all ()
769
+ assert (td == stack_td (td_unbind , 0 ). contiguous ( )).all ()
758
770
assert (td [0 ] == td_unbind [0 ]).all ()
759
771
760
772
@pytest .mark .parametrize ("squeeze_dim" , [0 , 1 ])
@@ -834,6 +846,10 @@ def test_rename_key(self, td_name) -> None:
834
846
assert "a" not in td .keys ()
835
847
836
848
z = td .get ("z" )
849
+ if isinstance (a , MemmapTensor ):
850
+ a = a ._tensor
851
+ if isinstance (z , MemmapTensor ):
852
+ z = z ._tensor
837
853
torch .testing .assert_allclose (a , z )
838
854
839
855
new_z = torch .randn_like (z )
@@ -914,7 +930,7 @@ def test_setitem_string(self, td_name):
914
930
def test_getitem_string (self , td_name ):
915
931
torch .manual_seed (1 )
916
932
td = getattr (self , td_name )
917
- assert isinstance (td ["a" ], torch .Tensor )
933
+ assert isinstance (td ["a" ], ( MemmapTensor , torch .Tensor ) )
918
934
919
935
def test_delitem (self , td_name ):
920
936
torch .manual_seed (1 )
@@ -1036,7 +1052,7 @@ def td(self):
1036
1052
1037
1053
@property
1038
1054
def stacked_td (self ):
1039
- return torch . stack ([self .td for _ in range (2 )], 0 )
1055
+ return stack_td ([self .td for _ in range (2 )], 0 )
1040
1056
1041
1057
@property
1042
1058
def idx_td (self ):
@@ -1148,7 +1164,7 @@ def test_batchsize_reset():
1148
1164
assert td .to_tensordict ().batch_size == torch .Size ([3 ])
1149
1165
1150
1166
# test that lazy tds return an exception
1151
- td_stack = torch . stack ([TensorDict ({"a" : torch .randn (3 )}, [3 ]) for _ in range (2 )])
1167
+ td_stack = stack_td ([TensorDict ({"a" : torch .randn (3 )}, [3 ]) for _ in range (2 )])
1152
1168
td_stack .to_tensordict ().batch_size = [2 ]
1153
1169
with pytest .raises (
1154
1170
RuntimeError ,
@@ -1222,7 +1238,7 @@ def test_create_on_device():
1222
1238
# stacked TensorDict
1223
1239
td1 = TensorDict ({}, [5 ])
1224
1240
td2 = TensorDict ({}, [5 ])
1225
- stackedtd = torch . stack ([td1 , td2 ], 0 )
1241
+ stackedtd = stack_td ([td1 , td2 ], 0 )
1226
1242
with pytest .raises (RuntimeError ):
1227
1243
stackedtd .device
1228
1244
stackedtd .set ("a" , torch .randn (2 , 5 , device = device ))
@@ -1232,7 +1248,7 @@ def test_create_on_device():
1232
1248
1233
1249
td1 = TensorDict ({}, [5 ], device = "cuda:0" )
1234
1250
td2 = TensorDict ({}, [5 ], device = "cuda:0" )
1235
- stackedtd = torch . stack ([td1 , td2 ], 0 )
1251
+ stackedtd = stack_td ([td1 , td2 ], 0 )
1236
1252
stackedtd .set ("a" , torch .randn (2 , 5 , 1 ))
1237
1253
assert stackedtd .get ("a" ).device == device
1238
1254
assert td1 .get ("a" ).device == device
@@ -1417,7 +1433,7 @@ def test_mp(td_type):
1417
1433
if td_type == "contiguous" :
1418
1434
tensordict = tensordict .share_memory_ ()
1419
1435
elif td_type == "stack" :
1420
- tensordict = torch . stack (
1436
+ tensordict = stack_td (
1421
1437
[
1422
1438
tensordict [0 ].clone ().share_memory_ (),
1423
1439
tensordict [1 ].clone ().share_memory_ (),
@@ -1429,7 +1445,7 @@ def test_mp(td_type):
1429
1445
elif td_type == "memmap" :
1430
1446
tensordict = tensordict .memmap_ ()
1431
1447
elif td_type == "memmap_stack" :
1432
- tensordict = torch . stack (
1448
+ tensordict = stack_td (
1433
1449
[tensordict [0 ].clone ().memmap_ (), tensordict [1 ].clone ().memmap_ ()], 0
1434
1450
)
1435
1451
else :
@@ -1457,7 +1473,7 @@ def test_stack_keys():
1457
1473
},
1458
1474
batch_size = [],
1459
1475
)
1460
- td = torch . stack ([td1 , td2 ], 0 )
1476
+ td = stack_td ([td1 , td2 ], 0 )
1461
1477
assert "a" in td .keys ()
1462
1478
assert "b" not in td .keys ()
1463
1479
assert "b" in td [1 ].keys ()
@@ -1467,13 +1483,20 @@ def test_stack_keys():
1467
1483
td .set_ ("b" , torch .randn (2 , 10 )) # b has been set before
1468
1484
1469
1485
td1 .set ("c" , torch .randn (4 ))
1470
- assert "c" in td .keys () # now all tds have the key c
1486
+ td [
1487
+ "c"
1488
+ ] # we must first query that key for the stacked tensordict to update the list
1489
+ assert "c" in td .keys (), list (td .keys ()) # now all tds have the key c
1471
1490
td .get ("c" )
1472
1491
1473
1492
td1 .set ("d" , torch .randn (6 ))
1474
1493
with pytest .raises (RuntimeError ):
1475
1494
td .get ("d" )
1476
1495
1496
+ td ["e" ] = torch .randn (2 , 4 )
1497
+ assert "e" in td .keys () # now all tds have the key c
1498
+ td .get ("e" )
1499
+
1477
1500
1478
1501
def test_getitem_batch_size ():
1479
1502
shape = [
0 commit comments