Skip to content

Commit 89efa18

Browse files
author
Vincent Moens
committed
Update
[ghstack-poisoned]
1 parent 8ab42c3 commit 89efa18

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

test/test_rb.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -887,6 +887,20 @@ def test_extend_lazystack(self, storage_type):
887887
assert isinstance(s, LazyStackedTensorDict)
888888
assert len(rb) == 5
889889

890+
def test_extend_empty_lazy(self):
891+
892+
rb = ReplayBuffer(
893+
storage=LazyTensorStorage(6, empty_lazy=True),
894+
batch_size=2,
895+
)
896+
td1 = TensorDict(a=torch.rand(4, 8), batch_size=4)
897+
td2 = TensorDict(a=torch.rand(3, 8), batch_size=3)
898+
ltd = LazyStackedTensorDict(td1, td2, stack_dim=0)
899+
rb.extend(ltd)
900+
s = rb.sample(3)
901+
assert isinstance(s, LazyStackedTensorDict)
902+
assert len(rb) == 2
903+
890904
@pytest.mark.parametrize("device_data", get_default_devices())
891905
@pytest.mark.parametrize("storage_type", [LazyMemmapStorage, LazyTensorStorage])
892906
@pytest.mark.parametrize("data_type", ["tensor", "tc", "td", "pytree"])

0 commit comments

Comments
 (0)