Skip to content

Commit

Permalink
Update stateful_callbacks state before saving checkpoint (huggingface…
Browse files Browse the repository at this point in the history
…#32115)

* update ExportableState callbacks state before saving trainer_state on save_checkpoint

* run make fixup and fix format

* manage multiple stateful callbacks of same class
  • Loading branch information
pedrobrs authored and zucchini-nlp committed Aug 30, 2024
1 parent 714286e commit 3033dd0
Showing 1 changed file with 10 additions and 2 deletions.
12 changes: 10 additions & 2 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2979,8 +2979,16 @@ def _save_checkpoint(self, model, trial, metrics=None):

# Save the Trainer state
if self.args.should_save:
# Update the `TrainerControl` state to where we are currently
self.state.stateful_callbacks["TrainerControl"] = self.control.state()
# Update `ExportableState` callbacks and `TrainerControl` state to where we are currently
for cb in [
cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState)
]:
cb_name = cb.__class__.__name__
cb_state = cb.state()
if isinstance(self.state.stateful_callbacks[cb_name], list):
self.state.stateful_callbacks[cb_name].append(cb_state)
else:
self.state.stateful_callbacks[cb_name] = cb_state
self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME))

if self.args.push_to_hub:
Expand Down

0 comments on commit 3033dd0

Please sign in to comment.