|
| 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates. |
| 2 | +# |
| 3 | +# This source code is licensed under the MIT license found in the |
| 4 | +# LICENSE file in the root directory of this source tree. |
| 5 | +from __future__ import annotations |
| 6 | + |
| 7 | +from abc import ABC, abstractmethod |
| 8 | + |
| 9 | +from typing import Any, Callable, Dict |
| 10 | + |
| 11 | +import numpy as np |
| 12 | + |
| 13 | +import torch |
| 14 | + |
| 15 | +from torchrl.data.replay_buffers.replay_buffers import ReplayBuffer |
| 16 | +from torchrl.data.replay_buffers.samplers import Sampler |
| 17 | + |
| 18 | + |
| 19 | +class ParameterScheduler(ABC): |
| 20 | + """Scheduler to adjust the value of a given parameter of a replay buffer's sampler. |
| 21 | +
|
| 22 | + Scheduler can for example be used to alter the alpha and beta values in the PrioritizedSampler. |
| 23 | +
|
| 24 | + Args: |
| 25 | + obj (ReplayBuffer or Sampler): the replay buffer or sampler whose sampler to adjust |
| 26 | + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the beta parameter |
| 27 | + min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted |
| 28 | + Defaults to `None`. |
| 29 | + max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted |
| 30 | + Defaults to `None`. |
| 31 | +
|
| 32 | + """ |
| 33 | + |
| 34 | + def __init__( |
| 35 | + self, |
| 36 | + obj: ReplayBuffer | Sampler, |
| 37 | + param_name: str, |
| 38 | + min_value: int | float | None = None, |
| 39 | + max_value: int | float | None = None, |
| 40 | + ): |
| 41 | + if not isinstance(obj, (ReplayBuffer, Sampler)): |
| 42 | + raise TypeError( |
| 43 | + f"ParameterScheduler only supports Sampler class. Pass either `ReplayBuffer` or `Sampler` object. Got {type(obj)} instead." |
| 44 | + ) |
| 45 | + self.sampler = obj.sampler if isinstance(obj, ReplayBuffer) else obj |
| 46 | + self.param_name = param_name |
| 47 | + self._min_val = min_value or float("-inf") |
| 48 | + self._max_val = max_value or float("inf") |
| 49 | + if not hasattr(self.sampler, self.param_name): |
| 50 | + raise ValueError( |
| 51 | + f"Provided class {type(obj).__name__} does not have an attribute {param_name}" |
| 52 | + ) |
| 53 | + initial_val = getattr(self.sampler, self.param_name) |
| 54 | + if isinstance(initial_val, torch.Tensor): |
| 55 | + initial_val = initial_val.clone() |
| 56 | + self.backend = torch |
| 57 | + else: |
| 58 | + self.backend = np |
| 59 | + self.initial_val = initial_val |
| 60 | + self._step_cnt = 0 |
| 61 | + |
| 62 | + def state_dict(self): |
| 63 | + """Returns the state of the scheduler as a :class:`dict`. |
| 64 | +
|
| 65 | + It contains an entry for every variable in ``self.__dict__`` which |
| 66 | + is not the sampler. |
| 67 | + """ |
| 68 | + sd = dict(self.__dict__) |
| 69 | + del sd["sampler"] |
| 70 | + return sd |
| 71 | + |
| 72 | + def load_state_dict(self, state_dict: Dict[str, Any]): |
| 73 | + """Load the scheduler's state. |
| 74 | +
|
| 75 | + Args: |
| 76 | + state_dict (dict): scheduler state. Should be an object returned |
| 77 | + from a call to :meth:`state_dict`. |
| 78 | + """ |
| 79 | + self.__dict__.update(state_dict) |
| 80 | + |
| 81 | + def step(self): |
| 82 | + self._step_cnt += 1 |
| 83 | + # Apply the step function |
| 84 | + new_value = self._step() |
| 85 | + # clip value to specified range |
| 86 | + new_value_clipped = self.backend.clip(new_value, self._min_val, self._max_val) |
| 87 | + # Set the new value of the parameter dynamically |
| 88 | + setattr(self.sampler, self.param_name, new_value_clipped) |
| 89 | + |
| 90 | + @abstractmethod |
| 91 | + def _step(self): |
| 92 | + ... |
| 93 | + |
| 94 | + |
| 95 | +class LambdaScheduler(ParameterScheduler): |
| 96 | + """Sets a parameter to its initial value times a given function. |
| 97 | +
|
| 98 | + Similar to :class:`~torch.optim.LambdaLR`. |
| 99 | +
|
| 100 | + Args: |
| 101 | + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). |
| 102 | + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the |
| 103 | + beta parameter. |
| 104 | + lambda_fn (Callable[[int], float]): A function which computes a multiplicative factor given an integer |
| 105 | + parameter ``step_count``. |
| 106 | + min_value (Union[int, float], optional): a lower bound for the parameter to be adjusted |
| 107 | + Defaults to `None`. |
| 108 | + max_value (Union[int, float], optional): an upper bound for the parameter to be adjusted |
| 109 | + Defaults to `None`. |
| 110 | +
|
| 111 | + """ |
| 112 | + |
| 113 | + def __init__( |
| 114 | + self, |
| 115 | + obj: ReplayBuffer | Sampler, |
| 116 | + param_name: str, |
| 117 | + lambda_fn: Callable[[int], float], |
| 118 | + min_value: int | float | None = None, |
| 119 | + max_value: int | float | None = None, |
| 120 | + ): |
| 121 | + super().__init__(obj, param_name, min_value, max_value) |
| 122 | + self.lambda_fn = lambda_fn |
| 123 | + |
| 124 | + def _step(self): |
| 125 | + return self.initial_val * self.lambda_fn(self._step_cnt) |
| 126 | + |
| 127 | + |
| 128 | +class LinearScheduler(ParameterScheduler): |
| 129 | + """A linear scheduler for gradually altering a parameter in an object over a given number of steps. |
| 130 | +
|
| 131 | + This scheduler linearly interpolates between the initial value of the parameter and a final target value. |
| 132 | +
|
| 133 | + Args: |
| 134 | + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). |
| 135 | + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the |
| 136 | + beta parameter. |
| 137 | + final_value (number): The final value that the parameter will reach after the |
| 138 | + specified number of steps. |
| 139 | + num_steps (number, optional): The total number of steps over which the parameter |
| 140 | + will be linearly altered. |
| 141 | +
|
| 142 | + Example: |
| 143 | + >>> # xdoctest: +SKIP |
| 144 | + >>> # Assuming sampler uses initial beta = 0.6 |
| 145 | + >>> # beta = 0.7 if step == 1 |
| 146 | + >>> # beta = 0.8 if step == 2 |
| 147 | + >>> # beta = 0.9 if step == 3 |
| 148 | + >>> # beta = 1.0 if step >= 4 |
| 149 | + >>> scheduler = LinearScheduler(sampler, param_name='beta', final_value=1.0, num_steps=4) |
| 150 | + >>> for epoch in range(100): |
| 151 | + >>> train(...) |
| 152 | + >>> validate(...) |
| 153 | + >>> scheduler.step() |
| 154 | + """ |
| 155 | + |
| 156 | + def __init__( |
| 157 | + self, |
| 158 | + obj: ReplayBuffer | Sampler, |
| 159 | + param_name: str, |
| 160 | + final_value: int | float, |
| 161 | + num_steps: int, |
| 162 | + ): |
| 163 | + super().__init__(obj, param_name) |
| 164 | + if isinstance(self.initial_val, torch.Tensor): |
| 165 | + # cast to same type as initial value |
| 166 | + final_value = torch.tensor(final_value).to(self.initial_val) |
| 167 | + self.final_val = final_value |
| 168 | + self.num_steps = num_steps |
| 169 | + self._delta = (self.final_val - self.initial_val) / self.num_steps |
| 170 | + |
| 171 | + def _step(self): |
| 172 | + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile |
| 173 | + # without graph breaks |
| 174 | + if self._step_cnt < self.num_steps: |
| 175 | + return self.initial_val + (self._delta * self._step_cnt) |
| 176 | + else: |
| 177 | + return self.final_val |
| 178 | + |
| 179 | + |
| 180 | +class StepScheduler(ParameterScheduler): |
| 181 | + """A step scheduler that alters a parameter after every n steps using either multiplicative or additive changes. |
| 182 | +
|
| 183 | + The scheduler can apply: |
| 184 | + 1. Multiplicative changes: `new_val = curr_val * gamma` |
| 185 | + 2. Additive changes: `new_val = curr_val + gamma` |
| 186 | +
|
| 187 | + Args: |
| 188 | + obj (ReplayBuffer or Sampler): the replay buffer whose sampler to adjust (or the sampler itself). |
| 189 | + param_name (str): the name of the attribute to adjust, e.g. `beta` to adjust the |
| 190 | + beta parameter. |
| 191 | + gamma (int or float, optional): The value by which to adjust the parameter, |
| 192 | + either in a multiplicative or additive way. |
| 193 | + n_steps (int, optional): The number of steps after which the parameter should be altered. |
| 194 | + Defaults to 1. |
| 195 | + mode (str, optional): The mode of scheduling. Can be either `'multiplicative'` or `'additive'`. |
| 196 | + Defaults to `'multiplicative'`. |
| 197 | + min_value (int or float, optional): a lower bound for the parameter to be adjusted. |
| 198 | + Defaults to `None`. |
| 199 | + max_value (int or float, optional): an upper bound for the parameter to be adjusted. |
| 200 | + Defaults to `None`. |
| 201 | +
|
| 202 | + Example: |
| 203 | + >>> # xdoctest: +SKIP |
| 204 | + >>> # Assuming sampler uses initial beta = 0.6 |
| 205 | + >>> # beta = 0.6 if 0 <= step < 10 |
| 206 | + >>> # beta = 0.7 if 10 <= step < 20 |
| 207 | + >>> # beta = 0.8 if 20 <= step < 30 |
| 208 | + >>> # beta = 0.9 if 30 <= step < 40 |
| 209 | + >>> # beta = 1.0 if 40 <= step |
| 210 | + >>> scheduler = StepScheduler(sampler, param_name='beta', gamma=0.1, mode='additive', max_value=1.0) |
| 211 | + >>> for epoch in range(100): |
| 212 | + >>> train(...) |
| 213 | + >>> validate(...) |
| 214 | + >>> scheduler.step() |
| 215 | + """ |
| 216 | + |
| 217 | + def __init__( |
| 218 | + self, |
| 219 | + obj: ReplayBuffer | Sampler, |
| 220 | + param_name: str, |
| 221 | + gamma: int | float = 0.9, |
| 222 | + n_steps: int = 1, |
| 223 | + mode: str = "multiplicative", |
| 224 | + min_value: int | float | None = None, |
| 225 | + max_value: int | float | None = None, |
| 226 | + ): |
| 227 | + |
| 228 | + super().__init__(obj, param_name, min_value, max_value) |
| 229 | + self.gamma = gamma |
| 230 | + self.n_steps = n_steps |
| 231 | + self.mode = mode |
| 232 | + if mode == "additive": |
| 233 | + operator = self.backend.add |
| 234 | + elif mode == "multiplicative": |
| 235 | + operator = self.backend.multiply |
| 236 | + else: |
| 237 | + raise ValueError( |
| 238 | + f"Invalid mode: {mode}. Choose 'multiplicative' or 'additive'." |
| 239 | + ) |
| 240 | + self.operator = operator |
| 241 | + |
| 242 | + def _step(self): |
| 243 | + """Applies the scheduling logic to alter the parameter value every `n_steps`.""" |
| 244 | + # Check if the current step count is a multiple of n_steps |
| 245 | + current_val = getattr(self.sampler, self.param_name) |
| 246 | + # Nit: we should use torch.where instead than if/else here to make the scheduler compatible with compile |
| 247 | + # without graph breaks |
| 248 | + if self._step_cnt % self.n_steps == 0: |
| 249 | + return self.operator(current_val, self.gamma) |
| 250 | + else: |
| 251 | + return current_val |
| 252 | + |
| 253 | + |
| 254 | +class SchedulerList: |
| 255 | + """Simple container abstracting a list of schedulers.""" |
| 256 | + |
| 257 | + def __init__(self, schedulers: list[ParameterScheduler]) -> None: |
| 258 | + if isinstance(schedulers, ParameterScheduler): |
| 259 | + schedulers = [schedulers] |
| 260 | + self.schedulers = schedulers |
| 261 | + |
| 262 | + def append(self, scheduler: ParameterScheduler): |
| 263 | + self.schedulers.append(scheduler) |
| 264 | + |
| 265 | + def step(self): |
| 266 | + for scheduler in self.schedulers: |
| 267 | + scheduler.step() |
0 commit comments