From 73df5c2bc6427bdc78e6d766e043b64d92c26cd9 Mon Sep 17 00:00:00 2001 From: wistuba Date: Thu, 17 Aug 2023 11:09:16 +0200 Subject: [PATCH] Replace memory batch size with a fraction of the total batch size (#359) --- .../updaters/offline-er-clear10.json | 4 +- .../updaters/offline-er-clear100.json | 4 +- .../updaters/offline-er-domainnet.json | 4 +- .../class_incremental_learning_cifar10_der.py | 4 +- examples/nlp_finetuning/start.py | 4 +- .../start_with_hpo.py | 2 +- .../start_without_hpo.py | 4 +- .../start_training_with_er_without_hpo.py | 4 +- .../start_training_without_hpo.py | 4 +- src/renate/cli/parsing_functions.py | 18 ++++---- src/renate/defaults.py | 1 + src/renate/updaters/avalanche/learner.py | 4 +- .../updaters/avalanche/model_updater.py | 4 +- src/renate/updaters/experimental/er.py | 20 ++++----- src/renate/updaters/experimental/gdumb.py | 4 +- .../updaters/experimental/offline_er.py | 11 +---- src/renate/updaters/learner.py | 17 ++++--- src/renate/utils/config_spaces.py | 4 +- test/conftest.py | 6 +-- .../updaters/avalanche-er-buffer500.json | 3 +- .../configs/updaters/cls-er-buffer500.json | 4 +- .../configs/updaters/der-buffer500.json | 4 +- .../configs/updaters/er-buffer500.json | 4 +- .../configs/updaters/gdumb-buffer500.json | 4 +- .../updaters/offline-er-buffer500.json | 4 +- .../configs/updaters/pod-er-buffer500.json | 4 +- .../configs/updaters/super-er-buffer500.json | 4 +- .../avalanche/test_avalanche_learner.py | 10 ++++- .../avalanche/test_avalanche_model_updater.py | 18 ++++---- test/renate/updaters/experimental/test_er.py | 44 ++++++++++--------- 30 files changed, 120 insertions(+), 106 deletions(-) diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear10.json b/benchmarks/experiment_configs/updaters/offline-er-clear10.json index 6fe89caf..b1714bff 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-clear10.json +++ b/benchmarks/experiment_configs/updaters/offline-er-clear10.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 128, - "memory_batch_size": 128, + "batch_size": 256, + "batch_memory_frac": 0.5, "memory_size": 3300 } diff --git a/benchmarks/experiment_configs/updaters/offline-er-clear100.json b/benchmarks/experiment_configs/updaters/offline-er-clear100.json index a6d09005..e91e4fad 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-clear100.json +++ b/benchmarks/experiment_configs/updaters/offline-er-clear100.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 128, - "memory_batch_size": 128, + "batch_size": 256, + "batch_memory_frac": 0.5, "memory_size": 10000 } diff --git a/benchmarks/experiment_configs/updaters/offline-er-domainnet.json b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json index b336094f..d403b96c 100644 --- a/benchmarks/experiment_configs/updaters/offline-er-domainnet.json +++ b/benchmarks/experiment_configs/updaters/offline-er-domainnet.json @@ -1,6 +1,6 @@ { "updater": "Offline-ER", - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 3450 } diff --git a/examples/benchmarking/class_incremental_learning_cifar10_der.py b/examples/benchmarking/class_incremental_learning_cifar10_der.py index b9b8a67e..8b5c67e4 100644 --- a/examples/benchmarking/class_incremental_learning_cifar10_der.py +++ b/examples/benchmarking/class_incremental_learning_cifar10_der.py @@ -11,8 +11,8 @@ "learning_rate": 0.03, "alpha": 0.2, "beta": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 500, "max_epochs": 50, "loss_normalization": 0, diff --git a/examples/nlp_finetuning/start.py b/examples/nlp_finetuning/start.py index cb6a8951..932d9dfe 100644 --- a/examples/nlp_finetuning/start.py +++ b/examples/nlp_finetuning/start.py @@ -12,8 +12,8 @@ "weight_decay": 0.0, "learning_rate": 0.001, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 300, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/examples/simple_classifier_cifar10/start_with_hpo.py b/examples/simple_classifier_cifar10/start_with_hpo.py index 16fb9e2e..eac54342 100644 --- a/examples/simple_classifier_cifar10/start_with_hpo.py +++ b/examples/simple_classifier_cifar10/start_with_hpo.py @@ -13,7 +13,7 @@ "learning_rate": loguniform(1e-4, 1e-1), "alpha": uniform(0.0, 1.0), "batch_size": choice([32, 64, 128, 256]), - "memory_batch_size": 32, + "batch_memory_frac": 0.5, "memory_size": 1000, "loss_normalization": 0, "loss_weight": uniform(0.0, 1.0), diff --git a/examples/simple_classifier_cifar10/start_without_hpo.py b/examples/simple_classifier_cifar10/start_without_hpo.py index f7f408f3..177002e6 100644 --- a/examples/simple_classifier_cifar10/start_without_hpo.py +++ b/examples/simple_classifier_cifar10/start_without_hpo.py @@ -11,8 +11,8 @@ "weight_decay": 0.0, "learning_rate": 0.1, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 300, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/examples/train_mlp_locally/start_training_with_er_without_hpo.py b/examples/train_mlp_locally/start_training_with_er_without_hpo.py index 6031f107..b9e5f4f6 100644 --- a/examples/train_mlp_locally/start_training_with_er_without_hpo.py +++ b/examples/train_mlp_locally/start_training_with_er_without_hpo.py @@ -7,9 +7,9 @@ "momentum": 0.0, "weight_decay": 1e-2, "learning_rate": 0.05, - "batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "max_epochs": 50, - "memory_batch_size": 32, "memory_size": 500, } diff --git a/examples/train_mlp_locally/start_training_without_hpo.py b/examples/train_mlp_locally/start_training_without_hpo.py index f8d20dd3..e58c264f 100644 --- a/examples/train_mlp_locally/start_training_without_hpo.py +++ b/examples/train_mlp_locally/start_training_without_hpo.py @@ -9,8 +9,8 @@ "weight_decay": 0.0, "learning_rate": 0.1, "alpha": 0.5, - "batch_size": 32, - "memory_batch_size": 32, + "batch_size": 64, + "batch_memory_frac": 0.5, "memory_size": 500, "loss_normalization": 0, "loss_weight": 0.5, diff --git a/src/renate/cli/parsing_functions.py b/src/renate/cli/parsing_functions.py index 4dae8f78..fe4f33cd 100644 --- a/src/renate/cli/parsing_functions.py +++ b/src/renate/cli/parsing_functions.py @@ -53,7 +53,7 @@ def get_updater_and_learner_kwargs( "loss_weight", "ema_memory_update_gamma", "memory_size", - "memory_batch_size", + "batch_memory_frac", "loss_normalization", ] updater_class = None @@ -93,7 +93,7 @@ def get_updater_and_learner_kwargs( ] updater_class = SuperExperienceReplayModelUpdater elif args.updater == "Offline-ER": - learner_args = learner_args + ["loss_weight_new_data", "memory_size", "memory_batch_size"] + learner_args = learner_args + ["loss_weight_new_data", "memory_size", "batch_memory_frac"] updater_class = OfflineExperienceReplayModelUpdater elif args.updater == "RD": learner_args = learner_args + ["memory_size"] @@ -108,7 +108,7 @@ def get_updater_and_learner_kwargs( learner_args = learner_args updater_class = FineTuningModelUpdater elif args.updater == "Avalanche-ER": - learner_args = learner_args + ["memory_size", "memory_batch_size"] + learner_args = learner_args + ["memory_size", "batch_memory_frac"] from renate.updaters.avalanche.model_updater import ExperienceReplayAvalancheModelUpdater updater_class = ExperienceReplayAvalancheModelUpdater @@ -123,7 +123,7 @@ def get_updater_and_learner_kwargs( updater_class = LearningWithoutForgettingModelUpdater elif args.updater == "Avalanche-iCaRL": - learner_args = learner_args + ["memory_size", "memory_batch_size"] + learner_args = learner_args + ["memory_size", "batch_memory_frac"] from renate.updaters.avalanche.model_updater import ICaRLModelUpdater updater_class = ICaRLModelUpdater @@ -428,11 +428,11 @@ def _add_replay_learner_arguments(arguments: Dict[str, Dict[str, Any]]) -> None: "help": "Memory size available for the memory buffer. Default: " f"{defaults.MEMORY_SIZE}.", }, - "memory_batch_size": { - "type": int, - "default": defaults.BATCH_SIZE, - "help": "Batch size used during model update for the memory buffer. Default: " - f"{defaults.BATCH_SIZE}.", + "batch_memory_frac": { + "type": float, + "default": defaults.BATCH_MEMORY_FRAC, + "help": "Fraction of the batch populated with memory data. Default: " + f"{defaults.BATCH_MEMORY_FRAC}.", }, } ) diff --git a/src/renate/defaults.py b/src/renate/defaults.py index 9025140d..12ed2983 100644 --- a/src/renate/defaults.py +++ b/src/renate/defaults.py @@ -18,6 +18,7 @@ WEIGHT_DECAY = 0.0 MAX_EPOCHS = 50 BATCH_SIZE = 32 +BATCH_MEMORY_FRAC = 0.5 LOSS_WEIGHT = 1.0 SEED = 0 EMA_MEMORY_UPDATE_GAMMA = 1.0 diff --git a/src/renate/updaters/avalanche/learner.py b/src/renate/updaters/avalanche/learner.py index ff258cae..d2e04a34 100644 --- a/src/renate/updaters/avalanche/learner.py +++ b/src/renate/updaters/avalanche/learner.py @@ -38,7 +38,7 @@ def update_settings( avalanche_learner._criterion = self._loss_fn avalanche_learner.train_epochs = max_epochs avalanche_learner.train_mb_size = self._batch_size - avalanche_learner.eval_mb_size = self._batch_size + avalanche_learner.eval_mb_size = self._batch_size + getattr(self, "_memory_batch_size", 0) avalanche_learner.device = device avalanche_learner.eval_every = eval_every @@ -57,7 +57,7 @@ def _create_avalanche_learner( optimizer=optimizer, criterion=self._loss_fn, train_mb_size=self._batch_size, - eval_mb_size=self._batch_size, + eval_mb_size=self._batch_size + getattr(self, "_memory_batch_size", 0), train_epochs=train_epochs, plugins=plugins, evaluator=default_evaluator(), diff --git a/src/renate/updaters/avalanche/model_updater.py b/src/renate/updaters/avalanche/model_updater.py index e5fb0933..10e1d2f4 100644 --- a/src/renate/updaters/avalanche/model_updater.py +++ b/src/renate/updaters/avalanche/model_updater.py @@ -259,7 +259,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, learning_rate_scheduler: Optional[Callable[[Optimizer], _LRScheduler]] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -289,7 +289,7 @@ def __init__( learner_kwargs = { "batch_size": batch_size, "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "seed": seed, } super().__init__( diff --git a/src/renate/updaters/experimental/er.py b/src/renate/updaters/experimental/er.py index 82b9a1da..e9fe4018 100644 --- a/src/renate/updaters/experimental/er.py +++ b/src/renate/updaters/experimental/er.py @@ -532,7 +532,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -566,7 +566,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -614,7 +614,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -649,7 +649,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -698,7 +698,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -734,7 +734,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -784,7 +784,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -823,7 +823,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, @@ -876,7 +876,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight: float = defaults.LOSS_WEIGHT, ema_memory_update_gamma: float = defaults.EMA_MEMORY_UPDATE_GAMMA, loss_normalization: int = defaults.LOSS_NORMALIZATION, @@ -921,7 +921,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight": loss_weight, "ema_memory_update_gamma": ema_memory_update_gamma, "loss_normalization": loss_normalization, diff --git a/src/renate/updaters/experimental/gdumb.py b/src/renate/updaters/experimental/gdumb.py index 623046a8..03e12f6f 100644 --- a/src/renate/updaters/experimental/gdumb.py +++ b/src/renate/updaters/experimental/gdumb.py @@ -106,7 +106,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 batch_size: int = defaults.BATCH_SIZE, @@ -136,7 +136,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, "seed": seed, } diff --git a/src/renate/updaters/experimental/offline_er.py b/src/renate/updaters/experimental/offline_er.py index 78ac2dd0..7809af1f 100644 --- a/src/renate/updaters/experimental/offline_er.py +++ b/src/renate/updaters/experimental/offline_er.py @@ -29,18 +29,11 @@ class OfflineExperienceReplayLearner(ReplayLearner): terminated. Args: - memory_size: The maximum size of the memory. - memory_batch_size: Size of batches sampled from the memory. The memory batch will be - appended to the batch sampled from the current dataset, leading to an effective batch - size of `memory_batch_size + batch_size`. loss_weight_new_data: The training loss will be a convex combination of the loss on the new data and the loss on the memory data. If a float (needs to be in [0, 1]) is given here, it will be used as the weight for the new data. If `None`, the weight will be set dynamically to `N_t / sum([N_1, ..., N_t])`, where `N_i` denotes the size of task/chunk `i` and the current task is `t`. - buffer_transform: The transformation to be applied to the memory buffer data samples. - buffer_target_transform: The target transformation to be applied to the memory buffer target - samples. """ def __init__(self, loss_weight_new_data: Optional[float] = None, **kwargs) -> None: @@ -147,7 +140,7 @@ def __init__( loss_fn: torch.nn.Module, optimizer: Callable[[List[Parameter]], Optimizer], memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: int = defaults.BATCH_MEMORY_FRAC, loss_weight_new_data: Optional[float] = None, learning_rate_scheduler: Optional[partial] = None, learning_rate_scheduler_interval: defaults.SUPPORTED_LR_SCHEDULER_INTERVAL_TYPE = defaults.LR_SCHEDULER_INTERVAL, # noqa: E501 @@ -178,7 +171,7 @@ def __init__( ): learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "loss_weight_new_data": loss_weight_new_data, "batch_size": batch_size, "seed": seed, diff --git a/src/renate/updaters/learner.py b/src/renate/updaters/learner.py index fd9d4123..d7e8777f 100644 --- a/src/renate/updaters/learner.py +++ b/src/renate/updaters/learner.py @@ -456,9 +456,7 @@ class ReplayLearner(Learner, abc.ABC): Args: memory_size: The maximum size of the memory. - memory_batch_size: Size of batches sampled from the memory. The memory batch will be - appended to the batch sampled from the current dataset, leading to an effective batch - size of `memory_batch_size + batch_size`. + batch_memory_frac: Fraction of the batch that is sampled from rehearsal memory. buffer_transform: The transformation to be applied to the memory buffer data samples. buffer_target_transform: The target transformation to be applied to the memory buffer target samples. @@ -468,14 +466,21 @@ class ReplayLearner(Learner, abc.ABC): def __init__( self, memory_size: int, - memory_batch_size: int = defaults.BATCH_SIZE, + batch_size: int = defaults.BATCH_SIZE, + batch_memory_frac: float = defaults.BATCH_MEMORY_FRAC, buffer_transform: Optional[Callable] = None, buffer_target_transform: Optional[Callable] = None, seed: int = defaults.SEED, **kwargs, ) -> None: - super().__init__(seed=seed, **kwargs) - self._memory_batch_size = min(memory_size, memory_batch_size) + if not (0 <= batch_memory_frac <= 1): + raise ValueError( + f"Expecting batch_memory_frac to be in [0, 1], received {batch_memory_frac}." + ) + memory_batch_size = min(memory_size, int(batch_memory_frac * batch_size)) + batch_size = batch_size - memory_batch_size + super().__init__(batch_size=batch_size, seed=seed, **kwargs) + self._memory_batch_size = memory_batch_size self._memory_buffer = ReservoirBuffer( max_size=memory_size, seed=seed, diff --git a/src/renate/utils/config_spaces.py b/src/renate/utils/config_spaces.py index 2c1f6bc6..d0842049 100644 --- a/src/renate/utils/config_spaces.py +++ b/src/renate/utils/config_spaces.py @@ -15,13 +15,13 @@ def _get_range(start, end, step): "momentum": choice([0.0, 0.9, 0.99]), "weight_decay": loguniform(1e-6, 1e-2), "learning_rate": loguniform(0.001, 0.5), - "batch_size": 32, + "batch_size": 64, "max_epochs": 50, } _replay_config_space = { **_learner_config_space, **{ - "memory_batch_size": 32, + "batch_memory_frac": 0.5, "memory_size": 1000, }, } diff --git a/test/conftest.py b/test/conftest.py index 921f6540..9d1fd5ce 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -71,7 +71,7 @@ def pytest_collection_modifyitems(config, items): LEARNER_KWARGS = { ExperienceReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, }, @@ -81,7 +81,7 @@ def pytest_collection_modifyitems(config, items): RepeatedDistillationLearner: {"batch_size": 10, "seed": 42, "memory_size": 30}, OfflineExperienceReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "loss_weight_new_data": 0.5, "batch_size": 50, "seed": 1, @@ -90,7 +90,7 @@ def pytest_collection_modifyitems(config, items): AVALANCHE_LEARNER_KWARGS = { AvalancheReplayLearner: { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, }, diff --git a/test/integration_tests/configs/updaters/avalanche-er-buffer500.json b/test/integration_tests/configs/updaters/avalanche-er-buffer500.json index a0cc0afe..d50ca947 100644 --- a/test/integration_tests/configs/updaters/avalanche-er-buffer500.json +++ b/test/integration_tests/configs/updaters/avalanche-er-buffer500.json @@ -4,6 +4,7 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, + "batch_size": 288, + "batch_memory_frac": 0.112, "memory_size": 500 } diff --git a/test/integration_tests/configs/updaters/cls-er-buffer500.json b/test/integration_tests/configs/updaters/cls-er-buffer500.json index a62b367a..a54b08b4 100644 --- a/test/integration_tests/configs/updaters/cls-er-buffer500.json +++ b/test/integration_tests/configs/updaters/cls-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.5, "beta": 0.1, diff --git a/test/integration_tests/configs/updaters/der-buffer500.json b/test/integration_tests/configs/updaters/der-buffer500.json index 13dea96b..6bf12918 100644 --- a/test/integration_tests/configs/updaters/der-buffer500.json +++ b/test/integration_tests/configs/updaters/der-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.2, "beta": 0.5, diff --git a/test/integration_tests/configs/updaters/er-buffer500.json b/test/integration_tests/configs/updaters/er-buffer500.json index d23103ee..b0645044 100644 --- a/test/integration_tests/configs/updaters/er-buffer500.json +++ b/test/integration_tests/configs/updaters/er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 0.5, "loss_normalization": 0, diff --git a/test/integration_tests/configs/updaters/gdumb-buffer500.json b/test/integration_tests/configs/updaters/gdumb-buffer500.json index efa07b0c..9d0d3ef0 100644 --- a/test/integration_tests/configs/updaters/gdumb-buffer500.json +++ b/test/integration_tests/configs/updaters/gdumb-buffer500.json @@ -4,7 +4,7 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500 } diff --git a/test/integration_tests/configs/updaters/offline-er-buffer500.json b/test/integration_tests/configs/updaters/offline-er-buffer500.json index eaa22099..f080a295 100644 --- a/test/integration_tests/configs/updaters/offline-er-buffer500.json +++ b/test/integration_tests/configs/updaters/offline-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "loss_weight_new_data": 0.5 } diff --git a/test/integration_tests/configs/updaters/pod-er-buffer500.json b/test/integration_tests/configs/updaters/pod-er-buffer500.json index c46f3f15..76ae14af 100644 --- a/test/integration_tests/configs/updaters/pod-er-buffer500.json +++ b/test/integration_tests/configs/updaters/pod-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "alpha": 1.0, "distillation_type": "spatial", diff --git a/test/integration_tests/configs/updaters/super-er-buffer500.json b/test/integration_tests/configs/updaters/super-er-buffer500.json index d3dafffc..371d92ee 100644 --- a/test/integration_tests/configs/updaters/super-er-buffer500.json +++ b/test/integration_tests/configs/updaters/super-er-buffer500.json @@ -4,8 +4,8 @@ "learning_rate": 0.01, "momentum": 0.0, "weight_decay": 0.0, - "batch_size": 256, - "memory_batch_size": 256, + "batch_size": 512, + "batch_memory_frac": 0.5, "memory_size": 500, "der_alpha": 1.0, "der_beta": 1.0, diff --git a/test/renate/updaters/avalanche/test_avalanche_learner.py b/test/renate/updaters/avalanche/test_avalanche_learner.py index 6e0e3eda..e5657e2c 100644 --- a/test/renate/updaters/avalanche/test_avalanche_learner.py +++ b/test/renate/updaters/avalanche/test_avalanche_learner.py @@ -30,6 +30,7 @@ def check_learner_settings( expected_max_epochs, expected_device, expected_eval_every, + expected_batch_size, expected_loss_fn=None, ): if isinstance(learner, AvalancheICaRLLearner): @@ -47,7 +48,7 @@ def check_learner_settings( assert avalanche_learner._criterion == expected_loss_fn assert avalanche_learner.optimizer == expected_optimizer assert avalanche_learner.train_epochs == expected_max_epochs - assert avalanche_learner.train_mb_size == learner_kwargs["batch_size"] + assert avalanche_learner.train_mb_size == expected_batch_size assert avalanche_learner.eval_mb_size == learner_kwargs["batch_size"] assert avalanche_learner.device == expected_device @@ -84,6 +85,11 @@ def test_update_settings(learner_class): expected_optimizer = SGD(expected_model.parameters(), lr=0.1) expected_device = torch.device("cpu") expected_eval_every = -1 + expected_batch_size = learner_kwargs["batch_size"] + if "batch_memory_frac" in learner_kwargs: + expected_batch_size = expected_batch_size - int( + learner_kwargs["batch_memory_frac"] * expected_batch_size + ) learner = learner_class( model=expected_model, optimizer=None, @@ -107,6 +113,7 @@ def test_update_settings(learner_class): expected_max_epochs=expected_max_epochs, expected_device=expected_device, expected_eval_every=expected_eval_every, + expected_batch_size=expected_batch_size, ) # Update @@ -135,4 +142,5 @@ def test_update_settings(learner_class): expected_max_epochs=expected_max_epochs, expected_device=expected_device, expected_eval_every=expected_eval_every, + expected_batch_size=expected_batch_size, ) diff --git a/test/renate/updaters/avalanche/test_avalanche_model_updater.py b/test/renate/updaters/avalanche/test_avalanche_model_updater.py index 2738c733..af2dad9c 100644 --- a/test/renate/updaters/avalanche/test_avalanche_model_updater.py +++ b/test/renate/updaters/avalanche/test_avalanche_model_updater.py @@ -77,15 +77,17 @@ def test_continuation_of_training_with_avalanche_model_updater(tmpdir, learner_c @pytest.mark.parametrize( - "batch_size,memory_size,memory_batch_size", - [[10, 10, 10], [20, 10, 10], [10, 100, 10], [10, 30, 1], [100, 10, 3]], + "batch_size,memory_size,batch_memory_frac", + [[20, 10, 0.5], [30, 10, 0.34], [20, 100, 0.5], [10, 30, 0.1], [100, 10, 0.03]], ) -def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_batch_size): +def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, batch_memory_frac): + expected_memory_batch_size = int(batch_memory_frac * batch_size) + expected_batch_size = batch_size - expected_memory_batch_size dataset_size = 100 model, dataset = get_model_and_dataset(dataset_size) learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, } model_updater = ExperienceReplayAvalancheModelUpdater( @@ -99,9 +101,9 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b ) model_updater.update(train_dataset=dataset) replay_plugin = plugin_by_class(ReplayPlugin, model_updater._learner.plugins) - assert replay_plugin.batch_size == batch_size + assert replay_plugin.batch_size == expected_batch_size assert replay_plugin.mem_size == memory_size - assert replay_plugin.batch_size_mem == memory_batch_size + assert replay_plugin.batch_size_mem == expected_memory_batch_size assert len(replay_plugin.storage_policy.buffer) == min( memory_size, dataset_size, len(replay_plugin.storage_policy.buffer) ) @@ -121,9 +123,9 @@ def test_experience_replay_buffer_size(tmpdir, batch_size, memory_size, memory_b ) replay_plugin = plugin_by_class(ReplayPlugin, model_updater._learner.plugins) - assert replay_plugin.batch_size == batch_size + assert replay_plugin.batch_size == expected_batch_size assert replay_plugin.mem_size == memory_size - assert replay_plugin.batch_size_mem == memory_batch_size + assert replay_plugin.batch_size_mem == expected_memory_batch_size model_updater.update(train_dataset=dataset) assert len(model_updater._learner.dataloader.data) == dataset_size assert len(model_updater._learner.dataloader.memory) == min( diff --git a/test/renate/updaters/experimental/test_er.py b/test/renate/updaters/experimental/test_er.py index 72545b72..de5d6ad9 100644 --- a/test/renate/updaters/experimental/test_er.py +++ b/test/renate/updaters/experimental/test_er.py @@ -26,14 +26,15 @@ def get_model_and_dataset(): @pytest.mark.parametrize( - "batch_size,memory_size,memory_batch_size", - [[10, 10, 10], [20, 10, 10], [10, 100, 10], [10, 30, 1], [100, 10, 3]], + "batch_size,memory_size,batch_memory_frac", + [[20, 10, 0.5], [30, 10, 0.34], [20, 100, 0.5], [10, 30, 0.1], [100, 10, 0.03]], ) -def test_er_overall_memory_size_after_update(batch_size, memory_size, memory_batch_size): +def test_er_overall_memory_size_after_update(batch_size, memory_size, batch_memory_frac): + memory_batch_size = int(batch_memory_frac * batch_size) model, dataset = get_model_and_dataset() learner_kwargs = { "memory_size": memory_size, - "memory_batch_size": memory_batch_size, + "batch_memory_frac": batch_memory_frac, "batch_size": batch_size, } model_updater = pytest.helpers.get_simple_updater( @@ -88,9 +89,15 @@ def test_er_validation_buffer(tmpdir): ) +def validate_common_args(model_updater, learner_kwargs): + memory_batch_size = int(learner_kwargs["batch_memory_frac"] * learner_kwargs["batch_size"]) + batch_size = learner_kwargs["batch_size"] - memory_batch_size + assert model_updater._learner._batch_size == batch_size + assert model_updater._learner._memory_batch_size == memory_batch_size + + def validate_cls_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["alpha"] assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["beta"] assert ( @@ -112,15 +119,13 @@ def validate_cls_er(model_updater, learner_kwargs): def validate_dark_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["beta"] assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["alpha"] def validate_pod_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["pod_loss"].weight == learner_kwargs["alpha"] assert ( model_updater._learner._components["pod_loss"]._distillation_type @@ -130,8 +135,7 @@ def validate_pod_er(model_updater, learner_kwargs): def validate_super_er(model_updater, learner_kwargs): - assert model_updater._learner._batch_size == learner_kwargs["batch_size"] - assert model_updater._learner._memory_batch_size == learner_kwargs["memory_batch_size"] + validate_common_args(model_updater, learner_kwargs) assert model_updater._learner._components["memory_loss"].weight == learner_kwargs["der_beta"] assert model_updater._learner._components["mse_loss"].weight == learner_kwargs["der_alpha"] assert model_updater._learner._components["cls_loss"].weight == learner_kwargs["cls_alpha"] @@ -174,7 +178,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_cls_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -186,7 +190,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 2.3, @@ -202,7 +206,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_dark_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -210,7 +214,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 2.3, @@ -222,7 +226,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_pod_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "alpha": 0.123, @@ -231,7 +235,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "alpha": 0.123, @@ -244,7 +248,7 @@ def validate_super_er(model_updater, learner_kwargs): validate_super_er, { "memory_size": 30, - "memory_batch_size": 20, + "batch_memory_frac": 0.4, "batch_size": 50, "seed": 1, "der_alpha": 0.123, @@ -262,7 +266,7 @@ def validate_super_er(model_updater, learner_kwargs): }, { "memory_size": 30, - "memory_batch_size": 10, + "batch_memory_frac": 0.1, "batch_size": 100, "seed": 1, "der_alpha": 2.3,