Skip to content

Commit

Permalink
Add a migration for the dataloader loops (#17125)
Browse files Browse the repository at this point in the history
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
2 people authored and lantiga committed Mar 30, 2023
1 parent bb8d420 commit fb557b9
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 2 deletions.
30 changes: 28 additions & 2 deletions src/lightning/pytorch/utilities/migration/migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def _migration_index() -> Dict[str, List[Callable[[_CHECKPOINT], _CHECKPOINT]]]:
_drop_apex_amp_state,
_migrate_loop_structure_after_tbptt_removal,
_migrate_loop_structure_after_optimizer_loop_removal,
_migrate_loop_structure_after_dataloader_loop_removal,
],
}

Expand Down Expand Up @@ -236,7 +237,8 @@ def _migrate_loop_structure_after_tbptt_removal(checkpoint: _CHECKPOINT) -> _CHE
"""
if "loops" not in checkpoint:
return checkpoint

if "fit_loop" not in checkpoint["loops"]:
return checkpoint
fit_loop = checkpoint["loops"]["fit_loop"]

# remap `x.batch_loop.y` to `x.y`
Expand Down Expand Up @@ -273,8 +275,10 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
"""
if "loops" not in checkpoint:
return checkpoint

if "fit_loop" not in checkpoint["loops"]:
return checkpoint
fit_loop = checkpoint["loops"]["fit_loop"]

# optimizer_position is no longer used
if "epoch_loop.optimizer_loop.optim_progress" in fit_loop:
fit_loop["epoch_loop.optimizer_loop.optim_progress"].pop("optimizer_position", None)
Expand All @@ -291,3 +295,25 @@ def _migrate_loop_structure_after_optimizer_loop_removal(checkpoint: _CHECKPOINT
"epoch_loop.manual_loop.optim_step_progress"
)
return checkpoint


def _migrate_loop_structure_after_dataloader_loop_removal(checkpoint: _CHECKPOINT) -> _CHECKPOINT:
"""The dataloader loops (``_DataLoaderLoop``, ``_PredictionLoop`, and ``_EvaluationLoop``) were flattened into
the ``_EvaluationEpochLoop`` (now ``_EvaluationLoop``) and ``_PredictionEpochLoop`` (now ``_PredictionLoop``).
Version: 2.0.0
Commit: ec4f592ecfe238edd83185f6c6905fb1e2406d61
PR: #16726
"""
if "loops" not in checkpoint:
return checkpoint
loops = checkpoint["loops"]
for loop_key in ("predict_loop", "validate_loop", "test_loop"):
if loop_key not in loops:
continue
loop = loops[loop_key]
loop.pop("dataloader_progress", None) # no longer used
epoch_loop_key = "epoch_loop."
epoch_loop_dict = {k[len(epoch_loop_key) :]: loop.pop(k) for k in list(loop) if k.startswith(epoch_loop_key)}
loop.update(epoch_loop_dict)
return checkpoint
49 changes: 49 additions & 0 deletions tests/tests_pytorch/utilities/migration/test_migration.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,3 +227,52 @@ def test_migrate_loop_structure_after_optimizer_loop_removal():
"epoch_loop.manual_optimization.optim_step_progress": optim_progress_manual,
}
}


def test_migrate_loop_structure_after_dataloader_loop_removal():
"""Test the loop state migration after the dataloader loops were removed in 2.0.0."""
old_dataloader_loop_state_dict = {
"state_dict": {},
"dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
"epoch_loop.state_dict": {},
"epoch_loop.batch_progress": {
"total": {"ready": 123, "started": 0, "processed": 0, "completed": 0},
"current": {"ready": 0, "started": 0, "processed": 0, "completed": 0},
"is_last_batch": False,
},
}
old_checkpoint = {
"loops": {
"predict_loop": old_dataloader_loop_state_dict,
"validate_loop": dict(old_dataloader_loop_state_dict), # copy
"test_loop": dict(old_dataloader_loop_state_dict), # copy
}
}
_set_version(old_checkpoint, "1.9.0") # pretend a checkpoint prior to 2.0.0
updated_checkpoint, _ = migrate_checkpoint(old_checkpoint.copy(), target_version="2.0.0")
assert updated_checkpoint["loops"] == {
"predict_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
"test_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
"validate_loop": {
"batch_progress": {
"current": {"completed": 0, "processed": 0, "ready": 0, "started": 0},
"is_last_batch": False,
"total": {"completed": 0, "processed": 0, "ready": 123, "started": 0},
},
"state_dict": {},
},
}

0 comments on commit fb557b9

Please sign in to comment.