Skip to content

Commit 2f8acf7

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 89efa18 commit 2f8acf7

File tree

2 files changed

+14
-6
lines changed

2 files changed

+14
-6
lines changed

torchrl/data/replay_buffers/storages.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1099,12 +1099,19 @@ def max_size_along_dim0(data_shape):
10991099

11001100
if is_tensor_collection(data):
11011101
out = data.to(self.device)
1102-
if self.empty_lazy and shape is None:
1103-
raise RuntimeError(
1104-
"Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1102+
if self.empty_lazy:
1103+
if shape is None:
1104+
# shape is None in add
1105+
raise RuntimeError(
1106+
"Make sure you have called `extend` and not `add` first when setting `empty_lazy=True`."
1107+
)
1108+
out: TensorDictBase = torch.empty_like(
1109+
out.expand(max_size_along_dim0(data.shape))
11051110
)
11061111
elif shape is None:
11071112
shape = data.shape
1113+
else:
1114+
out = out[0]
11081115
out: TensorDictBase = out.new_empty(
11091116
max_size_along_dim0(shape), empty_lazy=self.empty_lazy
11101117
)

torchrl/envs/transforms/transforms.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7308,12 +7308,13 @@ def transform_reward_spec(self, reward_spec: TensorSpec) -> TensorSpec:
73087308
return reward_spec
73097309

73107310
def forward(self, tensordict: TensorDictBase) -> TensorDictBase:
7311-
time_dim = [i for i, name in enumerate(tensordict.names) if name == "time"]
7312-
if not time_dim:
7311+
try:
7312+
time_dim = list(tensordict.names).index("time")
7313+
except ValueError:
73137314
raise ValueError(
73147315
"At least one dimension of the tensordict must be named 'time' in offline mode"
73157316
)
7316-
time_dim = time_dim[0] - 1
7317+
time_dim = time_dim - 1
73177318
for in_key, out_key in _zip_strict(self.in_keys, self.out_keys):
73187319
reward = tensordict[in_key]
73197320
cumsum = reward.cumsum(time_dim)

0 commit comments

Comments
 (0)