Description
Description & Motivation
I have a task with multiple test datasets and I want to compute metrics individually per dataset. This requires me to take a custom action when the dataset ends. It would be great to have built-in hooks in lightning to enable this and it already fits in with the overall philosophy of the lightningmodule processing hooks enabling actions at specific points in the loop.
Pitch
I propose adding two new hooks for validation/test/prediction (training is probably not applicable here but we can discuss if it is). This would be the on_dataloader_start and on_dataloader_end hooks. They would be called as follows:
on_X_epoch_start
on_X_dataloader_start ---------------|
on_X_batch_start -| batch loop | dataloader loop
on_X_batch_end -| |
on_X_dataloader_end-----------------|
on_X_epoch_end
There would also need to be an update to allow metric logging in this manner because currently lightning throws an exception if the same key is logged "with different metadata" (I forget the exact text of the exception)
Alternatives
I currently provide a subset of this functionality (the *_dataloader_end hooks) using the following callback:
from typing import Any
from lightning.pytorch import Callback, LightningModule, Trainer
class MultiloaderNotifier(Callback):
def __init__(self) -> None:
self.dataloader_idxs = {"validation": 0, "test": 0, "predict": 0}
def on_batch_start(self, stage: str, pl_module: LightningModule, dataloader_idx: int) -> None:
if dataloader_idx != self.dataloader_idxs[stage]:
if hook := getattr(pl_module, f"on_{stage}_dataloader_end", None):
hook(self.dataloader_idxs[stage])
self.dataloader_idxs[stage] = dataloader_idx
def on_epoch_end(self, stage: str, pl_module: LightningModule) -> None:
if hook := getattr(pl_module, f"on_{stage}_dataloader_end", None):
hook(self.dataloader_idxs[stage])
self.dataloader_idxs[stage] = 0
def on_validation_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.on_batch_start("validation", pl_module, dataloader_idx)
def on_validation_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.on_epoch_end("validation", pl_module)
def on_test_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.on_batch_start("test", pl_module, dataloader_idx)
def on_test_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.on_epoch_end("test", pl_module)
def on_predict_batch_start(self, trainer: Trainer, pl_module: LightningModule, batch: Any, batch_idx: int, dataloader_idx: int = 0) -> None:
self.on_batch_start("predict", pl_module, dataloader_idx)
def on_predict_epoch_end(self, trainer: Trainer, pl_module: LightningModule) -> None:
self.on_epoch_end("predict", pl_module)
which allows me to hook with
def on_test_dataloader_end(self, dataloader_idx: int) -> None:
self.logger.log_metric(self.metric.compute())
self.metric.reset()
However this has a number of drawbacks. First, the hooks are not necessarily called in the right order since I need to hook batch_start from the callback in order to detect the change in dataloader (so batch_start -> dataloader_end is not conceptually correct). Next, I need to hook epoch_end to make the last dataloader_end call (epoch_end -> dataloader_end is also not conceptually correct). Lastly there is no obvious way to do both dataloader_start and dataloader_end, I don't need them both but someone else might.
There is also the issue of the logging itself. Currently, to avoid the exception, I need to bypass the lightningmodule log function and call log_metric on the logger.
Additional context
No response