From fb557b9c7f100a117ccfdafe73da5a4d55e24494 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Mon, 20 Mar 2023 18:12:58 +0100 Subject: [PATCH] Add a migration for the dataloader loops (#17125) Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/utilities/migration/migration.py | 30 +++++++++++- .../utilities/migration/test_migration.py | 49 +++++++++++++++++++ 2 files changed, 77 insertions(+), 2 deletions(-) diff --git a/src/lightning/pytorch/utilities/migration/migration.py b/src/lightning/pytorch/utilities/migration/migration.py index 229b62f85c744..5e0a772b5f95e 100644 --- a/src/lightning/pytorch/utilities/migration/migration.py +++ b/src/lightning/pytorch/utilities/migration/migration.py @@ -50,6 +50,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]: _drop_apex_amp_state, _migrate_loop_structure_after_tbptt_removal, _migrate_loop_structure_after_optimizer_loop_removal, + _migrate_loop_structure_after_dataloader_loop_removal, ], } @@ -236,7 +237,8 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE """ if "loops" not in checkpoint: return checkpoint - + if "fit_loop" not in checkpoint["loops"]: + return checkpoint fit_loop = checkpoint["loops"]["fit_loop"] # remap `x.batch_loop.y` to `x.y` @@ -273,8 +275,10 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT """ if "loops" not in checkpoint: return checkpoint - + if "fit_loop" not in checkpoint["loops"]: + return checkpoint fit_loop = checkpoint["loops"]["fit_loop"] + # optimizer_position is no longer used if "epoch_loop.optimizer_loop.optim_progress" in fit_loop: fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None) @@ -291,3 +295,25 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT "epoch_loop.manual_loop.optim_step_progress" ) return checkpoint + + +def _migrate_loop_structure_after_dataloader_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT: + """The dataloader loops (``_DataLoaderLoop``, ``_PredictionLoop`, and ``_EvaluationLoop``) were flattened into + the ``_EvaluationEpochLoop`` (now ``_EvaluationLoop``) and ``_PredictionEpochLoop`` (now ``_PredictionLoop``). + + Version: 2.0.0 + Commit: ec4f592ecfe238edd83185f6c6905fb1e2406d61 + PR: #16726 + """ + if "loops" not in checkpoint: + return checkpoint + loops = checkpoint["loops"] + for loop_key in ("predict_loop", "validate_loop", "test_loop"): + if loop_key not in loops: + continue + loop = loops[loop_key] + loop.pop("dataloader_progress", None) # no longer used + epoch_loop_key = "epoch_loop." + epoch_loop_dict = {k[len(epoch_loop_key) :]: loop.pop(k) for k in list(loop) if k.startswith(epoch_loop_key)} + loop.update(epoch_loop_dict) + return checkpoint diff --git a/tests/tests_pytorch/utilities/migration/test_migration.py b/tests/tests_pytorch/utilities/migration/test_migration.py index 5678b022ba9dd..9c373ad1b6259 100644 --- a/tests/tests_pytorch/utilities/migration/test_migration.py +++ b/tests/tests_pytorch/utilities/migration/test_migration.py @@ -227,3 +227,52 @@ def test_migrate_loop_structure_after_optimizer_loop_removal(): "epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual, } } + + +def test_migrate_loop_structure_after_dataloader_loop_removal(): + """Test the loop state migration after the dataloader loops were removed in 2.0.0.""" + old_dataloader_loop_state_dict = { + "state_dict": {}, + "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, + "epoch_loop.state_dict": {}, + "epoch_loop.batch_progress": { + "total": {"ready": 123, "started": 0, "processed": 0, "completed": 0}, + "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, + "is_last_batch": False, + }, + } + old_checkpoint = { + "loops": { + "predict_loop": old_dataloader_loop_state_dict, + "validate_loop": dict(old_dataloader_loop_state_dict), # copy + "test_loop": dict(old_dataloader_loop_state_dict), # copy + } + } + _set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0 + updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0") + assert updated_checkpoint["loops"] == { + "predict_loop": { + "batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 123, "started": 0}, + }, + "state_dict": {}, + }, + "test_loop": { + "batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 123, "started": 0}, + }, + "state_dict": {}, + }, + "validate_loop": { + "batch_progress": { + "current": {"completed": 0, "processed": 0, "ready": 0, "started": 0}, + "is_last_batch": False, + "total": {"completed": 0, "processed": 0, "ready": 123, "started": 0}, + }, + "state_dict": {}, + }, + }