Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backup and restore callback #701

Merged
merged 7 commits into from
Sep 21, 2023
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Comments fixed, tests written
  • Loading branch information
adi-kmt committed Aug 23, 2023
commit 70abdc7ed8c52172f1870f512fbff26d71c6bdc4
138 changes: 78 additions & 60 deletions keras_core/callbacks/backup_and_restore_callback.py
Original file line number Diff line number Diff line change
@@ -1,59 +1,83 @@
import os

import keras_core.saving

from keras_core.api_export import keras_core_export
from keras_core.callbacks.callback import Callback
from keras_core.utils import file_utils
from keras_core.utils import io_utils


@keras_core_export("keras_core.callbacks.BackupAndRestoreCallback")
class BackupAndRestoreCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just BackupAndRestore

"""
Callback to back up and restore the training state.

BackupAndRestore callback is intended to recover training from an
interruption that has happened in the middle of a Model.fit execution,
by backing up the training states in a temporary checkpoint file at the
end of each epoch. Each backup overwrites the previously written
checkpoint file, so at any given time there is at most one such
checkpoint file for backup/restoring purpose.

If training restarts before completion, the training state (which
includes the Model weights and epoch number) is restored to the most
recently saved state at the beginning of a new Model.fit run. At the
completion of a Model.fit run, the temporary checkpoint file is deleted.

Note that the user is responsible to bring jobs back after the
interruption. This callback is important for the backup and restore
mechanism for fault tolerance purpose, and the model to be restored from
a previous checkpoint is expected to be the same as the one used to back
up. If user changes arguments passed to compile or fit, the checkpoint
saved for fault tolerance can become invalid.
"""Callback to back up and restore the training state.

`BackupAndRestore` callback is intended to recover training from an
interruption that has happened in the middle of a `Model.fit` execution, by
backing up the training states in a temporary checkpoint file (with the help
of a `tf.train.CheckpointManager`), at the end of each epoch. Each backup
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved
overwrites the previously written checkpoint file, so at any given time
there is at most one such checkpoint file for backup/restoring purpose.

If training restarts before completion, the training state (which includes
the `Model` weights and epoch number) is restored to the most recently saved
state at the beginning of a new `Model.fit` run. At the completion of a
`Model.fit` run, the temporary checkpoint file is deleted.

Note that the user is responsible to bring jobs back after the interruption.
This callback is important for the backup and restore mechanism for fault
tolerance purpose, and the model to be restored from a previous checkpoint
is expected to be the same as the one used to back up. If user changes
arguments passed to compile or fit, the checkpoint saved for fault tolerance
can become invalid.

Example:

>>> class InterruptingCallback(tf.keras.callbacks.Callback):
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved
... def on_epoch_begin(self, epoch, logs=None):
... if epoch == 4:
... raise RuntimeError('Interrupting!')
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is what is being shown in this code example)

>>> model = keras.models.Sequential([tf.keras.layers.Dense(10)])
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved
>>> model.compile(keras.optimizers.SGD(), loss='mse')
>>> try:
... model.fit(np.arange(100).reshape(5, 20), np.zeros(5), epochs=10,
... batch_size=1, callbacks=[callback, InterruptingCallback()],
... verbose=0)
... except:
... pass
>>> history = model.fit(np.arange(100).reshape(5, 20), np.zeros(5),
... epochs=10, batch_size=1, callbacks=[callback],
... verbose=0)
>>> # Only 6 more epochs are run, since first training got interrupted at
>>> # zero-indexed epoch 4, second training will continue from 4 to 9.
>>> len(history.history['loss'])
6

Args:
backup_dir: String, path to store the checkpoint. e.g. backup_dir =
os.path.join(working_dir, 'backup'). This is the directory in which
the system stores temporary files to recover the model from jobs
terminated unexpectedly.
save_freq: 'epoch', integer, or False. When set to 'epoch'
the callback saves the checkpoint at the end of each epoch. When set
to an integer, the callback saves the checkpoint every save_freq
batches. Set save_freq to False if only using preemption
checkpointing (with save_before_preemption=True).
delete_checkpoint: Boolean, default to True. This BackupAndRestore
callback works by saving a checkpoint to back up the training state.
If delete_checkpoint=True, the checkpoint will be deleted after
training is finished. Use False if you'd like to keep the checkpoint
for future usage.
file_path: String, path to store the checkpoint.
e.g. `backup_dir = os.path.join(working_dir, 'backup')`.
This is the directory in which the system stores temporary files to
recover the model from jobs terminated unexpectedly. The directory
cannot be reused elsewhere to store other files, e.g. by the
`BackupAndRestore` callback of another training run,
or by another callback
(e.g. `ModelCheckpoint`) of the same training.
save_freq: `'epoch'`, integer, or `False`. When set to `'epoch'`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use double quotes (") as the quote character

the callback saves the checkpoint at the end of each epoch.
When set to an integer, the callback saves the checkpoint every
`save_freq` batches. Set `save_freq` to `False` if only using
preemption checkpointing (with `save_before_preemption=True`).
delete_checkpoint: Boolean, default to True. This `BackupAndRestore`
callback works by saving a checkpoint to back up the training state.
If `delete_checkpoint=True`, the checkpoint will be deleted after
training is finished. Use `False` if you'd like to keep the checkpoint
for future usage.
save_before_preemption: A boolean value instructing whether to turn on
the automatic checkpoint saving for preemption/maintenance events.

the automatic checkpoint saving for preemption/maintenance events.
"""

def __init__(
self,
backup_dir,
file_path,
save_freq="epoch",
delete_checkpoint=True,
save_before_preemption=False,
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved
@@ -63,11 +87,12 @@ def __init__(
self.save_freq = save_freq
self.delete_checkpoint = delete_checkpoint
self.save_before_preemption = save_before_preemption
self._batches_seen_since_last_saving = 0
self._last_batch_seen = 0

if not backup_dir:
if not file_path:
raise ValueError("Empty `backup_dir` argument passed")
self.backup_dir = backup_dir
self.file_path = file_path

if (not save_freq) and (not save_before_preemption):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for parens here

raise ValueError(
@@ -84,28 +109,25 @@ def on_train_begin(self, logs=None):
"""
Get training state from temporary file and restore it
"""
super().on_train_begin()
self.set_model(keras_core.saving.load_model(filepath=self.backup_dir))
if self._check_checkpoints_exists(self.file_path):
self._model = keras_core.saving.load_model(filepath=self.file_path)
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved

def on_train_end(self, logs=None):
"""
Delete training state stored
"""
if self._check_checkpoints_exists(self.backup_dir):
if self.delete_checkpoint and \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove \ line breaks and let black handle the code formatting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is from an old file, don't see it in my code 😅

self._check_checkpoints_exists(self.file_path):
self._cleanup_checkpoint()

def on_epoch_begin(self, epoch, logs=None):
if self.delete_checkpoint:
if self.delete_checkpoint and \
self._check_checkpoints_exists(self.file_path):
self._cleanup_checkpoint()
self._current_epoch = epoch

def on_epoch_end(self, epoch, logs=None):
super().on_epoch_end(epoch)
if self.save_freq == "epoch":
self._save_model(epoch=epoch, batch=None, logs=logs)

def on_train_batch_end(self, batch, logs=None):
super().on_epoch_end(batch)
if self._should_save_on_batch(batch):
self._save_model(epoch=self._current_epoch, batch=batch, logs=logs)

@@ -128,10 +150,6 @@ def _save_model(self, epoch, batch, logs):

try:
self._model.save(filepath, overwrite=True)
adi-kmt marked this conversation as resolved.
Show resolved Hide resolved
if self.verbose > 0:
io_utils.print_msg(
f"\nEpoch {epoch + 1}: saving model to {filepath}"
)
except IsADirectoryError: # h5py 3.x
raise IOError(
"Please specify a non-directory filepath for "
@@ -159,14 +177,14 @@ def _get_file_path(self, epoch, batch, logs):
# logged metrics and the path's placeholders can cause formatting to
# fail.
if batch is None or "batch" in logs:
file_path = self.filepath.format(epoch=epoch + 1, **logs)
file_path = self.file_path.format(epoch=epoch + 1, **logs)
else:
file_path = self.filepath.format(
file_path = self.file_path.format(
epoch=epoch + 1, batch=batch + 1, **logs
)
except KeyError as e:
raise KeyError(
f'Failed to format this callback filepath: "{self.filepath}". '
f'Failed to format this callback filepath: "{self.file_path}". '
f"Reason: {e}"
)
return file_path
@@ -191,8 +209,8 @@ def _cleanup_checkpoint(self):
"""
Delete other checkpoint files (if present) in the directory
"""
if self._check_checkpoints_exists(filepath=self.backup_dir):
file_utils.rmtree(self.backup_dir)
if self._check_checkpoints_exists(filepath=self.file_path):
file_utils.rmtree(self.file_path)

def _check_checkpoints_exists(self, filepath):
return file_utils.exists(filepath)
Loading