Skip to content

Commit

Permalink
Replace memory batch size with a fraction of the total batch size (#359)
Browse files Browse the repository at this point in the history
  • Loading branch information
wistuba authored Aug 17, 2023
1 parent 593cff3 commit 73df5c2
Show file tree
Hide file tree
Showing 30 changed files with 120 additions and 106 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"updater": "Offline-ER",
"batch_size": 128,
"memory_batch_size": 128,
"batch_size": 256,
"batch_memory_frac": 0.5,
"memory_size": 3300
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"updater": "Offline-ER",
"batch_size": 128,
"memory_batch_size": 128,
"batch_size": 256,
"batch_memory_frac": 0.5,
"memory_size": 10000
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
{
"updater": "Offline-ER",
"batch_size": 32,
"memory_batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"memory_size": 3450
}
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"learning_rate": 0.03,
"alpha": 0.2,
"beta": 0.5,
"batch_size": 32,
"memory_batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"memory_size": 500,
"max_epochs": 50,
"loss_normalization": 0,
Expand Down
4 changes: 2 additions & 2 deletions examples/nlp_finetuning/start.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@
"weight_decay": 0.0,
"learning_rate": 0.001,
"alpha": 0.5,
"batch_size": 32,
"memory_batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"memory_size": 300,
"loss_normalization": 0,
"loss_weight": 0.5,
Expand Down
2 changes: 1 addition & 1 deletion examples/simple_classifier_cifar10/start_with_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
"learning_rate": loguniform(1e-4, 1e-1),
"alpha": uniform(0.0, 1.0),
"batch_size": choice([32, 64, 128, 256]),
"memory_batch_size": 32,
"batch_memory_frac": 0.5,
"memory_size": 1000,
"loss_normalization": 0,
"loss_weight": uniform(0.0, 1.0),
Expand Down
4 changes: 2 additions & 2 deletions examples/simple_classifier_cifar10/start_without_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
"weight_decay": 0.0,
"learning_rate": 0.1,
"alpha": 0.5,
"batch_size": 32,
"memory_batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"memory_size": 300,
"loss_normalization": 0,
"loss_weight": 0.5,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
"momentum": 0.0,
"weight_decay": 1e-2,
"learning_rate": 0.05,
"batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"max_epochs": 50,
"memory_batch_size": 32,
"memory_size": 500,
}

Expand Down
4 changes: 2 additions & 2 deletions examples/train_mlp_locally/start_training_without_hpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
"weight_decay": 0.0,
"learning_rate": 0.1,
"alpha": 0.5,
"batch_size": 32,
"memory_batch_size": 32,
"batch_size": 64,
"batch_memory_frac": 0.5,
"memory_size": 500,
"loss_normalization": 0,
"loss_weight": 0.5,
Expand Down
18 changes: 9 additions & 9 deletions src/renate/cli/parsing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_updater_and_learner_kwargs(
"loss_weight",
"ema_memory_update_gamma",
"memory_size",
"memory_batch_size",
"batch_memory_frac",
"loss_normalization",
]
updater_class = None
Expand Down Expand Up @@ -93,7 +93,7 @@ def get_updater_and_learner_kwargs(
]
updater_class = SuperExperienceReplayModelUpdater
elif args.updater == "Offline-ER":
learner_args = learner_args + ["loss_weight_new_data", "memory_size", "memory_batch_size"]
learner_args = learner_args + ["loss_weight_new_data", "memory_size", "batch_memory_frac"]
updater_class = OfflineExperienceReplayModelUpdater
elif args.updater == "RD":
learner_args = learner_args + ["memory_size"]
Expand All @@ -108,7 +108,7 @@ def get_updater_and_learner_kwargs(
learner_args = learner_args
updater_class = FineTuningModelUpdater
elif args.updater == "Avalanche-ER":
learner_args = learner_args + ["memory_size", "memory_batch_size"]
learner_args = learner_args + ["memory_size", "batch_memory_frac"]
from renate.updaters.avalanche.model_updater import ExperienceReplayAvalancheModelUpdater

updater_class = ExperienceReplayAvalancheModelUpdater
Expand All @@ -123,7 +123,7 @@ def get_updater_and_learner_kwargs(

updater_class = LearningWithoutForgettingModelUpdater
elif args.updater == "Avalanche-iCaRL":
learner_args = learner_args + ["memory_size", "memory_batch_size"]
learner_args = learner_args + ["memory_size", "batch_memory_frac"]
from renate.updaters.avalanche.model_updater import ICaRLModelUpdater

updater_class = ICaRLModelUpdater
Expand Down Expand Up @@ -428,11 +428,11 @@ def _add_replay_learner_arguments(arguments: Dict[str, Dict[str, Any]]) -> None:
"help": "Memory size available for the memory buffer. Default: "
f"{defaults.MEMORY_SIZE}.",
},
"memory_batch_size": {
"type": int,
"default": defaults.BATCH_SIZE,
"help": "Batch size used during model update for the memory buffer. Default: "
f"{defaults.BATCH_SIZE}.",
"batch_memory_frac": {
"type": float,
"default": defaults.BATCH_MEMORY_FRAC,
"help": "Fraction of the batch populated with memory data. Default: "
f"{defaults.BATCH_MEMORY_FRAC}.",
},
}
)
Expand Down
1 change: 1 addition & 0 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
WEIGHT_DECAY = 0.0
MAX_EPOCHS = 50
BATCH_SIZE = 32
BATCH_MEMORY_FRAC = 0.5
LOSS_WEIGHT = 1.0
SEED = 0
EMA_MEMORY_UPDATE_GAMMA = 1.0
Expand Down
4 changes: 2 additions & 2 deletions src/renate/updaters/avalanche/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def update_settings(
avalanche_learner._criterion = self._loss_fn
avalanche_learner.train_epochs = max_epochs
avalanche_learner.train_mb_size = self._batch_size
avalanche_learner.eval_mb_size = self._batch_size
avalanche_learner.eval_mb_size = self._batch_size + getattr(self, "_memory_batch_size", 0)
avalanche_learner.device = device
avalanche_learner.eval_every = eval_every

Expand All @@ -57,7 +57,7 @@ def _create_avalanche_learner(
optimizer=optimizer,
criterion=self._loss_fn,
train_mb_size=self._batch_size,
eval_mb_size=self._batch_size,
eval_mb_size=self._batch_size + getattr(self, "_memory_batch_size", 0),
train_epochs=train_epochs,
plugins=plugins,
evaluator=default_evaluator(),
Expand Down
4 changes: 2 additions & 2 deletions src/renate/updaters/avalanche/model_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
Expand Down Expand Up @@ -289,7 +289,7 @@ def __init__(
learner_kwargs = {
"batch_size": batch_size,
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"seed": seed,
}
super().__init__(
Expand Down
20 changes: 10 additions & 10 deletions src/renate/updaters/experimental/er.py
Original file line number Diff line number Diff line change
Expand Up @@ -532,7 +532,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
Expand Down Expand Up @@ -566,7 +566,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
Expand Down Expand Up @@ -614,7 +614,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
Expand Down Expand Up @@ -649,7 +649,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
Expand Down Expand Up @@ -698,7 +698,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
Expand Down Expand Up @@ -734,7 +734,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
Expand Down Expand Up @@ -784,7 +784,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
Expand Down Expand Up @@ -823,7 +823,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
Expand Down Expand Up @@ -876,7 +876,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight: float = defaults.LOSS_WEIGHT,
ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA,
loss_normalization: int = defaults.LOSS_NORMALIZATION,
Expand Down Expand Up @@ -921,7 +921,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight": loss_weight,
"ema_memory_update_gamma": ema_memory_update_gamma,
"loss_normalization": loss_normalization,
Expand Down
4 changes: 2 additions & 2 deletions src/renate/updaters/experimental/gdumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
batch_size: int = defaults.BATCH_SIZE,
Expand Down Expand Up @@ -136,7 +136,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"batch_size": batch_size,
"seed": seed,
}
Expand Down
11 changes: 2 additions & 9 deletions src/renate/updaters/experimental/offline_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,11 @@ class OfflineExperienceReplayLearner(ReplayLearner):
terminated.
Args:
memory_size: The maximum size of the memory.
memory_batch_size: Size of batches sampled from the memory. The memory batch will be
appended to the batch sampled from the current dataset, leading to an effective batch
size of `memory_batch_size + batch_size`.
loss_weight_new_data: The training loss will be a convex combination of the loss on the new
data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here,
it will be used as the weight for the new data. If `None`, the weight will be set
dynamically to `N_t / sum([N_1, ..., N_t])`, where `N_i` denotes the size of task/chunk
`i` and the current task is `t`.
buffer_transform: The transformation to be applied to the memory buffer data samples.
buffer_target_transform: The target transformation to be applied to the memory buffer target
samples.
"""

def __init__(self, loss_weight_new_data: Optional[float] = None, **kwargs) -> None:
Expand Down Expand Up @@ -147,7 +140,7 @@ def __init__(
loss_fn: torch.nn.Module,
optimizer: Callable[[List[Parameter]], Optimizer],
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC,
loss_weight_new_data: Optional[float] = None,
learning_rate_scheduler: Optional[partial] = None,
learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501
Expand Down Expand Up @@ -178,7 +171,7 @@ def __init__(
):
learner_kwargs = {
"memory_size": memory_size,
"memory_batch_size": memory_batch_size,
"batch_memory_frac": batch_memory_frac,
"loss_weight_new_data": loss_weight_new_data,
"batch_size": batch_size,
"seed": seed,
Expand Down
17 changes: 11 additions & 6 deletions src/renate/updaters/learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -456,9 +456,7 @@ class ReplayLearner(Learner, abc.ABC):
Args:
memory_size: The maximum size of the memory.
memory_batch_size: Size of batches sampled from the memory. The memory batch will be
appended to the batch sampled from the current dataset, leading to an effective batch
size of `memory_batch_size + batch_size`.
batch_memory_frac: Fraction of the batch that is sampled from rehearsal memory.
buffer_transform: The transformation to be applied to the memory buffer data samples.
buffer_target_transform: The target transformation to be applied to the memory buffer target
samples.
Expand All @@ -468,14 +466,21 @@ class ReplayLearner(Learner, abc.ABC):
def __init__(
self,
memory_size: int,
memory_batch_size: int = defaults.BATCH_SIZE,
batch_size: int = defaults.BATCH_SIZE,
batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC,
buffer_transform: Optional[Callable] = None,
buffer_target_transform: Optional[Callable] = None,
seed: int = defaults.SEED,
**kwargs,
) -> None:
super().__init__(seed=seed, **kwargs)
self._memory_batch_size = min(memory_size, memory_batch_size)
if not (0 <= batch_memory_frac <= 1):
raise ValueError(
f"Expecting batch_memory_frac to be in [0, 1], received {batch_memory_frac}."
)
memory_batch_size = min(memory_size, int(batch_memory_frac * batch_size))
batch_size = batch_size - memory_batch_size
super().__init__(batch_size=batch_size, seed=seed, **kwargs)
self._memory_batch_size = memory_batch_size
self._memory_buffer = ReservoirBuffer(
max_size=memory_size,
seed=seed,
Expand Down
4 changes: 2 additions & 2 deletions src/renate/utils/config_spaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,13 @@ def _get_range(start, end, step):
"momentum": choice([0.0, 0.9, 0.99]),
"weight_decay": loguniform(1e-6, 1e-2),
"learning_rate": loguniform(0.001, 0.5),
"batch_size": 32,
"batch_size": 64,
"max_epochs": 50,
}
_replay_config_space = {
**_learner_config_space,
**{
"memory_batch_size": 32,
"batch_memory_frac": 0.5,
"memory_size": 1000,
},
}
Expand Down
Loading

0 comments on commit 73df5c2

Please sign in to comment.