-
Notifications
You must be signed in to change notification settings - Fork 118
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Backup and restore callback #701
Backup and restore callback #701
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
Get training state from temporary file and restore it | ||
""" | ||
super().on_train_begin() | ||
self.set_model(keras_core.saving.load_model(filepath=self.backup_dir)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this the right approach? Should we rather take self.model
and load its state?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not able to directly use self.model but self._model seems to be working
Thanks for making the changes. I will review it today. |
|
||
|
||
@keras_core_export("keras_core.callbacks.BackupAndRestoreCallback") | ||
class BackupAndRestoreCallback(Callback): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's just BackupAndRestore
`BackupAndRestore` callback of another training run, | ||
or by another callback | ||
(e.g. `ModelCheckpoint`) of the same training. | ||
save_freq: `'epoch'`, integer, or `False`. When set to `'epoch'` |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use double quotes ("
) as the quote character
raise ValueError("Empty `backup_dir` argument passed") | ||
self.file_path = file_path | ||
|
||
if (not save_freq) and (not save_before_preemption): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for parens here
self._model = keras_core.saving.load_model(filepath=self.file_path) | ||
|
||
def on_train_end(self, logs=None): | ||
if self.delete_checkpoint and \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove \
line breaks and let black
handle the code formatting
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is from an old file, don't see it in my code 😅
self._cleanup_checkpoint() | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
if self.delete_checkpoint and self._check_checkpoints_exists( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we delete checkpoint on Epoch begin? If an interrupt happens during training, it will not be able to recover. Deletion of checkpoint should only happen in on_train_end
. I checked to confirm that this logic of deleting in on_epoch_begin
is not there in tf.keras
implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The issue I was trying to consider is that there should be only a single checkpoint at a time according to the documentation, but I understand your point, if the saving frequency is epoch and the model get interrupted between the epoch it will not be able to recover.
So I'm thinking of adding it in here i.e. just before saving the weights.
Then also delete it at train end.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
With overwrite=True
, why do we need to delete before writing again? Deleting at on_train_end
makes sense and would be consistent with tf.keras
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing the delete checkpoint logic in on_epoch_begin
, I am good with this PR.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done 🥸
self._cleanup_checkpoint() | ||
|
||
def on_epoch_begin(self, epoch, logs=None): | ||
if self.delete_checkpoint and self._check_checkpoints_exists( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After removing the delete checkpoint logic in on_epoch_begin
, I am good with this PR.
|
||
Args: | ||
file_path: String, path to store the checkpoint. | ||
e.g. `backup_dir = os.path.join(working_dir, "backup")`. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Please use 4 space indent
- We should stick to the original argument,
backup_dir
. Then we can use it to generate the filepath internally. Note that the filepath needs to end in.weights.h5
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kamathis4 - Can you address this comment
... def on_epoch_begin(self, epoch, logs=None): | ||
... if epoch == 4: | ||
... raise RuntimeError('Interrupting!') | ||
>>> callback = keras.callbacks.BackupAndRestore(backup_dir="/tmp/backup") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(This is what is being shown in this code example)
"""Returns the file path for checkpoint.""" | ||
|
||
try: | ||
# `filepath` may contain placeholders such as |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the filename will be autogenerated you can just include epoch
in the filename you pick -- e.g. backup_at_epoch_{epoch}.weights.h5
Get training state from temporary file and restore it | ||
""" | ||
if self._check_checkpoints_exists(self.file_path): | ||
self._model.load_weights(filepath=self.file_path) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are not restoring the "Current Epoch Number", so the model.fit
that happens after the restore re-trains for all the epochs. In tf.keras
implementation, it restores the current epoch and model.fit
only trains for the remaining epochs. See example in https://keras.io/api/callbacks/backup_and_restore/
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Marking it resolved here. We will address it in subsequent PR when we add the hook in model
to set initial_epoch
.
x_train, y_train, batch_size=4, callbacks=[cbk], epochs=5 | ||
) | ||
|
||
self.assertEqual(cbk._current_epoch, 4) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you also assert that len(hist["loss"]) == 3
to confirm that it only ran for 3 additional epochs and not the entire 5 epochs again? I would expect that assertion to fail since we don't seem to restore the epoch
in on_train_begin
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@fchollet - What do you think we should do? Is there a good way we can set the iniital_epoch
of the model? Also, we are not restoring the Optimizer state that tf.keras
had restored when it saved it in training checkpoint. Any suggestions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked with Francois and addressed both questions -
- Model saving also includes optimizer, so optimizer state is also restored when backup is restored.
- On setting
initial_epoch
, we may need to add that hook onmodel
before we can set it in theon_train_begin
ofBackupAndRestore
. We will address that in a subsequent PR.
We can make this conversation resolved.
Hi @kamathis4 -- are you still working on this PR? Should we take it over? |
Hey @fchollet, please do take over. Have been really busy with work and don't seem to be able to make time for oss. |
Ok -- merging and patching up on top. Thanks for the contribution! |
Based on the callback issue