Skip to content

Commit

Permalink
Add utility to restore state from snapshot (#274)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #274

Since the callback adds logic for saving the progress + dataloader state, the callback also becomes responsible for restoring those states as well

this diff adds a staticmethod on the callback to restore that state

Reviewed By: daniellepintz

Differential Revision: D40524092

fbshipit-source-id: b7c67a25a633cd8d677ab7b6da0b3868d27f2256
  • Loading branch information
ananthsub authored and facebook-github-bot committed Nov 29, 2022
1 parent 05174a4 commit 535dc56
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 22 deletions.
42 changes: 42 additions & 0 deletions tests/framework/callbacks/test_torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,48 @@ def test_save_every_n_train_epochs(self) -> None:
os.path.exists(expected_path) and os.path.isdir(expected_path)
)

def test_save_restore(self) -> None:
input_dim = 2
dataset_len = 10
batch_size = 2
max_epochs = 2
expected_steps_per_epoch = math.ceil(dataset_len / batch_size)
save_every_n_train_steps = 2

my_unit = DummyTrainUnit(input_dim=input_dim)
dataloader = generate_random_dataloader(dataset_len, input_dim, batch_size)
state = init_train_state(dataloader=dataloader, max_epochs=max_epochs)
expected_paths: List[str] = []
with tempfile.TemporaryDirectory() as temp_dir:
cumulative_steps = 0
for epoch in range(max_epochs):
for _ in range(
save_every_n_train_steps,
expected_steps_per_epoch + 1,
save_every_n_train_steps,
):
cumulative_steps += save_every_n_train_steps
expected_paths.append(
os.path.join(temp_dir, f"epoch_{epoch}_step_{cumulative_steps}")
)
snapshot_cb = TorchSnapshotSaver(
temp_dir,
save_every_n_train_steps=save_every_n_train_steps,
replicated=["**"],
)
train(state, my_unit, callbacks=[snapshot_cb])

end_num_steps_completed = state.train_state.progress.num_steps_completed
self.assertGreater(len(expected_paths), 0)
snapshot_cb.restore(expected_paths[0], state, my_unit)
restored_num_steps_completed = (
state.train_state.progress.num_steps_completed
)
# A snapshot is saved every n steps
# so the first snapshot's progress will be equal to save_every_n_train_steps
self.assertNotEqual(restored_num_steps_completed, end_num_steps_completed)
self.assertEqual(restored_num_steps_completed, save_every_n_train_steps)

def test_saver_invalid_args(self) -> None:
with tempfile.TemporaryDirectory() as temp_dir:
with self.assertRaisesRegex(
Expand Down
96 changes: 74 additions & 22 deletions torchtnt/framework/callbacks/torchsnapshot_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@
_TStateful = StatefulProtocol
_TORCHSNAPSHOT_AVAILABLE = False

EVAL_PROGRESS_STATE_KEY = "eval_progress"
RNG_STATE_KEY = "rng_state"
TRAIN_PROGRESS_STATE_KEY = "train_progress"
TRAIN_DL_STATE_KEY = "train_dataloader"
_EVAL_PROGRESS_STATE_KEY = "eval_progress"
_RNG_STATE_KEY = "rng_state"
_TRAIN_PROGRESS_STATE_KEY = "train_progress"
_TRAIN_DL_STATE_KEY = "train_dataloader"


class TorchSnapshotSaver(Callback):
Expand Down Expand Up @@ -54,23 +54,17 @@ def __init__(
save_every_n_epochs: Optional[int] = None,
replicated: Optional[List[str]] = None,
) -> None:

if not _TORCHSNAPSHOT_AVAILABLE:
raise RuntimeError(
"TorchSnapshotSaver support requires torchsnapshot. "
"Please make sure ``torchsnapshot`` is installed. "
"Installation: https://github.com/pytorch/torchsnapshot#install"
)
_validate_snapshot_available()
if save_every_n_train_steps is not None and save_every_n_train_steps < 0:
raise ValueError(
f"Invalid value passed for save_every_n_train_steps. Expected to receive either None or non-negative number, but received {save_every_n_train_steps}"
)
self._save_every_n_train_steps = save_every_n_train_steps
if save_every_n_epochs is not None and save_every_n_epochs < 0:
raise ValueError(
f"Invalid value passed for save_every_n_epochs. Expected to receive either None or non-negative number, but received {save_every_n_epochs}"
)
self._save_every_n_epochs = save_every_n_epochs
self._save_every_n_train_steps = save_every_n_train_steps
self._dirpath: str = dirpath
self._replicated: Set[str] = set(replicated or [])

Expand Down Expand Up @@ -124,6 +118,64 @@ def on_train_epoch_end(self, state: State, unit: TTrainUnit) -> None:
)
rank_zero_info(f"Saved snapshot to path: {snapshot_path}")

@staticmethod
def restore(
path: str,
state: State,
unit: TTrainUnit,
*,
restore_train_progress: bool = True,
restore_train_dataloader: bool = True,
restore_eval_progress: bool = True,
) -> None:
"""Utility method to restore snapshot state from a path.
Since the class also manages saving the progress and dataloader states,
this method handles their restoration. There are additional flags offered
should the user want to skip loading these states. By default, the train progress,
train dataloader, and eval progress are restored, if applicable.
"""

_validate_snapshot_available()
app_state = unit.app_state()
_check_app_state_collision(app_state)

snapshot = torchsnapshot.Snapshot(path)

train_state = none_throws(state.train_state)

rng_state = torchsnapshot.RNGState()
app_state[_RNG_STATE_KEY] = rng_state

if restore_train_progress:
train_progress = train_state.progress
app_state[_TRAIN_PROGRESS_STATE_KEY] = train_progress

if restore_train_dataloader:
# request to restore the dataloader state only if
# the persisted snapshot state includes the dataloader entry
manifest = snapshot.get_manifest()
for key in manifest:
if _TRAIN_DL_STATE_KEY in key:
app_state[_TRAIN_DL_STATE_KEY] = train_state.dataloader
break

if state.entry_point == EntryPoint.FIT and restore_eval_progress:
# include evaluation states if fitting
eval_state = none_throws(state.eval_state)
app_state[_EVAL_PROGRESS_STATE_KEY] = eval_state.progress

snapshot.restore(app_state)


def _validate_snapshot_available() -> None:
if not _TORCHSNAPSHOT_AVAILABLE:
raise RuntimeError(
"TorchSnapshotSaver support requires torchsnapshot. "
"Please make sure ``torchsnapshot`` is installed. "
"Installation: https://github.com/pytorch/torchsnapshot#install"
)


def _get_snapshot_save_path(dirpath: str, epoch: int, step: int) -> str:
return os.path.join(dirpath, f"epoch_{epoch}_step_{step}")
Expand All @@ -138,33 +190,33 @@ def _get_app_state(
app_state = unit.app_state()

rng_state = torchsnapshot.RNGState()
app_state[RNG_STATE_KEY] = rng_state
app_state[TRAIN_PROGRESS_STATE_KEY] = train_progress
train_prog_glob = f"{TRAIN_PROGRESS_STATE_KEY}/*"
app_state[_RNG_STATE_KEY] = rng_state
app_state[_TRAIN_PROGRESS_STATE_KEY] = train_progress
train_prog_glob = f"{_TRAIN_PROGRESS_STATE_KEY}/*"
replicated.add(train_prog_glob)

# for intra-epoch checkpointing, include dataloader states
train_dl = train_state.dataloader
if intra_epoch and isinstance(train_dl, _TStateful):
app_state[TRAIN_DL_STATE_KEY] = train_dl
app_state[_TRAIN_DL_STATE_KEY] = train_dl

if state.entry_point == EntryPoint.FIT:
# include evaluation states if fitting
eval_state = none_throws(state.eval_state)

app_state[EVAL_PROGRESS_STATE_KEY] = eval_state.progress
eval_prog_glob = f"{EVAL_PROGRESS_STATE_KEY}/*"
app_state[_EVAL_PROGRESS_STATE_KEY] = eval_state.progress
eval_prog_glob = f"{_EVAL_PROGRESS_STATE_KEY}/*"
replicated.add(eval_prog_glob)

return app_state


def _check_app_state_collision(app_state: Dict[str, _TStateful]) -> None:
keys_to_check = (
TRAIN_PROGRESS_STATE_KEY,
TRAIN_DL_STATE_KEY,
RNG_STATE_KEY,
EVAL_PROGRESS_STATE_KEY,
_TRAIN_PROGRESS_STATE_KEY,
_TRAIN_DL_STATE_KEY,
_RNG_STATE_KEY,
_EVAL_PROGRESS_STATE_KEY,
)
for key in keys_to_check:
if key in app_state:
Expand Down

0 comments on commit 535dc56

Please sign in to comment.