Skip to content

Commit eaf9ef2

Browse files
author
Vincent Moens
committed
Merge remote-tracking branch 'origin/main' into fix-set-truncated
2 parents 71a8d5f + f613eef commit eaf9ef2

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

test/test_rb.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1927,12 +1927,13 @@ def test_sampler_without_rep_state_dict(self, backend):
19271927
s = new_replay_buffer.sample(batch_size=1)
19281928
assert (s.exclude("index") == 0).all()
19291929

1930-
def test_sampler_without_replacement_cap_prefetch(self):
1930+
@pytest.mark.parametrize("drop_last", [False, True])
1931+
def test_sampler_without_replacement_cap_prefetch(self, drop_last):
19311932
torch.manual_seed(0)
1932-
data = TensorDict({"a": torch.arange(10)}, batch_size=[10])
1933+
data = TensorDict({"a": torch.arange(11)}, batch_size=[11])
19331934
rb = ReplayBuffer(
1934-
storage=LazyTensorStorage(10),
1935-
sampler=SamplerWithoutReplacement(),
1935+
storage=LazyTensorStorage(11),
1936+
sampler=SamplerWithoutReplacement(drop_last=drop_last),
19361937
batch_size=2,
19371938
prefetch=3,
19381939
)
@@ -1941,10 +1942,13 @@ def test_sampler_without_replacement_cap_prefetch(self):
19411942
for _ in range(100):
19421943
s = set()
19431944
for i, d in enumerate(rb):
1944-
assert i <= 4
1945+
assert i <= (4 + int(not drop_last)), i
19451946
s = s.union(set(d["a"].tolist()))
1946-
assert i == 4
1947-
assert s == set(range(10))
1947+
assert i == (4 + int(not drop_last)), i
1948+
if drop_last:
1949+
assert s != set(range(11))
1950+
else:
1951+
assert s == set(range(11))
19481952

19491953
@pytest.mark.parametrize(
19501954
"batch_size,num_slices,slice_len,prioritized",

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -637,7 +637,8 @@ def sample(self, batch_size: int | None = None, return_info: bool = False) -> An
637637
while (
638638
len(self._prefetch_queue)
639639
< min(self._sampler._remaining_batches, self._prefetch_cap)
640-
) and not self._sampler.ran_out:
640+
and not self._sampler.ran_out
641+
) or not len(self._prefetch_queue):
641642
fut = self._prefetch_executor.submit(self._sample, batch_size)
642643
self._prefetch_queue.append(fut)
643644
ret = self._prefetch_queue.popleft().result()

0 commit comments

Comments
 (0)