Skip to content

Commit

Permalink
Fix filtration logic for eval results with multiple dataloaders (#10810)
Browse files Browse the repository at this point in the history
Co-authored-by: Carlos Mocholi <carlossmocholi@gmail.com>
  • Loading branch information
rohitgr7 and carmocca committed Dec 3, 2021
1 parent 84bdcd4 commit 3e689b5
Show file tree
Hide file tree
Showing 7 changed files with 72 additions and 47 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Disabled batch_size extraction for torchmetric instances because they accumulate the metrics internally ([#10815](https://github.com/PyTorchLightning/pytorch-lightning/pull/10815))
- Fixed an issue with `SignalConnector` not restoring the default signal handlers on teardown when running on SLURM or with fault-tolerant training enabled ([#10611](https://github.com/PyTorchLightning/pytorch-lightning/pull/10611))
- Fixed `SignalConnector._has_already_handler` check for callable type ([#10483](https://github.com/PyTorchLightning/pytorch-lightning/pull/10483))
- Fixed an issue to return the results for each dataloader separately instead of duplicating them for each ([#10810](https://github.com/PyTorchLightning/pytorch-lightning/pull/10810))
- Improved exception message if `rich` version is less than `10.2.2` ([#10839](https://github.com/PyTorchLightning/pytorch-lightning/pull/10839))
- Fixed uploading best model checkpoint in NeptuneLogger ([#10369](https://github.com/PyTorchLightning/pytorch-lightning/pull/10369))
- Fixed early schedule reset logic in PyTorch profiler that was causing data leak ([#10837](https://github.com/PyTorchLightning/pytorch-lightning/pull/10837))
Expand All @@ -27,6 +28,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed a consolidation error in Lite when attempting to save the state dict of a sharded optimizer ([#10746](https://github.com/PyTorchLightning/pytorch-lightning/pull/10746))
- Fixed the default logging level for batch hooks associated with training from `on_step=False, on_epoch=True` to `on_step=True, on_epoch=False` ([#10756](https://github.com/PyTorchLightning/pytorch-lightning/pull/10756))


### Removed

- Removed PyTorch 1.6 support ([#10367](https://github.com/PyTorchLightning/pytorch-lightning/pull/10367), [#10738](https://github.com/PyTorchLightning/pytorch-lightning/pull/10738))
Expand Down
3 changes: 2 additions & 1 deletion pytorch_lightning/core/lightning.py
Original file line number Diff line number Diff line change
Expand Up @@ -486,7 +486,8 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=self._current_dataloader_idx,
batch_size=batch_size,
sync_dist=sync_dist and distributed_available(),
sync_dist_fn=self.trainer.training_type_plugin.reduce or sync_ddp,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,20 @@ def update_eval_step_metrics(self) -> None:
# increment the step even if nothing was logged
self._increment_eval_log_step()

@staticmethod
def _filter_metrics_for_dataloader(
dl_idx: int, metrics: _OUT_DICT, metric_prefix: str = "dataloader_idx"
) -> _OUT_DICT:
return {k: v for k, v in metrics.items() if metric_prefix not in k or k.endswith(f"{metric_prefix}_{dl_idx}")}

def _prepare_eval_loop_results(self, metrics: _OUT_DICT) -> None:
def _prepare_eval_loop_results(self) -> None:
if self.trainer.sanity_checking:
return

on_step = not self._epoch_end_reached
num_dataloaders = self.trainer._evaluation_loop.num_dataloaders
has_been_initialized = len(self.eval_loop_results) == num_dataloaders
for dl_idx in range(self.trainer._evaluation_loop.num_dataloaders):
# remove callback metrics that don't belong to this dataloader
callback_metrics = self._filter_metrics_for_dataloader(dl_idx, metrics)
assert self.trainer._evaluation_loop._results is not None
for dl_idx in range(num_dataloaders):
metrics = self.trainer._evaluation_loop._results.metrics(
on_step, dataloader_idx=dl_idx if num_dataloaders > 1 else None
)
callback_metrics = metrics["callback"]

if has_been_initialized:
self.eval_loop_results[dl_idx].update(callback_metrics)
else:
Expand All @@ -182,7 +181,7 @@ def update_eval_epoch_metrics(self) -> List[_OUT_DICT]:
# log all the metrics as a single dict
self.log_metrics(metrics["log"])

self._prepare_eval_loop_results(metrics["callback"])
self._prepare_eval_loop_results()

# log results of evaluation
if (
Expand Down
20 changes: 14 additions & 6 deletions pytorch_lightning/trainer/connectors/logger_connector/result.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ class _Metadata:
on_epoch: bool = True
reduce_fx: Callable = torch.mean
enable_graph: bool = False
add_dataloader_idx: bool = True
dataloader_idx: Optional[int] = None
metric_attribute: Optional[str] = None
_sync: Optional[_Sync] = None
Expand Down Expand Up @@ -434,6 +435,7 @@ def log(
sync_dist: bool = False,
sync_dist_fn: Callable = _Sync.no_op,
sync_dist_group: Optional[Any] = None,
add_dataloader_idx: bool = True,
dataloader_idx: Optional[int] = None,
batch_size: Optional[int] = None,
metric_attribute: Optional[str] = None,
Expand All @@ -451,7 +453,7 @@ def log(
# storage key
key = f"{fx}.{name}"
# add dataloader_suffix to both key and fx
if dataloader_idx is not None:
if add_dataloader_idx and dataloader_idx is not None:
key += f".{dataloader_idx}"
fx += f".{dataloader_idx}"

Expand All @@ -464,6 +466,7 @@ def log(
on_epoch=on_epoch,
reduce_fx=reduce_fx,
enable_graph=enable_graph,
add_dataloader_idx=add_dataloader_idx,
dataloader_idx=dataloader_idx,
metric_attribute=metric_attribute,
)
Expand Down Expand Up @@ -522,24 +525,29 @@ def _get_cache(result_metric: ResultMetric, on_step: bool) -> Optional[torch.Ten
return cache.detach()
return cache

def valid_items(self) -> Generator:
def valid_items(self, dataloader_idx: Optional[int] = None) -> Generator:
"""This function is used to iterate over current valid metrics."""
return ((k, v) for k, v in self.items() if not (isinstance(v, ResultMetric) and v.has_reset))
return (
(k, v)
for k, v in self.items()
if not (isinstance(v, ResultMetric) and v.has_reset) and (dataloader_idx in (None, v.meta.dataloader_idx))
)

def _forked_name(self, result_metric: ResultMetric, on_step: bool) -> Tuple[str, str]:
name = result_metric.meta.name
forked_name = result_metric.meta.forked_name(on_step)
add_dataloader_idx = result_metric.meta.add_dataloader_idx
dl_idx = result_metric.meta.dataloader_idx
if dl_idx is not None:
if add_dataloader_idx and dl_idx is not None:
dataloader_suffix = self.DATALOADER_SUFFIX.format(dl_idx)
name += dataloader_suffix
forked_name += dataloader_suffix
return name, forked_name

def metrics(self, on_step: bool) -> _METRICS:
def metrics(self, on_step: bool, dataloader_idx: Optional[int] = None) -> _METRICS:
metrics = _METRICS(callback={}, log={}, pbar={})

for _, result_metric in self.valid_items():
for _, result_metric in self.valid_items(dataloader_idx):

# extract forward_cache or computed from the ResultMetric. ignore when the output is None
value = apply_to_collection(result_metric, ResultMetric, self._get_cache, on_step, include_none=False)
Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1396,7 +1396,7 @@ def __set_ckpt_path(self, ckpt_path: Optional[str], model_provided: bool, model_
" The best model of the previous `fit` call will be used."
f" You can pass `{fn}(ckpt_path='best')` to use and best model"
" checkpoint and avoid this warning or"
" `ckpt_path=trainer.model_checkpoint.last_model_path` to use the last model."
" `ckpt_path=trainer.checkpoint_callback.last_model_path` to use the last model."
)
ckpt_path = "best"

Expand Down
2 changes: 1 addition & 1 deletion tests/plugins/test_ddp_spawn_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def on_predict_start(self) -> None:
assert isinstance(self.trainer.model, LightningModule)


@RunIf(skip_windows=True, skip_49370=True)
@RunIf(skip_windows=True, skip_49370=True, skip_hanging_spawn=True)
def test_ddp_spawn_configure_ddp(tmpdir):
"""Tests with ddp spawn plugin."""
trainer = Trainer(default_root_dir=tmpdir, num_processes=2, strategy="ddp_spawn", fast_dev_run=True)
Expand Down
69 changes: 42 additions & 27 deletions tests/trainer/logging_/test_eval_loop_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

from pytorch_lightning import callbacks, Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel, RandomDataset
from tests.helpers.runif import RunIf
Expand Down Expand Up @@ -676,32 +675,6 @@ def val_dataloader(self):
trainer.fit(model)


@pytest.mark.parametrize(
["kwargs", "expected"],
[
({"dl_idx": 0, "metrics": {"acc": 123}}, {"acc": 123}),
(
{"dl_idx": 0, "metrics": {"acc/dataloader_idx_0": 123, "acc/dataloader_idx_1": 321}},
{"acc/dataloader_idx_0": 123},
),
(
{"dl_idx": 10, "metrics": {"acc/dataloader_idx_1": 123, "acc/dataloader_idx_10": 321}},
{"acc/dataloader_idx_10": 321},
),
(
{"dl_idx": 3, "metrics": {"top_3_acc/dataloader_idx_0": 123, "top_3_acc/dataloader_idx_3": 321}},
{"top_3_acc/dataloader_idx_3": 321},
),
# theoretical case, as `/dataloader_idx_3` would have been added
({"dl_idx": 3, "metrics": {"top_3_acc": 123}}, {"top_3_acc": 123}),
],
)
def test_filter_metrics_for_dataloader(kwargs, expected):
"""Logged metrics should only include metrics from the concerned dataloader."""
actual = LoggerConnector._filter_metrics_for_dataloader(**kwargs)
assert actual == expected


@RunIf(min_gpus=1)
def test_evaluation_move_metrics_to_cpu_and_outputs(tmpdir):
class TestModel(BoringModel):
Expand All @@ -723,3 +696,45 @@ def validation_epoch_end(self, outputs):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, limit_val_batches=2, move_metrics_to_cpu=True, gpus=1)
trainer.validate(model, verbose=False)


def test_logging_results_with_no_dataloader_idx(tmpdir):
num_dataloaders = 2
log_common_same_val = {"test_log_common": 789}
log_common_diff_val = "test_log_common_diff_value"
log_key_no_dl_idx = "test_log_no_dl_idx_{}"
log_key_dl0 = {"test_log_a_class": 123}
log_key_dl1 = {"test_log_b_class": 456}

class CustomBoringModel(BoringModel):
def test_step(self, batch, batch_idx, dataloader_idx):
self.log_dict(log_common_same_val)
self.log(log_common_diff_val, dataloader_idx + 1)
self.log(
log_key_no_dl_idx.format(dataloader_idx),
321 * (dataloader_idx + 1),
add_dataloader_idx=False,
)
self.log_dict(log_key_dl0 if dataloader_idx == 0 else log_key_dl1, add_dataloader_idx=False)

def test_dataloader(self):
return [torch.utils.data.DataLoader(RandomDataset(32, 64)) for _ in range(num_dataloaders)]

model = CustomBoringModel()
model.test_epoch_end = None
trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=1)
results = trainer.test(model)

assert len(results) == num_dataloaders
assert results[0] == {
"test_log_common/dataloader_idx_0": 789.0,
"test_log_common_diff_value/dataloader_idx_0": 1.0,
"test_log_no_dl_idx_0": 321,
"test_log_a_class": 123.0,
}
assert results[1] == {
"test_log_common/dataloader_idx_1": 789.0,
"test_log_common_diff_value/dataloader_idx_1": 2.0,
"test_log_no_dl_idx_1": 321 * 2,
"test_log_b_class": 456.0,
}

0 comments on commit 3e689b5

Please sign in to comment.