Skip to content

Commit 5851652

Browse files
LTluttmannvmoens
andauthored
[Feature] Add scheduler for alpha/beta parameters of PrioritizedSampler (#2452)
Co-authored-by: Vincent Moens <vincentmoens@gmail.com>
1 parent 6d1a1b3 commit 5851652

File tree

3 files changed

+360
-0
lines changed

3 files changed

+360
-0
lines changed

test/test_rb.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,11 @@
5959
SliceSampler,
6060
SliceSamplerWithoutReplacement,
6161
)
62+
from torchrl.data.replay_buffers.scheduler import (
63+
LinearScheduler,
64+
SchedulerList,
65+
StepScheduler,
66+
)
6267

6368
from torchrl.data.replay_buffers.storages import (
6469
LazyMemmapStorage,
@@ -100,6 +105,7 @@
100105
VecNorm,
101106
)
102107

108+
103109
OLD_TORCH = parse(torch.__version__) < parse("2.0.0")
104110
_has_tv = importlib.util.find_spec("torchvision") is not None
105111
_has_gym = importlib.util.find_spec("gym") is not None
@@ -3041,6 +3047,77 @@ def test_prioritized_slice_sampler_episodes(device):
30413047
), "after priority update, only episode 1 and 3 are expected to be sampled"
30423048

30433049

3050+
@pytest.mark.parametrize("alpha", [0.6, torch.tensor(1.0)])
3051+
@pytest.mark.parametrize("beta", [0.7, torch.tensor(0.1)])
3052+
@pytest.mark.parametrize("gamma", [0.1])
3053+
@pytest.mark.parametrize("total_steps", [200])
3054+
@pytest.mark.parametrize("n_annealing_steps", [100])
3055+
@pytest.mark.parametrize("anneal_every_n", [10, 159])
3056+
@pytest.mark.parametrize("alpha_min", [0, 0.2])
3057+
@pytest.mark.parametrize("beta_max", [1, 1.4])
3058+
def test_prioritized_parameter_scheduler(
3059+
alpha,
3060+
beta,
3061+
gamma,
3062+
total_steps,
3063+
n_annealing_steps,
3064+
anneal_every_n,
3065+
alpha_min,
3066+
beta_max,
3067+
):
3068+
rb = TensorDictPrioritizedReplayBuffer(
3069+
alpha=alpha, beta=beta, storage=ListStorage(max_size=1000)
3070+
)
3071+
data = TensorDict({"data": torch.randn(1000, 5)}, batch_size=1000)
3072+
rb.extend(data)
3073+
alpha_scheduler = LinearScheduler(
3074+
rb, param_name="alpha", final_value=alpha_min, num_steps=n_annealing_steps
3075+
)
3076+
beta_scheduler = StepScheduler(
3077+
rb,
3078+
param_name="beta",
3079+
gamma=gamma,
3080+
n_steps=anneal_every_n,
3081+
max_value=beta_max,
3082+
mode="additive",
3083+
)
3084+
3085+
scheduler = SchedulerList(schedulers=(alpha_scheduler, beta_scheduler))
3086+
3087+
alpha = alpha if torch.is_tensor(alpha) else torch.tensor(alpha)
3088+
alpha_min = torch.tensor(alpha_min)
3089+
expected_alpha_vals = torch.linspace(alpha, alpha_min, n_annealing_steps + 1)
3090+
expected_alpha_vals = torch.nn.functional.pad(
3091+
expected_alpha_vals, (0, total_steps - n_annealing_steps), value=alpha_min
3092+
)
3093+
3094+
expected_beta_vals = [beta]
3095+
annealing_steps = total_steps // anneal_every_n
3096+
gammas = torch.arange(0, annealing_steps + 1, dtype=torch.float32) * gamma
3097+
expected_beta_vals = (
3098+
(beta + gammas).repeat_interleave(anneal_every_n).clip(None, beta_max)
3099+
)
3100+
for i in range(total_steps):
3101+
curr_alpha = rb.sampler.alpha
3102+
torch.testing.assert_close(
3103+
curr_alpha
3104+
if torch.is_tensor(curr_alpha)
3105+
else torch.tensor(curr_alpha).float(),
3106+
expected_alpha_vals[i],
3107+
msg=f"expected {expected_alpha_vals[i]}, got {curr_alpha}",
3108+
)
3109+
curr_beta = rb.sampler.beta
3110+
torch.testing.assert_close(
3111+
curr_beta
3112+
if torch.is_tensor(curr_beta)
3113+
else torch.tensor(curr_beta).float(),
3114+
expected_beta_vals[i],
3115+
msg=f"expected {expected_beta_vals[i]}, got {curr_beta}",
3116+
)
3117+
rb.sample(20)
3118+
scheduler.step()
3119+
3120+
30443121
class TestEnsemble:
30453122
def _make_data(self, data_type):
30463123
if data_type is torch.Tensor:

torchrl/data/replay_buffers/samplers.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,22 @@ def __repr__(self):
395395
def max_size(self):
396396
return self._max_capacity
397397

398+
@property
399+
def alpha(self):
400+
return self._alpha
401+
402+
@alpha.setter
403+
def alpha(self, value):
404+
self._alpha = value
405+
406+
@property
407+
def beta(self):
408+
return self._beta
409+
410+
@beta.setter
411+
def beta(self, value):
412+
self._beta = value
413+
398414
def __getstate__(self):
399415
if get_spawning_popen() is not None:
400416
raise RuntimeError(
Lines changed: 267 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,267 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
#
3+
# This source code is licensed under the MIT license found in the
4+
# LICENSE file in the root directory of this source tree.
5+
from __future__ import annotations
6+
7+
from abc import ABC, abstractmethod
8+
9+
from typing import Any, Callable, Dict
10+
11+
import numpy as np
12+
13+
import torch
14+
15+
from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer
16+
from torchrl.data.replay_buffers.samplers import Sampler
17+
18+
19+
class ParameterScheduler(ABC):
20+
"""Scheduler to adjust the value of a given parameter of a replay buffer's sampler.
21+
22+
Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler.
23+
24+
Args:
25+
obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust
26+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter
27+
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
28+
Defaults to `None`.
29+
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
30+
Defaults to `None`.
31+
32+
"""
33+
34+
def __init__(
35+
self,
36+
obj: ReplayBuffer | Sampler,
37+
param_name: str,
38+
min_value: int | float | None = None,
39+
max_value: int | float | None = None,
40+
):
41+
if not isinstance(obj, (ReplayBuffer, Sampler)):
42+
raise TypeError(
43+
f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead."
44+
)
45+
self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj
46+
self.param_name = param_name
47+
self._min_val = min_value or float("-inf")
48+
self._max_val = max_value or float("inf")
49+
if not hasattr(self.sampler, self.param_name):
50+
raise ValueError(
51+
f"Provided class {type(obj).__name__} does not have an attribute {param_name}"
52+
)
53+
initial_val = getattr(self.sampler, self.param_name)
54+
if isinstance(initial_val, torch.Tensor):
55+
initial_val = initial_val.clone()
56+
self.backend = torch
57+
else:
58+
self.backend = np
59+
self.initial_val = initial_val
60+
self._step_cnt = 0
61+
62+
def state_dict(self):
63+
"""Returns the state of the scheduler as a :class:`dict`.
64+
65+
It contains an entry for every variable in ``self.__dict__`` which
66+
is not the sampler.
67+
"""
68+
sd = dict(self.__dict__)
69+
del sd["sampler"]
70+
return sd
71+
72+
def load_state_dict(self, state_dict: Dict[str, Any]):
73+
"""Load the scheduler's state.
74+
75+
Args:
76+
state_dict (dict): scheduler state. Should be an object returned
77+
from a call to :meth:`state_dict`.
78+
"""
79+
self.__dict__.update(state_dict)
80+
81+
def step(self):
82+
self._step_cnt += 1
83+
# Apply the step function
84+
new_value = self._step()
85+
# clip value to specified range
86+
new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val)
87+
# Set the new value of the parameter dynamically
88+
setattr(self.sampler, self.param_name, new_value_clipped)
89+
90+
@abstractmethod
91+
def _step(self):
92+
...
93+
94+
95+
class LambdaScheduler(ParameterScheduler):
96+
"""Sets a parameter to its initial value times a given function.
97+
98+
Similar to :class:`~torch.optim.LambdaLR`.
99+
100+
Args:
101+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
102+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
103+
beta parameter.
104+
lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer
105+
parameter ``step_count``.
106+
min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted
107+
Defaults to `None`.
108+
max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted
109+
Defaults to `None`.
110+
111+
"""
112+
113+
def __init__(
114+
self,
115+
obj: ReplayBuffer | Sampler,
116+
param_name: str,
117+
lambda_fn: Callable[[int], float],
118+
min_value: int | float | None = None,
119+
max_value: int | float | None = None,
120+
):
121+
super().__init__(obj, param_name, min_value, max_value)
122+
self.lambda_fn = lambda_fn
123+
124+
def _step(self):
125+
return self.initial_val * self.lambda_fn(self._step_cnt)
126+
127+
128+
class LinearScheduler(ParameterScheduler):
129+
"""A linear scheduler for gradually altering a parameter in an object over a given number of steps.
130+
131+
This scheduler linearly interpolates between the initial value of the parameter and a final target value.
132+
133+
Args:
134+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
135+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
136+
beta parameter.
137+
final_value (number): The final value that the parameter will reach after the
138+
specified number of steps.
139+
num_steps (number, optional): The total number of steps over which the parameter
140+
will be linearly altered.
141+
142+
Example:
143+
>>> # xdoctest: +SKIP
144+
>>> # Assuming sampler uses initial beta = 0.6
145+
>>> # beta = 0.7 if step == 1
146+
>>> # beta = 0.8 if step == 2
147+
>>> # beta = 0.9 if step == 3
148+
>>> # beta = 1.0 if step >= 4
149+
>>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4)
150+
>>> for epoch in range(100):
151+
>>> train(...)
152+
>>> validate(...)
153+
>>> scheduler.step()
154+
"""
155+
156+
def __init__(
157+
self,
158+
obj: ReplayBuffer | Sampler,
159+
param_name: str,
160+
final_value: int | float,
161+
num_steps: int,
162+
):
163+
super().__init__(obj, param_name)
164+
if isinstance(self.initial_val, torch.Tensor):
165+
# cast to same type as initial value
166+
final_value = torch.tensor(final_value).to(self.initial_val)
167+
self.final_val = final_value
168+
self.num_steps = num_steps
169+
self._delta = (self.final_val - self.initial_val) / self.num_steps
170+
171+
def _step(self):
172+
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
173+
# without graph breaks
174+
if self._step_cnt < self.num_steps:
175+
return self.initial_val + (self._delta * self._step_cnt)
176+
else:
177+
return self.final_val
178+
179+
180+
class StepScheduler(ParameterScheduler):
181+
"""A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes.
182+
183+
The scheduler can apply:
184+
1. Multiplicative changes: `new_val = curr_val * gamma`
185+
2. Additive changes: `new_val = curr_val + gamma`
186+
187+
Args:
188+
obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself).
189+
param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the
190+
beta parameter.
191+
gamma (int or float, optional): The value by which to adjust the parameter,
192+
either in a multiplicative or additive way.
193+
n_steps (int, optional): The number of steps after which the parameter should be altered.
194+
Defaults to 1.
195+
mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`.
196+
Defaults to `'multiplicative'`.
197+
min_value (int or float, optional): a lower bound for the parameter to be adjusted.
198+
Defaults to `None`.
199+
max_value (int or float, optional): an upper bound for the parameter to be adjusted.
200+
Defaults to `None`.
201+
202+
Example:
203+
>>> # xdoctest: +SKIP
204+
>>> # Assuming sampler uses initial beta = 0.6
205+
>>> # beta = 0.6 if 0 <= step < 10
206+
>>> # beta = 0.7 if 10 <= step < 20
207+
>>> # beta = 0.8 if 20 <= step < 30
208+
>>> # beta = 0.9 if 30 <= step < 40
209+
>>> # beta = 1.0 if 40 <= step
210+
>>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0)
211+
>>> for epoch in range(100):
212+
>>> train(...)
213+
>>> validate(...)
214+
>>> scheduler.step()
215+
"""
216+
217+
def __init__(
218+
self,
219+
obj: ReplayBuffer | Sampler,
220+
param_name: str,
221+
gamma: int | float = 0.9,
222+
n_steps: int = 1,
223+
mode: str = "multiplicative",
224+
min_value: int | float | None = None,
225+
max_value: int | float | None = None,
226+
):
227+
228+
super().__init__(obj, param_name, min_value, max_value)
229+
self.gamma = gamma
230+
self.n_steps = n_steps
231+
self.mode = mode
232+
if mode == "additive":
233+
operator = self.backend.add
234+
elif mode == "multiplicative":
235+
operator = self.backend.multiply
236+
else:
237+
raise ValueError(
238+
f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'."
239+
)
240+
self.operator = operator
241+
242+
def _step(self):
243+
"""Applies the scheduling logic to alter the parameter value every `n_steps`."""
244+
# Check if the current step count is a multiple of n_steps
245+
current_val = getattr(self.sampler, self.param_name)
246+
# Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile
247+
# without graph breaks
248+
if self._step_cnt % self.n_steps == 0:
249+
return self.operator(current_val, self.gamma)
250+
else:
251+
return current_val
252+
253+
254+
class SchedulerList:
255+
"""Simple container abstracting a list of schedulers."""
256+
257+
def __init__(self, schedulers: list[ParameterScheduler]) -> None:
258+
if isinstance(schedulers, ParameterScheduler):
259+
schedulers = [schedulers]
260+
self.schedulers = schedulers
261+
262+
def append(self, scheduler: ParameterScheduler):
263+
self.schedulers.append(scheduler)
264+
265+
def step(self):
266+
for scheduler in self.schedulers:
267+
scheduler.step()

0 commit comments

Comments
 (0)