Skip to content

Commit effaef3

Browse files
fix: prevent second save in the end of training if last step was saved already (#36219)
* fix: prevent second save in the end of training * fix: prevent second save in the end of training * test: added test for no duplicate save on epoch save strategy * fix: removed TrainerControl * chore: style formatting --------- Co-authored-by: JaktensTid <jaktenstid1@gmail.com>
1 parent 5412ff1 commit effaef3

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

src/transformers/trainer_callback.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -600,7 +600,7 @@ def on_step_end(self, args: TrainingArguments, state: TrainerState, control: Tra
600600
if state.global_step >= state.max_steps:
601601
control.should_training_stop = True
602602
# Save the model at the end if we have a save strategy
603-
if args.save_strategy not in [SaveStrategy.NO, SaveStrategy.BEST]:
603+
if args.save_strategy == SaveStrategy.STEPS:
604604
control.should_save = True
605605

606606
return control

tests/trainer/test_trainer_callback.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -425,3 +425,21 @@ def test_stateful_control(self):
425425
trainer.state = TrainerState.load_from_json(os.path.join(checkpoint, TRAINER_STATE_NAME))
426426
trainer._load_callback_state()
427427
assert trainer.control.should_training_stop
428+
429+
def test_no_duplicate_save_on_epoch_save_strategy(self):
430+
times_saved = 0
431+
432+
class OnEndCallback(TrainerCallback):
433+
def on_step_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
434+
nonlocal times_saved
435+
if control.should_save:
436+
times_saved += 1
437+
438+
def on_epoch_end(self, args: TrainingArguments, state: TrainerState, control, **kwargs):
439+
nonlocal times_saved
440+
if control.should_save:
441+
times_saved += 1
442+
443+
trainer = self.get_trainer(max_steps=2, save_strategy="epoch", callbacks=[OnEndCallback])
444+
trainer.train()
445+
assert times_saved == 1

0 commit comments

Comments
 (0)