Skip to content

Commit 1062e3e

Browse files
author
Vincent Moens
authored
[BugFix] Fix and test PRB priority update across dims and rb types (#2244)
1 parent 3c6c801 commit 1062e3e

File tree

2 files changed

+158
-5
lines changed

2 files changed

+158
-5
lines changed

test/test_rb.py

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,6 +2653,153 @@ def test_prb_update_max_priority(self, max_priority_within_buffer):
26532653
assert rb._sampler._max_priority[0] == 21
26542654
assert rb._sampler._max_priority[1] == 0
26552655

2656+
def test_prb_ndim(self):
2657+
"""This test lists all the possible ways of updating the priority of a PRB with RB, TRB and TPRB.
2658+
2659+
All tests are done for 1d and 2d TDs.
2660+
2661+
"""
2662+
torch.manual_seed(0)
2663+
np.random.seed(0)
2664+
2665+
# first case: 1d, RB
2666+
rb = ReplayBuffer(
2667+
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
2668+
storage=LazyTensorStorage(100),
2669+
batch_size=4,
2670+
)
2671+
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
2672+
idx = rb.extend(data)
2673+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2674+
rb.update_priority(idx, 2)
2675+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2676+
s, info = rb.sample(return_info=True)
2677+
rb.update_priority(info["index"], 3)
2678+
assert (
2679+
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[info["index"]]
2680+
== 3
2681+
).all()
2682+
2683+
# second case: 1d, TRB
2684+
rb = TensorDictReplayBuffer(
2685+
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
2686+
storage=LazyTensorStorage(100),
2687+
batch_size=4,
2688+
)
2689+
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
2690+
idx = rb.extend(data)
2691+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2692+
rb.update_priority(idx, 2)
2693+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2694+
s = rb.sample()
2695+
rb.update_priority(s["index"], 3)
2696+
assert (
2697+
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]] == 3
2698+
).all()
2699+
2700+
# third case: 1d TPRB
2701+
rb = TensorDictPrioritizedReplayBuffer(
2702+
alpha=1.0,
2703+
beta=1.0,
2704+
storage=LazyTensorStorage(100),
2705+
batch_size=4,
2706+
priority_key="p",
2707+
)
2708+
data = TensorDict({"a": torch.arange(10), "p": torch.ones(10) / 2}, [10])
2709+
idx = rb.extend(data)
2710+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2711+
rb.update_priority(idx, 2)
2712+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2713+
s = rb.sample()
2714+
2715+
s["p"] = torch.ones(4) * 10_000
2716+
rb.update_tensordict_priority(s)
2717+
assert (
2718+
torch.tensor([rb._sampler._sum_tree[i] for i in range(10)])[s["index"]]
2719+
== 10_000
2720+
).all()
2721+
2722+
s2 = rb.sample()
2723+
# All indices in s2 must be from s since we set a very high priority to these items
2724+
assert (s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).any(0).all()
2725+
2726+
# fourth case: 2d RB
2727+
rb = ReplayBuffer(
2728+
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
2729+
storage=LazyTensorStorage(100, ndim=2),
2730+
batch_size=4,
2731+
)
2732+
data = TensorDict(
2733+
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
2734+
)
2735+
idx = rb.extend(data)
2736+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2737+
rb.update_priority(idx, 2)
2738+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2739+
2740+
s, info = rb.sample(return_info=True)
2741+
rb.update_priority(info["index"], 3)
2742+
priorities = torch.tensor(
2743+
[rb._sampler._sum_tree[i] for i in range(10)]
2744+
).reshape((5, 2))
2745+
assert (priorities[info["index"]] == 3).all()
2746+
2747+
# fifth case: 2d TRB
2748+
# 2d
2749+
rb = TensorDictReplayBuffer(
2750+
sampler=PrioritizedSampler(max_capacity=100, alpha=1.0, beta=1.0),
2751+
storage=LazyTensorStorage(100, ndim=2),
2752+
batch_size=4,
2753+
)
2754+
data = TensorDict(
2755+
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
2756+
)
2757+
idx = rb.extend(data)
2758+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2759+
rb.update_priority(idx, 2)
2760+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2761+
2762+
s = rb.sample()
2763+
rb.update_priority(s["index"], 10_000)
2764+
priorities = torch.tensor(
2765+
[rb._sampler._sum_tree[i] for i in range(10)]
2766+
).reshape((5, 2))
2767+
assert (priorities[s["index"].unbind(-1)] == 10_000).all()
2768+
2769+
s2 = rb.sample()
2770+
assert (
2771+
(s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all()
2772+
)
2773+
2774+
# Sixth case: 2d TDPRB
2775+
rb = TensorDictPrioritizedReplayBuffer(
2776+
alpha=1.0,
2777+
beta=1.0,
2778+
storage=LazyTensorStorage(100, ndim=2),
2779+
batch_size=4,
2780+
priority_key="p",
2781+
)
2782+
data = TensorDict(
2783+
{"a": torch.arange(5).expand(2, 5), "p": torch.ones(2, 5) / 2}, [2, 5]
2784+
)
2785+
idx = rb.extend(data)
2786+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 1).all()
2787+
rb.update_priority(idx, torch.ones(()) * 2)
2788+
assert (torch.tensor([rb._sampler._sum_tree[i] for i in range(10)]) == 2).all()
2789+
s = rb.sample()
2790+
# setting the priorities to a value that is so big that the buffer will resample them
2791+
s["p"] = torch.ones(4) * 10_000
2792+
rb.update_tensordict_priority(s)
2793+
priorities = torch.tensor(
2794+
[rb._sampler._sum_tree[i] for i in range(10)]
2795+
).reshape((5, 2))
2796+
assert (priorities[s["index"].unbind(-1)] == 10_000).all()
2797+
2798+
s2 = rb.sample()
2799+
assert (
2800+
(s2["index"].unsqueeze(0) == s["index"].unsqueeze(1)).all(-1).any(0).all()
2801+
)
2802+
26562803

26572804
def test_prioritized_slice_sampler_doc_example():
26582805
sampler = PrioritizedSliceSampler(max_capacity=9, num_slices=3, alpha=0.7, beta=0.9)

torchrl/data/replay_buffers/replay_buffers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -585,9 +585,12 @@ def extend(self, data: Sequence) -> torch.Tensor:
585585

586586
def update_priority(
587587
self,
588-
index: Union[int, torch.Tensor],
588+
index: Union[int, torch.Tensor, Tuple[torch.Tensor]],
589589
priority: Union[int, torch.Tensor],
590590
) -> None:
591+
if isinstance(index, tuple):
592+
index = torch.stack(index, -1)
593+
priority = torch.as_tensor(priority)
591594
if self.dim_extend > 0 and priority.ndim > 1:
592595
priority = self._transpose(priority).flatten()
593596
# priority = priority.flatten()
@@ -1095,7 +1098,7 @@ def _get_priority_vector(self, tensordict: TensorDictBase) -> torch.Tensor:
10951098
dtype=torch.float,
10961099
device=tensordict.device,
10971100
).expand(tensordict.shape[0])
1098-
if self._storage.ndim > 1:
1101+
if self._storage.ndim > 1 and priority.ndim >= self._storage.ndim:
10991102
# We have to flatten the priority otherwise we'll be aggregating
11001103
# the priority across batches
11011104
priority = priority.flatten(0, self._storage.ndim - 1)
@@ -1172,9 +1175,12 @@ def update_tensordict_priority(self, data: TensorDictBase) -> None:
11721175
else:
11731176
priority = torch.as_tensor(self._get_priority_item(data))
11741177
index = data.get("index")
1175-
while index.shape != priority.shape:
1176-
# reduce index
1177-
index = index[..., 0]
1178+
if self._storage.ndim > 1 and index.ndim == 2:
1179+
index = index.unbind(-1)
1180+
else:
1181+
while index.shape != priority.shape:
1182+
# reduce index
1183+
index = index[..., 0]
11781184
return self.update_priority(index, priority)
11791185

11801186
def sample(

0 commit comments

Comments
 (0)