Skip to content

Commit dbc7ba0

Browse files
authored
Feature: MemmapTensor storage for ReplayBuffer (#224)
1 parent faebd2b commit dbc7ba0

File tree

11 files changed

+408
-112
lines changed

11 files changed

+408
-112
lines changed

test/test_memmap.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# This source code is licensed under the MIT license found in the
44
# LICENSE file in the root directory of this source tree.
5-
5+
import argparse
66
import os.path
77
import pickle
88
import tempfile
@@ -71,27 +71,28 @@ def test_memmap_del():
7171
assert os.path.isfile(filename)
7272

7373

74-
@pytest.mark.parametrize("value", [True, False])
75-
def test_memmap_ownership(value):
74+
@pytest.mark.parametrize("transfer_ownership", [True, False])
75+
def test_memmap_ownership(transfer_ownership):
7676
t = torch.tensor([1])
77-
m = MemmapTensor(t, transfer_ownership=value)
78-
assert m.file.delete
77+
m = MemmapTensor(t, transfer_ownership=transfer_ownership)
78+
assert not m.file.delete
7979
with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp:
8080
pickle.dump(m, tmp)
81+
assert m._has_ownership is not m.transfer_ownership
8182
m2 = pickle.load(open(tmp.name, "rb"))
8283
assert m2._memmap_array is None # assert data is not actually loaded
8384
assert isinstance(m2, MemmapTensor)
8485
assert m2.filename == m.filename
85-
assert m2.file.name == m2.filename
86-
assert m2.file._closer.name == m2.filename
87-
assert (
88-
m.file.delete is not m2.file.delete
89-
) # delete attributes must have changed
86+
# assert m2.file.name == m2.filename
87+
# assert m2.file._closer.name == m2.filename
9088
assert (
91-
m.file._closer.delete is not m2.file._closer.delete
89+
m._has_ownership is not m2._has_ownership
9290
) # delete attributes must have changed
91+
# assert (
92+
# m.file._closer.delete is not m2.file._closer.delete
93+
# ) # delete attributes must have changed
9394
del m
94-
if value:
95+
if transfer_ownership:
9596
assert os.path.isfile(m2.filename)
9697
else:
9798
# m2 should point to a non-existing file
@@ -136,5 +137,11 @@ def test_memmap_clone():
136137
assert m2c == m1
137138

138139

140+
def test_memmap_tensor():
141+
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
142+
assert (torch.tensor(t) == t).all()
143+
144+
139145
if __name__ == "__main__":
140-
pytest.main([__file__, "--capture", "no"])
146+
args, unknown = argparse.ArgumentParser().parse_known_args()
147+
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)

test/test_rb.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,22 @@
1616
TensorDictReplayBuffer,
1717
)
1818
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
19-
from torchrl.data.replay_buffers.storages import ListStorage
19+
from torchrl.data.replay_buffers.storages import (
20+
ListStorage,
21+
LazyMemmapStorage,
22+
LazyTensorStorage,
23+
)
2024
from torchrl.data.tensordict.tensordict import assert_allclose_td, _TensorDict
2125

2226

27+
collate_fn_dict = {
28+
ListStorage: lambda x: torch.stack(x, 0),
29+
LazyTensorStorage: lambda x: x,
30+
LazyMemmapStorage: lambda x: x,
31+
None: lambda x: torch.stack(x, 0),
32+
}
33+
34+
2335
@pytest.mark.parametrize(
2436
"rbtype",
2537
[
@@ -39,8 +51,13 @@ class TestBuffers:
3951
_default_params_td_prb = {"alpha": 0.8, "beta": 0.9}
4052

4153
def _get_rb(self, rbtype, size, storage, prefetch):
54+
collate_fn = collate_fn_dict[storage]
4255
if storage is not None:
43-
storage = storage()
56+
storage = (
57+
storage(size)
58+
if storage in (LazyMemmapStorage, LazyTensorStorage)
59+
else storage()
60+
)
4461
if rbtype is ReplayBuffer:
4562
params = self._default_params_rb
4663
elif rbtype is PrioritizedReplayBuffer:
@@ -55,7 +72,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
5572
size=size,
5673
storage=storage,
5774
prefetch=prefetch,
58-
collate_fn=lambda x: torch.stack(x, 0),
75+
collate_fn=collate_fn,
5976
**params
6077
)
6178
return rb

test/test_tensordict.py

Lines changed: 46 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,11 @@
1212
import torch
1313
from _utils_internal import get_available_devices
1414
from torch import multiprocessing as mp
15-
from torchrl.data import SavedTensorDict, TensorDict
15+
from torchrl.data import SavedTensorDict, TensorDict, MemmapTensor
1616
from torchrl.data.tensordict.tensordict import (
1717
assert_allclose_td,
1818
LazyStackedTensorDict,
19+
stack as stack_td,
1920
)
2021
from torchrl.data.tensordict.utils import _getitem_batch_size, convert_ellipsis_to_idx
2122

@@ -67,14 +68,14 @@ def test_tensordict_set(device):
6768
def test_stack(device):
6869
torch.manual_seed(1)
6970
tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)]
70-
tds = torch.stack(tds_list, 0)
71+
tds = stack_td(tds_list, 0, contiguous=False)
7172
assert tds[0] is tds_list[0]
7273

7374
td = TensorDict(
7475
source={"a": torch.randn(4, 5, 3, device=device)}, batch_size=(4, 5)
7576
)
7677
td_list = list(td)
77-
td_reconstruct = torch.stack(td_list, 0)
78+
td_reconstruct = stack_td(td_list, 0)
7879
assert td_reconstruct.batch_size == td.batch_size
7980
assert (td_reconstruct == td).all()
8081

@@ -95,13 +96,13 @@ def test_tensordict_indexing(device):
9596
td_select = td[None, :2]
9697
td_select._check_batch_size()
9798

98-
td_reconstruct = torch.stack([_td for _td in td], 0)
99+
td_reconstruct = stack_td([_td for _td in td], 0, contiguous=False)
99100
assert (
100101
td_reconstruct == td
101102
).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}"
102103

103-
superlist = [torch.stack([__td for __td in _td], 0) for _td in td]
104-
td_reconstruct = torch.stack(superlist, 0)
104+
superlist = [stack_td([__td for __td in _td], 0, contiguous=False) for _td in td]
105+
td_reconstruct = stack_td(superlist, 0, contiguous=False)
105106
assert (
106107
td_reconstruct == td
107108
).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}"
@@ -342,8 +343,10 @@ def test_permute_with_tensordict_operations(device):
342343
"b": torch.randn(4, 5, 7, device=device),
343344
"c": torch.randn(4, 5, device=device),
344345
}
345-
td1 = torch.stack(
346-
[TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)], 2
346+
td1 = stack_td(
347+
[TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)],
348+
2,
349+
contiguous=False,
347350
).permute(2, 1, 0)
348351
assert td1.shape == torch.Size((6, 5, 4))
349352

@@ -370,7 +373,7 @@ def test_stacked_td(stack_dim, device):
370373
tensordicts3 = tensordicts[3]
371374
sub_td = LazyStackedTensorDict(*tensordicts, stack_dim=stack_dim)
372375

373-
std_bis = torch.stack(tensordicts, dim=stack_dim)
376+
std_bis = stack_td(tensordicts, dim=stack_dim, contiguous=False)
374377
assert (sub_td == std_bis).all()
375378

376379
item = tuple([*[slice(None) for _ in range(stack_dim)], 0])
@@ -426,7 +429,7 @@ def test_savedtensordict(device):
426429
)
427430
for i in range(4)
428431
]
429-
ss = torch.stack(ss_list, 0)
432+
ss = stack_td(ss_list, 0)
430433
assert ss_list[1] is ss[1]
431434
torch.testing.assert_allclose(ss_list[1].get("a"), vals[1])
432435
torch.testing.assert_allclose(ss_list[1].get("a"), ss[1].get("a"))
@@ -480,6 +483,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation):
480483
"sub_td",
481484
"idx_td",
482485
"saved_td",
486+
"memmap_td",
483487
"unsqueezed_td",
484488
"td_reset_bs",
485489
],
@@ -514,7 +518,7 @@ def stacked_td(self):
514518
},
515519
batch_size=[4, 3, 1],
516520
)
517-
return torch.stack([td1, td2], 2)
521+
return stack_td([td1, td2], 2)
518522

519523
@property
520524
def idx_td(self):
@@ -544,6 +548,10 @@ def sub_td(self):
544548
def saved_td(self):
545549
return SavedTensorDict(source=self.td)
546550

551+
@property
552+
def memmap_td(self):
553+
return self.td.memmap_()
554+
547555
@property
548556
def unsqueezed_td(self):
549557
td = TensorDict(
@@ -618,10 +626,14 @@ def test_cast(self, td_name):
618626
td_saved = td.to(SavedTensorDict)
619627
assert (td == td_saved).all()
620628

621-
def test_remove(self, td_name):
629+
@pytest.mark.parametrize("call_del", [True, False])
630+
def test_remove(self, td_name, call_del):
622631
torch.manual_seed(1)
623632
td = getattr(self, td_name)
624-
td = td.del_("a")
633+
if call_del:
634+
del td["a"]
635+
else:
636+
td = td.del_("a")
625637
assert td is not None
626638
assert "a" not in td.keys()
627639

@@ -754,7 +766,7 @@ def test_unbind(self, td_name):
754766
torch.manual_seed(1)
755767
td = getattr(self, td_name)
756768
td_unbind = torch.unbind(td, dim=0)
757-
assert (td == torch.stack(td_unbind, 0)).all()
769+
assert (td == stack_td(td_unbind, 0).contiguous()).all()
758770
assert (td[0] == td_unbind[0]).all()
759771

760772
@pytest.mark.parametrize("squeeze_dim", [0, 1])
@@ -834,6 +846,10 @@ def test_rename_key(self, td_name) -> None:
834846
assert "a" not in td.keys()
835847

836848
z = td.get("z")
849+
if isinstance(a, MemmapTensor):
850+
a = a._tensor
851+
if isinstance(z, MemmapTensor):
852+
z = z._tensor
837853
torch.testing.assert_allclose(a, z)
838854

839855
new_z = torch.randn_like(z)
@@ -914,7 +930,7 @@ def test_setitem_string(self, td_name):
914930
def test_getitem_string(self, td_name):
915931
torch.manual_seed(1)
916932
td = getattr(self, td_name)
917-
assert isinstance(td["a"], torch.Tensor)
933+
assert isinstance(td["a"], (MemmapTensor, torch.Tensor))
918934

919935
def test_delitem(self, td_name):
920936
torch.manual_seed(1)
@@ -1036,7 +1052,7 @@ def td(self):
10361052

10371053
@property
10381054
def stacked_td(self):
1039-
return torch.stack([self.td for _ in range(2)], 0)
1055+
return stack_td([self.td for _ in range(2)], 0)
10401056

10411057
@property
10421058
def idx_td(self):
@@ -1148,7 +1164,7 @@ def test_batchsize_reset():
11481164
assert td.to_tensordict().batch_size == torch.Size([3])
11491165

11501166
# test that lazy tds return an exception
1151-
td_stack = torch.stack([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)])
1167+
td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)])
11521168
td_stack.to_tensordict().batch_size = [2]
11531169
with pytest.raises(
11541170
RuntimeError,
@@ -1222,7 +1238,7 @@ def test_create_on_device():
12221238
# stacked TensorDict
12231239
td1 = TensorDict({}, [5])
12241240
td2 = TensorDict({}, [5])
1225-
stackedtd = torch.stack([td1, td2], 0)
1241+
stackedtd = stack_td([td1, td2], 0)
12261242
with pytest.raises(RuntimeError):
12271243
stackedtd.device
12281244
stackedtd.set("a", torch.randn(2, 5, device=device))
@@ -1232,7 +1248,7 @@ def test_create_on_device():
12321248

12331249
td1 = TensorDict({}, [5], device="cuda:0")
12341250
td2 = TensorDict({}, [5], device="cuda:0")
1235-
stackedtd = torch.stack([td1, td2], 0)
1251+
stackedtd = stack_td([td1, td2], 0)
12361252
stackedtd.set("a", torch.randn(2, 5, 1))
12371253
assert stackedtd.get("a").device == device
12381254
assert td1.get("a").device == device
@@ -1417,7 +1433,7 @@ def test_mp(td_type):
14171433
if td_type == "contiguous":
14181434
tensordict = tensordict.share_memory_()
14191435
elif td_type == "stack":
1420-
tensordict = torch.stack(
1436+
tensordict = stack_td(
14211437
[
14221438
tensordict[0].clone().share_memory_(),
14231439
tensordict[1].clone().share_memory_(),
@@ -1429,7 +1445,7 @@ def test_mp(td_type):
14291445
elif td_type == "memmap":
14301446
tensordict = tensordict.memmap_()
14311447
elif td_type == "memmap_stack":
1432-
tensordict = torch.stack(
1448+
tensordict = stack_td(
14331449
[tensordict[0].clone().memmap_(), tensordict[1].clone().memmap_()], 0
14341450
)
14351451
else:
@@ -1457,7 +1473,7 @@ def test_stack_keys():
14571473
},
14581474
batch_size=[],
14591475
)
1460-
td = torch.stack([td1, td2], 0)
1476+
td = stack_td([td1, td2], 0)
14611477
assert "a" in td.keys()
14621478
assert "b" not in td.keys()
14631479
assert "b" in td[1].keys()
@@ -1467,13 +1483,20 @@ def test_stack_keys():
14671483
td.set_("b", torch.randn(2, 10)) # b has been set before
14681484

14691485
td1.set("c", torch.randn(4))
1470-
assert "c" in td.keys() # now all tds have the key c
1486+
td[
1487+
"c"
1488+
] # we must first query that key for the stacked tensordict to update the list
1489+
assert "c" in td.keys(), list(td.keys()) # now all tds have the key c
14711490
td.get("c")
14721491

14731492
td1.set("d", torch.randn(6))
14741493
with pytest.raises(RuntimeError):
14751494
td.get("d")
14761495

1496+
td["e"] = torch.randn(2, 4)
1497+
assert "e" in td.keys() # now all tds have the key c
1498+
td.get("e")
1499+
14771500

14781501
def test_getitem_batch_size():
14791502
shape = [

torchrl/data/replay_buffers/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
# LICENSE file in the root directory of this source tree.
55

66
from .replay_buffers import *
7+
from .storages import *

0 commit comments

Comments
 (0)