Skip to content

Commit ff57c52

Browse files
author
Vincent Moens
committed
[Feature] empty_lazy for lazy tensor storages
ghstack-source-id: dfa85a0 Pull-Request-resolved: #2955
1 parent 36f34da commit ff57c52

File tree

3 files changed

+62
-11
lines changed

3 files changed

+62
-11
lines changed

test/test_rb.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -870,7 +870,9 @@ def extend_and_sample(data):
870870
"`TensorStorage._rand_given_ndim` can be removed."
871871
)
872872

873-
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
873+
@pytest.mark.parametrize(
874+
"storage_type", [partial(LazyTensorStorage, empty_lazy=True), LazyMemmapStorage]
875+
)
874876
def test_extend_lazystack(self, storage_type):
875877

876878
rb = ReplayBuffer(
@@ -881,9 +883,24 @@ def test_extend_lazystack(self, storage_type):
881883
td2 = TensorDict(a=torch.rand(5, 3, 8), batch_size=5)
882884
ltd = LazyStackedTensorDict(td1, td2, stack_dim=1)
883885
rb.extend(ltd)
884-
rb.sample(3)
886+
s = rb.sample(3)
887+
assert isinstance(s, LazyStackedTensorDict)
885888
assert len(rb) == 5
886889

890+
def test_extend_empty_lazy(self):
891+
892+
rb = ReplayBuffer(
893+
storage=LazyTensorStorage(6, empty_lazy=True),
894+
batch_size=2,
895+
)
896+
td1 = TensorDict(a=torch.rand(4, 8), batch_size=4)
897+
td2 = TensorDict(a=torch.rand(3, 8), batch_size=3)
898+
ltd = LazyStackedTensorDict(td1, td2, stack_dim=0)
899+
rb.extend(ltd)
900+
s = rb.sample(3)
901+
assert isinstance(s, LazyStackedTensorDict)
902+
assert len(rb) == 2
903+
887904
@pytest.mark.parametrize("device_data", get_default_devices())
888905
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
889906
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])

torchrl/data/replay_buffers/storages.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -827,7 +827,7 @@ def set(
827827
if not self.initialized:
828828
if not isinstance(cursor, INT_CLASSES):
829829
if is_tensor_collection(data):
830-
self._init(data[0])
830+
self._init(data, shape=data.shape[1:])
831831
else:
832832
self._init(tree_map(lambda x: x[0], data))
833833
else:
@@ -873,7 +873,7 @@ def set( # noqa: F811
873873
)
874874
if not self.initialized:
875875
if not isinstance(cursor, INT_CLASSES):
876-
self._init(data[0])
876+
self._init(data, shape=data.shape[1:])
877877
else:
878878
self._init(data)
879879
if not isinstance(cursor, (*INT_CLASSES, slice)):
@@ -993,6 +993,15 @@ class LazyTensorStorage(TensorStorage):
993993
Defaults to ``False``.
994994
consolidated (bool, optional): if ``True``, the storage will be consolidated after
995995
its first expansion. Defaults to ``False``.
996+
empty_lazy (bool, optional): if ``True``, any lazy tensordict in the first tensordict
997+
passed to the storage will be emptied of its content. This can be used to store
998+
ragged data or content with exclusive keys (e.g., when some but not all environments
999+
provide extra data to be stored in the buffer).
1000+
Setting `empty_lazy` to `True` requires :meth:`~.extend` to be called first (a call to `add`
1001+
will result in an exception).
1002+
Recall that data stored in lazy stacks is not stored contiguously in memory: indexing can be
1003+
slower than contiguous data and serialization is more hazardous. Use with caution!
1004+
Defaults to ``False``.
9961005
9971006
Examples:
9981007
>>> data = TensorDict({
@@ -1054,6 +1063,7 @@ def __init__(
10541063
ndim: int = 1,
10551064
compilable: bool = False,
10561065
consolidated: bool = False,
1066+
empty_lazy: bool = False,
10571067
):
10581068
super().__init__(
10591069
storage=None,
@@ -1062,11 +1072,13 @@ def __init__(
10621072
ndim=ndim,
10631073
compilable=compilable,
10641074
)
1075+
self.empty_lazy = empty_lazy
10651076
self.consolidated = consolidated
10661077

10671078
def _init(
10681079
self,
10691080
data: TensorDictBase | torch.Tensor | PyTree, # noqa: F821
1081+
shape: torch.Size | None = None,
10701082
) -> None:
10711083
if not self._compilable:
10721084
# TODO: Investigate why this seems to have a performance impact with
@@ -1087,8 +1099,21 @@ def max_size_along_dim0(data_shape):
10871099

10881100
if is_tensor_collection(data):
10891101
out = data.to(self.device)
1090-
out: TensorDictBase = torch.empty_like(
1091-
out.expand(max_size_along_dim0(data.shape))
1102+
if self.empty_lazy:
1103+
if shape is None:
1104+
# shape is None in add
1105+
raise RuntimeError(
1106+
"Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1107+
)
1108+
out: TensorDictBase = torch.empty_like(
1109+
out.expand(max_size_along_dim0(data.shape))
1110+
)
1111+
elif shape is None:
1112+
shape = data.shape
1113+
else:
1114+
out = out[0]
1115+
out: TensorDictBase = out.new_empty(
1116+
max_size_along_dim0(shape), empty_lazy=self.empty_lazy
10921117
)
10931118
if self.consolidated:
10941119
out = out.consolidate()
@@ -1286,7 +1311,9 @@ def load_state_dict(self, state_dict):
12861311
self.initialized = state_dict["initialized"]
12871312
self._len = state_dict["_len"]
12881313

1289-
def _init(self, data: TensorDictBase | torch.Tensor) -> None:
1314+
def _init(
1315+
self, data: TensorDictBase | torch.Tensor, *, shape: torch.Size | None = None
1316+
) -> None:
12901317
torchrl_logger.debug("Creating a MemmapStorage...")
12911318
if self.device == "auto":
12921319
self.device = data.device
@@ -1304,8 +1331,14 @@ def max_size_along_dim0(data_shape):
13041331
return (self.max_size, *data_shape)
13051332

13061333
if is_tensor_collection(data):
1334+
if shape is None:
1335+
# Within add()
1336+
shape = data.shape
1337+
else:
1338+
# Get the first element - we don't care about empty_lazy in memmap storages
1339+
data = data[0]
13071340
out = data.clone().to(self.device)
1308-
out = out.expand(max_size_along_dim0(data.shape))
1341+
out = out.expand(max_size_along_dim0(shape))
13091342
out = out.memmap_like(prefix=self.scratch_dir, existsok=self.existsok)
13101343
if torchrl_logger.isEnabledFor(logging.DEBUG):
13111344
for key, tensor in sorted(

torchrl/envs/transforms/transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7308,12 +7308,13 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
73087308
return reward_spec
73097309

73107310
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
7311-
time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"]
7312-
if not time_dim:
7311+
try:
7312+
time_dim = list(tensordict.names).index("time")
7313+
except ValueError:
73137314
raise ValueError(
73147315
"At least one dimension of the tensordict must be named 'time' in offline mode"
73157316
)
7316-
time_dim = time_dim[0] - 1
7317+
time_dim = time_dim - 1
73177318
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
73187319
reward = tensordict[in_key]
73197320
cumsum = reward.cumsum(time_dim)

0 commit comments

Comments
 (0)