Skip to content

Commit

Permalink
update async_save_info
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmonDay committed Sep 23, 2024
1 parent 7deb33c commit c44d385
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2236,16 +2236,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
self.model_wrapped.get_all_parameters(convert2cpu=True)

if self.args.should_save_model_state:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = []

self._save(output_dir=output_dir, merge_tensor_parallel=merge_tensor_parallel)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)
Expand Down Expand Up @@ -2523,10 +2514,9 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
# Save a trained model and configuration using `save_pretrained()`.
# They can then be reloaded using `from_pretrained()`

local_rank = int(os.getenv("PADDLE_RANK_IN_NODE", 0))
if (
strtobool(os.getenv("FLAG_LLM_PDC", "False"))
and local_rank == 0
and paddle.distributed.get_rank() == 0

Check warning on line 2519 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2519

Added line #L2519 was not covered by tests
and self.args.unified_checkpoint
and "async_save" in self.args.unified_checkpoint_config
):
Expand All @@ -2537,9 +2527,10 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
"ignore_save_lr_and_optim": self.args.ignore_save_lr_and_optim,
"skip_save_model_weight": "skip_save_model_weight" in self.args.unified_checkpoint_config,
}
if not os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")):
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)
if os.path.exists(os.path.join(self.args.logging_dir, "async_save_info.json")): # afs cannot overwrite

Check warning on line 2530 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2530

Added line #L2530 was not covered by tests
os.remove(os.path.join(self.args.logging_dir, "async_save_info.json"))
with open(os.path.join(self.args.logging_dir, "async_save_info.json"), "w") as f:
json.dump(save_info, f)

Check warning on line 2533 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2532-L2533

Added lines #L2532 - L2533 were not covered by tests

if self.args.should_save:
if self.tokenizer is not None:
Expand All @@ -2548,7 +2539,17 @@ def _save(self, output_dir: Optional[str] = None, state_dict=None, merge_tensor_
paddle.save(self.args, os.path.join(output_dir, TRAINING_ARGS_NAME))

if self.args.unified_checkpoint:
unified_checkpoint_config_backup = self.args.unified_checkpoint_config
# backup and remove unified_checkpoint_config for not trine stage

Check warning on line 2543 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2542-L2543

Added lines #L2542 - L2543 were not covered by tests
if not self.is_in_train:
self.args.unified_checkpoint_config = []

Check warning on line 2546 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2546

Added line #L2546 was not covered by tests
self.unified_checkpoint_handler.save_unified_checkpoint(self.model, self.optimizer, output_dir)

# recover unified_checkpoint_config for not trine stage
if not self.is_in_train:
self.args.unified_checkpoint_config = unified_checkpoint_config_backup

Check warning on line 2552 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2550-L2552

Added lines #L2550 - L2552 were not covered by tests
return

merge_tensor_parallel = merge_tensor_parallel and self.args.use_hybrid_parallel
Expand Down

0 comments on commit c44d385

Please sign in to comment.