Skip to content

Commit fc2b1ab

Browse files
committed
Merge remote-tracking branch 'origin/main' into openx
# Conflicts: # test/test_rb.py
2 parents e39638b + 781a5b2 commit fc2b1ab

File tree

2 files changed

+12
-2
lines changed

2 files changed

+12
-2
lines changed

test/test_rb.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1595,6 +1595,7 @@ def test_slice_sampler(
15951595
"obs": torch.randn((3, 4, 5)).expand(100, 3, 4, 5),
15961596
"act": torch.randn((20,)).expand(100, 20),
15971597
"steps": steps,
1598+
"count": torch.arange(100),
15981599
"other": torch.randn((20, 50)).expand(100, 20, 50),
15991600
done_key: done,
16001601
},
@@ -1621,7 +1622,8 @@ def test_slice_sampler(
16211622
num_slices = batch_size // slice_len
16221623
trajs_unique_id = set()
16231624
too_short = False
1624-
for _ in range(20):
1625+
count_unique = set()
1626+
for _ in range(10):
16251627
index, info = sampler.sample(storage, batch_size=batch_size)
16261628
if _data_prefix:
16271629
samples = storage._storage["_data"][index]
@@ -1640,6 +1642,14 @@ def test_slice_sampler(
16401642
trajs_unique_id = trajs_unique_id.union(
16411643
samples["another_episode"].view(-1).tolist()
16421644
)
1645+
count_unique = count_unique.union(samples.get("count").view(-1).tolist())
1646+
if len(count_unique) == 100:
1647+
# all items have been sampled
1648+
break
1649+
else:
1650+
raise AssertionError(
1651+
f"Not all items can be sampled: {set(range(100))-count_unique} are missing"
1652+
)
16431653
if strict_length:
16441654
assert not too_short
16451655
else:

torchrl/data/replay_buffers/samplers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -776,7 +776,7 @@ def _sample_slices(
776776
relative_starts = (
777777
(
778778
torch.rand(num_slices, device=lengths.device)
779-
* (lengths[traj_idx] - seq_length)
779+
* (lengths[traj_idx] - seq_length + 1)
780780
)
781781
.floor()
782782
.to(start_idx.dtype)

0 commit comments

Comments
 (0)