Skip to content

Commit

Permalink
Masking of logits of irrelevant classes (#364)
Browse files Browse the repository at this point in the history
  • Loading branch information
prabhuteja12 authored Aug 15, 2023
1 parent 41bd5b9 commit c4bcba3
Show file tree
Hide file tree
Showing 17 changed files with 269 additions and 6 deletions.
4 changes: 4 additions & 0 deletions src/renate/benchmark/experimentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,10 @@ def execute_experiment_job(
devices: Number of devices to use.
deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU.
In this function this parameter is set to True by default.
gradient_clip_val: The value at which to clip gradients. Passing None disables it.
`More details <https://lightning.ai/docs/pytorch/stable/common/trainer.html#init>`__
gradient_clip_algorithm: The gradient clipping algorithm to use. Can be norm or value.
`More details <https://lightning.ai/docs/pytorch/stable/common/trainer.html#init>`__
job_name: Name of the experiment job.
strategy: Name of the distributed training strategy to use.
`More details <https://lightning.ai/docs/pytorch/stable/extensions/strategy.html>`__
Expand Down
11 changes: 10 additions & 1 deletion src/renate/cli/parsing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_updater_and_learner_kwargs(
"""Returns the model updater class and the keyword arguments for the learner."""
if args.updater.startswith("Avalanche-") and find_spec("avalanche", None) is None:
raise ImportError("Avalanche is not installed. Please run `pip install Renate[avalanche]`.")
learner_args = ["batch_size", "seed"]
learner_args = ["batch_size", "seed", "mask_unused_classes"]
base_er_args = learner_args + [
"loss_weight",
"ema_memory_update_gamma",
Expand Down Expand Up @@ -324,6 +324,15 @@ def _standard_arguments() -> Dict[str, Dict[str, Any]]:
"choices": ["norm", "value", None],
"argument_group": OPTIONAL_ARGS_GROUP,
},
"mask_unused_classes": {
"default": str(defaults.MASK_UNUSED_CLASSES),
"type": str,
"choices": ["True", "False"],
"help": "Whether to use a class mask to kill the unused logits. Useful possibly for "
"class incremental learning methods. ",
"argument_group": OPTIONAL_ARGS_GROUP,
"true_type": bool,
},
"prepare_data": {
"type": str,
"default": "True",
Expand Down
1 change: 1 addition & 0 deletions src/renate/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
FRAMEWORK_VERSION = "1.13.1"

TASK_ID = "default_task"
MASK_UNUSED_CLASSES = False
WORKING_DIRECTORY = "renate_working_dir"
LOGGER = TensorBoardLogger
LOGGER_KWARGS = {
Expand Down
4 changes: 4 additions & 0 deletions src/renate/training/training.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,10 @@ def run_training_job(
precision: Type of bit precision to use.
`More details <https://lightning.ai/docs/pytorch/stable/common/precision_basic.html>`__
deterministic_trainer: When true the Trainer adopts a deterministic behaviour also on GPU.
gradient_clip_val: The value at which to clip gradients. Passing None disables it.
`More details <https://lightning.ai/docs/pytorch/stable/common/trainer.html#init>`__
gradient_clip_algorithm: The gradient clipping algorithm to use. Can be norm or value.
`More details <https://lightning.ai/docs/pytorch/stable/common/trainer.html#init>`__
job_name: Prefix for the name of the SageMaker training job.
"""
assert (
Expand Down
16 changes: 16 additions & 0 deletions src/renate/updaters/avalanche/model_updater.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0
import logging
import warnings
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Type

Expand Down Expand Up @@ -47,6 +48,13 @@
class AvalancheModelUpdater(SingleTrainingLoopUpdater):
_report = Reporter()

def __init__(self, *args, **kwargs):
if kwargs.get("mask_unused_classes", False) is True:
warnings.warn(
"Avalanche model updaters do not support mask_unused_classes. Ignoring it."
)
super().__init__(*args, **kwargs)

def _load_learner(
self,
learner_class: Type[Learner],
Expand Down Expand Up @@ -276,6 +284,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
Expand Down Expand Up @@ -310,6 +319,7 @@ def __init__(
precision=precision,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -344,6 +354,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
Expand Down Expand Up @@ -377,6 +388,7 @@ def __init__(
precision=precision,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -412,6 +424,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
Expand Down Expand Up @@ -446,6 +459,7 @@ def __init__(
precision=precision,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -480,6 +494,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -513,4 +528,5 @@ def __init__(
precision=precision,
gradient_clip_val=gradient_clip_val,
gradient_clip_algorithm=gradient_clip_algorithm,
mask_unused_classes=mask_unused_classes,
)
18 changes: 18 additions & 0 deletions src/renate/updaters/experimental/er.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
ShrinkAndPerturbReinitializationComponent,
)
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.misc import maybe_populate_mask_and_ignore_logits
from renate.utils.pytorch import move_tensors_to_device


Expand Down Expand Up @@ -140,6 +141,13 @@ def training_step(
batch_memory = self._sample_from_buffer(device=step_output["loss"].device)
(inputs_memory, _), metadata_memory = batch_memory
outputs_memory = self(inputs_memory)

outputs_memory, self._class_mask = maybe_populate_mask_and_ignore_logits(
self._mask_unused_classes,
self._class_mask,
self._classes_in_current_task,
outputs_memory,
)
intermediate_representation_memory = (
self._model.get_intermediate_representation()
)
Expand Down Expand Up @@ -554,6 +562,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -594,6 +603,7 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -635,6 +645,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -676,6 +687,7 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -718,6 +730,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -760,6 +773,7 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -805,6 +819,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -850,6 +865,7 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)


Expand Down Expand Up @@ -901,6 +917,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -952,4 +969,5 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
2 changes: 2 additions & 0 deletions src/renate/updaters/experimental/fine_tuning.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
Expand Down Expand Up @@ -77,4 +78,5 @@ def __init__(
precision=precision,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
2 changes: 2 additions & 0 deletions src/renate/updaters/experimental/gdumb.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -168,4 +169,5 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
2 changes: 2 additions & 0 deletions src/renate/updaters/experimental/joint.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"batch_size": batch_size,
Expand Down Expand Up @@ -153,4 +154,5 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
7 changes: 7 additions & 0 deletions src/renate/updaters/experimental/offline_er.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from renate.types import NestedTensors
from renate.updaters.learner import ReplayLearner
from renate.updaters.model_updater import SingleTrainingLoopUpdater
from renate.utils.misc import maybe_populate_mask_and_ignore_logits
from renate.utils.pytorch import cat_nested_tensors, get_length_nested_tensors


Expand Down Expand Up @@ -113,6 +114,10 @@ def training_step(self, batch: Dict[str, Tuple[NestedTensors]], batch_idx: int)
inputs = cat_nested_tensors((inputs, inputs_mem), 0)
targets = torch.cat((targets, targets_mem), 0)
outputs = self(inputs)

outputs, self._class_mask = maybe_populate_mask_and_ignore_logits(
self._mask_unused_classes, self._class_mask, self._classes_in_current_task, outputs
)
loss = self._loss_fn(outputs, targets)
if "memory" in batch:
loss_current = loss[:batch_size_current].mean()
Expand Down Expand Up @@ -169,6 +174,7 @@ def __init__(
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {
"memory_size": memory_size,
Expand Down Expand Up @@ -206,4 +212,5 @@ def __init__(
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)
6 changes: 6 additions & 0 deletions src/renate/updaters/experimental/repeated_distill.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,9 @@ def __init__(
seed: Optional[int] = None,
early_stopping_enabled=False,
deterministic_trainer: bool = defaults.DETERMINISTIC_TRAINER,
gradient_clip_val: Optional[float] = defaults.GRADIENT_CLIP_VAL,
gradient_clip_algorithm: Optional[str] = defaults.GRADIENT_CLIP_ALGORITHM,
mask_unused_classes: bool = defaults.MASK_UNUSED_CLASSES,
):
learner_kwargs = {"memory_size": memory_size, "batch_size": batch_size, "seed": seed}
super().__init__(
Expand Down Expand Up @@ -152,6 +155,9 @@ def __init__(
early_stopping_enabled=early_stopping_enabled,
logged_metrics=logged_metrics,
deterministic_trainer=deterministic_trainer,
gradient_clip_algorithm=gradient_clip_algorithm,
gradient_clip_val=gradient_clip_val,
mask_unused_classes=mask_unused_classes,
)

def update(
Expand Down
Loading

0 comments on commit c4bcba3

Please sign in to comment.