Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Backup and restore callback #701

Merged
merged 7 commits into from
Sep 21, 2023

Conversation

adi-kmt
Copy link
Contributor

@adi-kmt adi-kmt commented Aug 11, 2023

Based on the callback issue

@adi-kmt adi-kmt marked this pull request as draft August 11, 2023 13:50
Copy link
Contributor

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
Get training state from temporary file and restore it
"""
super().on_train_begin()
self.set_model(keras_core.saving.load_model(filepath=self.backup_dir))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this the right approach? Should we rather take self.model and load its state?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not able to directly use self.model but self._model seems to be working

@adi-kmt adi-kmt marked this pull request as ready for review August 23, 2023 13:44
@adi-kmt adi-kmt requested a review from fchollet August 23, 2023 13:46
@adi-kmt adi-kmt changed the title [WIP}Backup and restore callback Backup and restore callback Aug 23, 2023
@sampathweb sampathweb self-requested a review August 23, 2023 20:52
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
keras_core/callbacks/backup_and_restore_callback.py Outdated Show resolved Hide resolved
@adi-kmt adi-kmt requested a review from sampathweb August 29, 2023 16:57
@sampathweb
Copy link
Collaborator

Thanks for making the changes. I will review it today.



@keras_core_export("keras_core.callbacks.BackupAndRestoreCallback")
class BackupAndRestoreCallback(Callback):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just BackupAndRestore

`BackupAndRestore` callback of another training run,
or by another callback
(e.g. `ModelCheckpoint`) of the same training.
save_freq: `'epoch'`, integer, or `False`. When set to `'epoch'`
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Use double quotes (") as the quote character

raise ValueError("Empty `backup_dir` argument passed")
self.file_path = file_path

if (not save_freq) and (not save_before_preemption):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No need for parens here

self._model = keras_core.saving.load_model(filepath=self.file_path)

def on_train_end(self, logs=None):
if self.delete_checkpoint and \
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove \ line breaks and let black handle the code formatting

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is from an old file, don't see it in my code 😅

@adi-kmt adi-kmt requested a review from fchollet September 1, 2023 03:14
self._cleanup_checkpoint()

def on_epoch_begin(self, epoch, logs=None):
if self.delete_checkpoint and self._check_checkpoints_exists(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we delete checkpoint on Epoch begin? If an interrupt happens during training, it will not be able to recover. Deletion of checkpoint should only happen in on_train_end. I checked to confirm that this logic of deleting in on_epoch_begin is not there in tf.keras implementation.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The issue I was trying to consider is that there should be only a single checkpoint at a time according to the documentation, but I understand your point, if the saving frequency is epoch and the model get interrupted between the epoch it will not be able to recover.
So I'm thinking of adding it in here i.e. just before saving the weights.
Then also delete it at train end.

Copy link
Collaborator

@sampathweb sampathweb Sep 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With overwrite=True, why do we need to delete before writing again? Deleting at on_train_end makes sense and would be consistent with tf.keras

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the delete checkpoint logic in on_epoch_begin, I am good with this PR.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done 🥸

@adi-kmt adi-kmt requested a review from sampathweb September 1, 2023 06:47
self._cleanup_checkpoint()

def on_epoch_begin(self, epoch, logs=None):
if self.delete_checkpoint and self._check_checkpoints_exists(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After removing the delete checkpoint logic in on_epoch_begin, I am good with this PR.

@adi-kmt adi-kmt requested a review from sampathweb September 1, 2023 16:41

Args:
file_path: String, path to store the checkpoint.
e.g. `backup_dir = os.path.join(working_dir, "backup")`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Please use 4 space indent
  • We should stick to the original argument, backup_dir. Then we can use it to generate the filepath internally. Note that the filepath needs to end in .weights.h5.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kamathis4 - Can you address this comment

... def on_epoch_begin(self, epoch, logs=None):
... if epoch == 4:
... raise RuntimeError('Interrupting!')
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(This is what is being shown in this code example)

"""Returns the file path for checkpoint."""

try:
# `filepath` may contain placeholders such as
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the filename will be autogenerated you can just include epoch in the filename you pick -- e.g. backup_at_epoch_{epoch}.weights.h5

Get training state from temporary file and restore it
"""
if self._check_checkpoints_exists(self.file_path):
self._model.load_weights(filepath=self.file_path)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are not restoring the "Current Epoch Number", so the model.fit that happens after the restore re-trains for all the epochs. In tf.keras implementation, it restores the current epoch and model.fit only trains for the remaining epochs. See example in https://keras.io/api/callbacks/backup_and_restore/

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Marking it resolved here. We will address it in subsequent PR when we add the hook in model to set initial_epoch.

x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5
)

self.assertEqual(cbk._current_epoch, 4)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you also assert that len(hist["loss"]) == 3 to confirm that it only ran for 3 additional epochs and not the entire 5 epochs again? I would expect that assertion to fail since we don't seem to restore the epoch in on_train_begin.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@fchollet - What do you think we should do? Is there a good way we can set the iniital_epoch of the model? Also, we are not restoring the Optimizer state that tf.keras had restored when it saved it in training checkpoint. Any suggestions.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Checked with Francois and addressed both questions -

  1. Model saving also includes optimizer, so optimizer state is also restored when backup is restored.
  2. On setting initial_epoch, we may need to add that hook on model before we can set it in the on_train_begin of BackupAndRestore. We will address that in a subsequent PR.
    We can make this conversation resolved.

@fchollet
Copy link
Contributor

Hi @kamathis4 -- are you still working on this PR? Should we take it over?

@adi-kmt
Copy link
Contributor Author

adi-kmt commented Sep 21, 2023

Hi @kamathis4 -- are you still working on this PR? Should we take it over?

Hey @fchollet, please do take over. Have been really busy with work and don't seem to be able to make time for oss.

@adi-kmt adi-kmt requested a review from fchollet September 21, 2023 03:32
@fchollet
Copy link
Contributor

Ok -- merging and patching up on top. Thanks for the contribution!

@fchollet fchollet merged commit 5c11fe6 into keras-team:main Sep 21, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants