Skip to content

Commit

Permalink
Add support for returning callback from `LightningModule.configure_ca…
Browse files Browse the repository at this point in the history
…llbacks` (#11060)
  • Loading branch information
rohitgr7 authored Dec 18, 2021
1 parent 2a5d05b commit 3461af0
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 9 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Added a warning that shows when `max_epochs` in the `Trainer` is not set ([#10700](https://github.com/PyTorchLightning/pytorch-lightning/issues/10700))


- Added support for returning a single Callback from `LightningModule.configure_callbacks` without wrapping it into a list ([#11060](https://github.com/PyTorchLightning/pytorch-lightning/issues/11060))


- Added `console_kwargs` for `RichProgressBar` to initialize inner Console ([#10875](https://github.com/PyTorchLightning/pytorch-lightning/pull/10875))


Expand Down
16 changes: 9 additions & 7 deletions pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import tempfile
from contextlib import contextmanager
from pathlib import Path
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, overload, Sequence, Tuple, Union

import torch
from torch import ScriptModule, Tensor
Expand All @@ -31,6 +31,7 @@
from typing_extensions import Literal

import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.callbacks.progress import base as progress_base
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks, ModelHooks
from pytorch_lightning.core.mixins import DeviceDtypeModuleMixin, HyperparametersMixin
Expand Down Expand Up @@ -1119,15 +1120,16 @@ def predicts_step(self, batch, batch_idx, dataloader_idx=0):
"""
return self(batch)

def configure_callbacks(self):
def configure_callbacks(self) -> Union[Sequence[Callback], Callback]:
"""Configure model-specific callbacks. When the model gets attached, e.g., when ``.fit()`` or ``.test()``
gets called, the list returned here will be merged with the list of callbacks passed to the Trainer's
``callbacks`` argument. If a callback returned here has the same type as one or several callbacks already
present in the Trainer's callbacks list, it will take priority and replace them. In addition, Lightning
will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks run last.
gets called, the list or a callback returned here will be merged with the list of callbacks passed to the
Trainer's ``callbacks`` argument. If a callback returned here has the same type as one or several callbacks
already present in the Trainer's callbacks list, it will take priority and replace them. In addition,
Lightning will make sure :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint` callbacks
run last.
Return:
A list of callbacks which will extend the list of callbacks in the Trainer.
A callback or a list of callbacks which will extend the list of callbacks in the Trainer.
Example::
Expand Down
4 changes: 3 additions & 1 deletion pytorch_lightning/trainer/connectors/callback_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
import os
from datetime import timedelta
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional, Sequence, Union

from pytorch_lightning.callbacks import (
Callback,
Expand Down Expand Up @@ -272,6 +272,8 @@ def _attach_model_callbacks(self) -> None:
model_callbacks = self.trainer._call_lightning_module_hook("configure_callbacks")
if not model_callbacks:
return

model_callbacks = [model_callbacks] if not isinstance(model_callbacks, Sequence) else model_callbacks
model_callback_types = {type(c) for c in model_callbacks}
trainer_callback_types = {type(c) for c in self.trainer.callbacks}
override_types = model_callback_types.intersection(trainer_callback_types)
Expand Down
2 changes: 1 addition & 1 deletion tests/callbacks/test_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def test_configure_callbacks_hook_multiple_calls(tmpdir):

class TestModel(BoringModel):
def configure_callbacks(self):
return [model_callback_mock]
return model_callback_mock

model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, enable_checkpointing=False)
Expand Down

0 comments on commit 3461af0

Please sign in to comment.