Skip to content

Feature: MemmapTensor storage for ReplayBuffer #224

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 32 commits into from
Jun 28, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions test/test_memmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import argparse
import os.path
import pickle
import tempfile
Expand Down Expand Up @@ -71,27 +71,28 @@ def test_memmap_del():
assert os.path.isfile(filename)


@pytest.mark.parametrize("value", [True, False])
def test_memmap_ownership(value):
@pytest.mark.parametrize("transfer_ownership", [True, False])
def test_memmap_ownership(transfer_ownership):
t = torch.tensor([1])
m = MemmapTensor(t, transfer_ownership=value)
assert m.file.delete
m = MemmapTensor(t, transfer_ownership=transfer_ownership)
assert not m.file.delete
with tempfile.NamedTemporaryFile(suffix=".pkl") as tmp:
pickle.dump(m, tmp)
assert m._has_ownership is not m.transfer_ownership
m2 = pickle.load(open(tmp.name, "rb"))
assert m2._memmap_array is None # assert data is not actually loaded
assert isinstance(m2, MemmapTensor)
assert m2.filename == m.filename
assert m2.file.name == m2.filename
assert m2.file._closer.name == m2.filename
assert (
m.file.delete is not m2.file.delete
) # delete attributes must have changed
# assert m2.file.name == m2.filename
# assert m2.file._closer.name == m2.filename
assert (
m.file._closer.delete is not m2.file._closer.delete
m._has_ownership is not m2._has_ownership
) # delete attributes must have changed
# assert (
# m.file._closer.delete is not m2.file._closer.delete
# ) # delete attributes must have changed
del m
if value:
if transfer_ownership:
assert os.path.isfile(m2.filename)
else:
# m2 should point to a non-existing file
Expand Down Expand Up @@ -136,5 +137,11 @@ def test_memmap_clone():
assert m2c == m1


def test_memmap_tensor():
t = torch.tensor([[1, 2, 3], [4, 5, 6]])
assert (torch.tensor(t) == t).all()


if __name__ == "__main__":
pytest.main([__file__, "--capture", "no"])
args, unknown = argparse.ArgumentParser().parse_known_args()
pytest.main([__file__, "--capture", "no", "--exitfirst"] + unknown)
23 changes: 20 additions & 3 deletions test/test_rb.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,22 @@
TensorDictReplayBuffer,
)
from torchrl.data.replay_buffers import TensorDictPrioritizedReplayBuffer
from torchrl.data.replay_buffers.storages import ListStorage
from torchrl.data.replay_buffers.storages import (
ListStorage,
LazyMemmapStorage,
LazyTensorStorage,
)
from torchrl.data.tensordict.tensordict import assert_allclose_td, _TensorDict


collate_fn_dict = {
ListStorage: lambda x: torch.stack(x, 0),
LazyTensorStorage: lambda x: x,
LazyMemmapStorage: lambda x: x,
None: lambda x: torch.stack(x, 0),
}


@pytest.mark.parametrize(
"rbtype",
[
Expand All @@ -39,8 +51,13 @@ class TestBuffers:
_default_params_td_prb = {"alpha": 0.8, "beta": 0.9}

def _get_rb(self, rbtype, size, storage, prefetch):
collate_fn = collate_fn_dict[storage]
if storage is not None:
storage = storage()
storage = (
storage(size)
if storage in (LazyMemmapStorage, LazyTensorStorage)
else storage()
)
if rbtype is ReplayBuffer:
params = self._default_params_rb
elif rbtype is PrioritizedReplayBuffer:
Expand All @@ -55,7 +72,7 @@ def _get_rb(self, rbtype, size, storage, prefetch):
size=size,
storage=storage,
prefetch=prefetch,
collate_fn=lambda x: torch.stack(x, 0),
collate_fn=collate_fn,
**params
)
return rb
Expand Down
69 changes: 46 additions & 23 deletions test/test_tensordict.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
import torch
from _utils_internal import get_available_devices
from torch import multiprocessing as mp
from torchrl.data import SavedTensorDict, TensorDict
from torchrl.data import SavedTensorDict, TensorDict, MemmapTensor
from torchrl.data.tensordict.tensordict import (
assert_allclose_td,
LazyStackedTensorDict,
stack as stack_td,
)
from torchrl.data.tensordict.utils import _getitem_batch_size, convert_ellipsis_to_idx

Expand Down Expand Up @@ -67,14 +68,14 @@ def test_tensordict_set(device):
def test_stack(device):
torch.manual_seed(1)
tds_list = [TensorDict(source={}, batch_size=(4, 5)) for _ in range(3)]
tds = torch.stack(tds_list, 0)
tds = stack_td(tds_list, 0, contiguous=False)
assert tds[0] is tds_list[0]

td = TensorDict(
source={"a": torch.randn(4, 5, 3, device=device)}, batch_size=(4, 5)
)
td_list = list(td)
td_reconstruct = torch.stack(td_list, 0)
td_reconstruct = stack_td(td_list, 0)
assert td_reconstruct.batch_size == td.batch_size
assert (td_reconstruct == td).all()

Expand All @@ -95,13 +96,13 @@ def test_tensordict_indexing(device):
td_select = td[None, :2]
td_select._check_batch_size()

td_reconstruct = torch.stack([_td for _td in td], 0)
td_reconstruct = stack_td([_td for _td in td], 0, contiguous=False)
assert (
td_reconstruct == td
).all(), f"td and td_reconstruct differ, got {td} and {td_reconstruct}"

superlist = [torch.stack([__td for __td in _td], 0) for _td in td]
td_reconstruct = torch.stack(superlist, 0)
superlist = [stack_td([__td for __td in _td], 0, contiguous=False) for _td in td]
td_reconstruct = stack_td(superlist, 0, contiguous=False)
assert (
td_reconstruct == td
).all(), f"td and td_reconstruct differ, got {td == td_reconstruct}"
Expand Down Expand Up @@ -342,8 +343,10 @@ def test_permute_with_tensordict_operations(device):
"b": torch.randn(4, 5, 7, device=device),
"c": torch.randn(4, 5, device=device),
}
td1 = torch.stack(
[TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)], 2
td1 = stack_td(
[TensorDict(batch_size=(4, 5), source=d).clone() for _ in range(6)],
2,
contiguous=False,
).permute(2, 1, 0)
assert td1.shape == torch.Size((6, 5, 4))

Expand All @@ -370,7 +373,7 @@ def test_stacked_td(stack_dim, device):
tensordicts3 = tensordicts[3]
sub_td = LazyStackedTensorDict(*tensordicts, stack_dim=stack_dim)

std_bis = torch.stack(tensordicts, dim=stack_dim)
std_bis = stack_td(tensordicts, dim=stack_dim, contiguous=False)
assert (sub_td == std_bis).all()

item = tuple([*[slice(None) for _ in range(stack_dim)], 0])
Expand Down Expand Up @@ -426,7 +429,7 @@ def test_savedtensordict(device):
)
for i in range(4)
]
ss = torch.stack(ss_list, 0)
ss = stack_td(ss_list, 0)
assert ss_list[1] is ss[1]
torch.testing.assert_allclose(ss_list[1].get("a"), vals[1])
torch.testing.assert_allclose(ss_list[1].get("a"), ss[1].get("a"))
Expand Down Expand Up @@ -480,6 +483,7 @@ def test_convert_ellipsis_to_idx_invalid(ellipsis_index, expectation):
"sub_td",
"idx_td",
"saved_td",
"memmap_td",
"unsqueezed_td",
"td_reset_bs",
],
Expand Down Expand Up @@ -514,7 +518,7 @@ def stacked_td(self):
},
batch_size=[4, 3, 1],
)
return torch.stack([td1, td2], 2)
return stack_td([td1, td2], 2)

@property
def idx_td(self):
Expand Down Expand Up @@ -544,6 +548,10 @@ def sub_td(self):
def saved_td(self):
return SavedTensorDict(source=self.td)

@property
def memmap_td(self):
return self.td.memmap_()

@property
def unsqueezed_td(self):
td = TensorDict(
Expand Down Expand Up @@ -618,10 +626,14 @@ def test_cast(self, td_name):
td_saved = td.to(SavedTensorDict)
assert (td == td_saved).all()

def test_remove(self, td_name):
@pytest.mark.parametrize("call_del", [True, False])
def test_remove(self, td_name, call_del):
torch.manual_seed(1)
td = getattr(self, td_name)
td = td.del_("a")
if call_del:
del td["a"]
else:
td = td.del_("a")
assert td is not None
assert "a" not in td.keys()

Expand Down Expand Up @@ -754,7 +766,7 @@ def test_unbind(self, td_name):
torch.manual_seed(1)
td = getattr(self, td_name)
td_unbind = torch.unbind(td, dim=0)
assert (td == torch.stack(td_unbind, 0)).all()
assert (td == stack_td(td_unbind, 0).contiguous()).all()
assert (td[0] == td_unbind[0]).all()

@pytest.mark.parametrize("squeeze_dim", [0, 1])
Expand Down Expand Up @@ -834,6 +846,10 @@ def test_rename_key(self, td_name) -> None:
assert "a" not in td.keys()

z = td.get("z")
if isinstance(a, MemmapTensor):
a = a._tensor
if isinstance(z, MemmapTensor):
z = z._tensor
torch.testing.assert_allclose(a, z)

new_z = torch.randn_like(z)
Expand Down Expand Up @@ -914,7 +930,7 @@ def test_setitem_string(self, td_name):
def test_getitem_string(self, td_name):
torch.manual_seed(1)
td = getattr(self, td_name)
assert isinstance(td["a"], torch.Tensor)
assert isinstance(td["a"], (MemmapTensor, torch.Tensor))

def test_delitem(self, td_name):
torch.manual_seed(1)
Expand Down Expand Up @@ -1036,7 +1052,7 @@ def td(self):

@property
def stacked_td(self):
return torch.stack([self.td for _ in range(2)], 0)
return stack_td([self.td for _ in range(2)], 0)

@property
def idx_td(self):
Expand Down Expand Up @@ -1148,7 +1164,7 @@ def test_batchsize_reset():
assert td.to_tensordict().batch_size == torch.Size([3])

# test that lazy tds return an exception
td_stack = torch.stack([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)])
td_stack = stack_td([TensorDict({"a": torch.randn(3)}, [3]) for _ in range(2)])
td_stack.to_tensordict().batch_size = [2]
with pytest.raises(
RuntimeError,
Expand Down Expand Up @@ -1222,7 +1238,7 @@ def test_create_on_device():
# stacked TensorDict
td1 = TensorDict({}, [5])
td2 = TensorDict({}, [5])
stackedtd = torch.stack([td1, td2], 0)
stackedtd = stack_td([td1, td2], 0)
with pytest.raises(RuntimeError):
stackedtd.device
stackedtd.set("a", torch.randn(2, 5, device=device))
Expand All @@ -1232,7 +1248,7 @@ def test_create_on_device():

td1 = TensorDict({}, [5], device="cuda:0")
td2 = TensorDict({}, [5], device="cuda:0")
stackedtd = torch.stack([td1, td2], 0)
stackedtd = stack_td([td1, td2], 0)
stackedtd.set("a", torch.randn(2, 5, 1))
assert stackedtd.get("a").device == device
assert td1.get("a").device == device
Expand Down Expand Up @@ -1417,7 +1433,7 @@ def test_mp(td_type):
if td_type == "contiguous":
tensordict = tensordict.share_memory_()
elif td_type == "stack":
tensordict = torch.stack(
tensordict = stack_td(
[
tensordict[0].clone().share_memory_(),
tensordict[1].clone().share_memory_(),
Expand All @@ -1429,7 +1445,7 @@ def test_mp(td_type):
elif td_type == "memmap":
tensordict = tensordict.memmap_()
elif td_type == "memmap_stack":
tensordict = torch.stack(
tensordict = stack_td(
[tensordict[0].clone().memmap_(), tensordict[1].clone().memmap_()], 0
)
else:
Expand Down Expand Up @@ -1457,7 +1473,7 @@ def test_stack_keys():
},
batch_size=[],
)
td = torch.stack([td1, td2], 0)
td = stack_td([td1, td2], 0)
assert "a" in td.keys()
assert "b" not in td.keys()
assert "b" in td[1].keys()
Expand All @@ -1467,13 +1483,20 @@ def test_stack_keys():
td.set_("b", torch.randn(2, 10)) # b has been set before

td1.set("c", torch.randn(4))
assert "c" in td.keys() # now all tds have the key c
td[
"c"
] # we must first query that key for the stacked tensordict to update the list
assert "c" in td.keys(), list(td.keys()) # now all tds have the key c
td.get("c")

td1.set("d", torch.randn(6))
with pytest.raises(RuntimeError):
td.get("d")

td["e"] = torch.randn(2, 4)
assert "e" in td.keys() # now all tds have the key c
td.get("e")


def test_getitem_batch_size():
shape = [
Expand Down
1 change: 1 addition & 0 deletions torchrl/data/replay_buffers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,4 @@
# LICENSE file in the root directory of this source tree.

from .replay_buffers import *
from .storages import *
Loading