@@ -2115,7 +2115,7 @@ def _load_best_model(self):
21152115 state_dict ["_smp_is_partial" ] = False
21162116 load_result = model .load_state_dict (state_dict , strict = True )
21172117 elif self .is_fsdp_enabled :
2118- load_fsdp_model (
2118+ load_result = load_fsdp_model (
21192119 self .accelerator .state .fsdp_plugin , self .accelerator , model , self .state .best_model_checkpoint
21202120 )
21212121 else :
@@ -2298,6 +2298,7 @@ def _save_checkpoint(self, model, trial, metrics=None):
22982298 # Needs to be called on all ranks to gather all states.
22992299 # full_optim_state_dict will be deprecated after Pytorch 2.2!
23002300 full_osd = self .model .__class__ .full_optim_state_dict (self .model , self .optimizer )
2301+ torch .save (full_osd , os .path .join (output_dir , OPTIMIZER_NAME ))
23012302
23022303 if is_torch_tpu_available ():
23032304 xm .rendezvous ("saving_optimizer_states" )
@@ -2321,12 +2322,9 @@ def _save_checkpoint(self, model, trial, metrics=None):
23212322 reissue_pt_warnings (caught_warnings )
23222323 if self .do_grad_scaling :
23232324 torch .save (self .scaler .state_dict (), os .path .join (output_dir , SCALER_NAME ))
2324- elif self .args .should_save and not self .is_deepspeed_enabled :
2325+ elif self .args .should_save and not self .is_deepspeed_enabled and not ( self . fsdp or self . is_fsdp_enabled ) :
23252326 # deepspeed.save_checkpoint above saves model/optim/sched
2326- if self .fsdp and not self .is_fsdp_enabled :
2327- torch .save (full_osd , os .path .join (output_dir , OPTIMIZER_NAME ))
2328- else :
2329- torch .save (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
2327+ torch .save (self .optimizer .state_dict (), os .path .join (output_dir , OPTIMIZER_NAME ))
23302328
23312329 with warnings .catch_warnings (record = True ) as caught_warnings :
23322330 torch .save (self .lr_scheduler .state_dict (), os .path .join (output_dir , SCHEDULER_NAME ))
@@ -2731,10 +2729,16 @@ def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = Fa
27312729 or self .fsdp is not None
27322730 or self .is_fsdp_enabled
27332731 ):
2734- state_dict = self .model .state_dict ()
2732+ state_dict = self .model .state_dict () if not self . is_fsdp_enabled else {}
27352733 if self .args .should_save :
27362734 self ._save (output_dir , state_dict = state_dict )
27372735 if self .is_fsdp_enabled :
2736+ # remove the dummy state_dict saved above
2737+ if self .args .should_save :
2738+ for filename in [WEIGHTS_NAME , SAFE_WEIGHTS_NAME ]:
2739+ file = os .path .join (output_dir , filename )
2740+ if os .path .isfile (file ):
2741+ os .remove (file )
27382742 save_fsdp_model (self .accelerator .state .fsdp_plugin , self .accelerator , self .model , output_dir )
27392743
27402744 elif self .is_deepspeed_enabled :
0 commit comments