From 3033dd05d19fa91ea7c79e12042a8f98d545ed52 Mon Sep 17 00:00:00 2001 From: pedrobrs Date: Tue, 27 Aug 2024 09:33:35 -0300 Subject: [PATCH] Update stateful_callbacks state before saving checkpoint (#32115) * update ExportableState callbacks state before saving trainer_state on save_checkpoint * run make fixup and fix format * manage multiple stateful callbacks of same class --- src/transformers/trainer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py index 2684b3f0ec0db1..5b34cf695783dd 100755 --- a/src/transformers/trainer.py +++ b/src/transformers/trainer.py @@ -2979,8 +2979,16 @@ def _save_checkpoint(self, model, trial, metrics=None): # Save the Trainer state if self.args.should_save: - # Update the `TrainerControl` state to where we are currently - self.state.stateful_callbacks["TrainerControl"] = self.control.state() + # Update `ExportableState` callbacks and `TrainerControl` state to where we are currently + for cb in [ + cb for cb in self.callback_handler.callbacks + [self.control] if isinstance(cb, ExportableState) + ]: + cb_name = cb.__class__.__name__ + cb_state = cb.state() + if isinstance(self.state.stateful_callbacks[cb_name], list): + self.state.stateful_callbacks[cb_name].append(cb_state) + else: + self.state.stateful_callbacks[cb_name] = cb_state self.state.save_to_json(os.path.join(output_dir, TRAINER_STATE_NAME)) if self.args.push_to_hub: