Skip to content

Commit 1c7e5e2

Browse files
authored
fix fsdp checkpointing issues (#24926)
* fix fsdp load * Update trainer.py * remove saving duplicate state_dict
1 parent 9ef5256 commit 1c7e5e2

File tree

1 file changed

+11
-7
lines changed

1 file changed

+11
-7
lines changed

src/transformers/trainer.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2115,7 +2115,7 @@ def _load_best_model(self):
21152115
state_dict["_smp_is_partial"] = False
21162116
load_result = model.load_state_dict(state_dict, strict=True)
21172117
elif self.is_fsdp_enabled:
2118-
load_fsdp_model(
2118+
load_result = load_fsdp_model(
21192119
self.accelerator.state.fsdp_plugin, self.accelerator, model, self.state.best_model_checkpoint
21202120
)
21212121
else:
@@ -2298,6 +2298,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
22982298
# Needs to be called on all ranks to gather all states.
22992299
# full_optim_state_dict will be deprecated after Pytorch 2.2!
23002300
full_osd = self.model.__class__.full_optim_state_dict(self.model, self.optimizer)
2301+
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
23012302

23022303
if is_torch_tpu_available():
23032304
xm.rendezvous("saving_optimizer_states")
@@ -2321,12 +2322,9 @@ def _save_checkpoint(self, model, trial, metrics=None):
23212322
reissue_pt_warnings(caught_warnings)
23222323
if self.do_grad_scaling:
23232324
torch.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
2324-
elif self.args.should_save and not self.is_deepspeed_enabled:
2325+
elif self.args.should_save and not self.is_deepspeed_enabled and not (self.fsdp or self.is_fsdp_enabled):
23252326
# deepspeed.save_checkpoint above saves model/optim/sched
2326-
if self.fsdp and not self.is_fsdp_enabled:
2327-
torch.save(full_osd, os.path.join(output_dir, OPTIMIZER_NAME))
2328-
else:
2329-
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
2327+
torch.save(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
23302328

23312329
with warnings.catch_warnings(record=True) as caught_warnings:
23322330
torch.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
@@ -2731,10 +2729,16 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
27312729
or self.fsdp is not None
27322730
or self.is_fsdp_enabled
27332731
):
2734-
state_dict = self.model.state_dict()
2732+
state_dict = self.model.state_dict() if not self.is_fsdp_enabled else {}
27352733
if self.args.should_save:
27362734
self._save(output_dir, state_dict=state_dict)
27372735
if self.is_fsdp_enabled:
2736+
# remove the dummy state_dict saved above
2737+
if self.args.should_save:
2738+
for filename in [WEIGHTS_NAME, SAFE_WEIGHTS_NAME]:
2739+
file = os.path.join(output_dir, filename)
2740+
if os.path.isfile(file):
2741+
os.remove(file)
27382742
save_fsdp_model(self.accelerator.state.fsdp_plugin, self.accelerator, self.model, output_dir)
27392743

27402744
elif self.is_deepspeed_enabled:

0 commit comments

Comments
 (0)