Skip to content

Commit 3c6c801

Browse files
author
Vincent Moens
authored
[BugFix] Fix max value within buffer during update priority (#2242)
1 parent 1d729e8 commit 3c6c801

File tree

2 files changed

+27
-6
lines changed

2 files changed

+27
-6
lines changed

test/test_rb.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2629,12 +2629,21 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
26292629
for data in torch.arange(20):
26302630
idx = rb.add(data)
26312631
rb.update_priority(idx, 21 - data)
2632-
if data <= 10 or not max_priority_within_buffer:
2632+
if data <= 10:
2633+
# The max is always going to be the first value
26332634
assert rb._sampler._max_priority[0] == 21
26342635
assert rb._sampler._max_priority[1] == 0
2635-
else:
2636-
assert rb._sampler._max_priority[0] == 10
2636+
elif not max_priority_within_buffer:
2637+
# The max is the historical max, which was at idx 0
2638+
assert rb._sampler._max_priority[0] == 21
26372639
assert rb._sampler._max_priority[1] == 0
2640+
else:
2641+
# the max is the current max. Find it and compare
2642+
sumtree = torch.as_tensor(
2643+
[rb._sampler._sum_tree[i] for i in range(rb._sampler._max_capacity)]
2644+
)
2645+
assert rb._sampler._max_priority[0] == sumtree.max()
2646+
assert rb._sampler._max_priority[1] == sumtree.argmax()
26382647
idx = rb.extend(torch.arange(10))
26392648
rb.update_priority(idx, 12)
26402649
if max_priority_within_buffer:

torchrl/data/replay_buffers/samplers.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -575,12 +575,24 @@ def update_priority(
575575
priority = priority[valid_index]
576576

577577
max_p, max_p_idx = priority.max(dim=0)
578-
max_priority = self._max_priority[0]
579-
if max_priority is None or max_p > max_priority:
580-
self._max_priority = (max_p, max_p_idx)
578+
cur_max_priority, cur_max_priority_index = self._max_priority
579+
if cur_max_priority is None or max_p > cur_max_priority:
580+
cur_max_priority, cur_max_priority_index = self._max_priority = (
581+
max_p,
582+
index[max_p_idx] if index.ndim else index,
583+
)
581584
priority = torch.pow(priority + self._eps, self._alpha)
582585
self._sum_tree[index] = priority
583586
self._min_tree[index] = priority
587+
if (
588+
self._max_priority_within_buffer
589+
and cur_max_priority_index is not None
590+
and (index == cur_max_priority_index).any()
591+
):
592+
maxval, maxidx = torch.tensor(
593+
[self._sum_tree[i] for i in range(self._max_capacity)]
594+
).max(0)
595+
self._max_priority = (maxval, maxidx)
584596

585597
def mark_update(
586598
self, index: Union[int, torch.Tensor], *, storage: Storage | None = None

0 commit comments

Comments
 (0)