From 93a9b2cd8db0dfec13acc4b436d67ae236c53e9c Mon Sep 17 00:00:00 2001 From: Siming Dai <908660116@qq.com> Date: Thu, 12 Sep 2024 17:49:21 +0800 Subject: [PATCH] [Unified Checkpoint] Fix uc lora config, fix release_grads (#9082) * [Unified checkpoint] update optimizer async save signal --------- Co-authored-by: gongenlei --- paddlenlp/trainer/plugins/unified_checkpoint.py | 11 +++++++++-- paddlenlp/trainer/trainer.py | 11 +++++++++++ paddlenlp/trainer/training_args.py | 2 +- 3 files changed, 21 insertions(+), 3 deletions(-) diff --git a/paddlenlp/trainer/plugins/unified_checkpoint.py b/paddlenlp/trainer/plugins/unified_checkpoint.py index 56183485cca8..f35b23f95050 100644 --- a/paddlenlp/trainer/plugins/unified_checkpoint.py +++ b/paddlenlp/trainer/plugins/unified_checkpoint.py @@ -349,7 +349,10 @@ def save_unified_checkpoint(self, model, optimizer, output_dir): # save the config config_to_save = save_config(model_to_save) # Attach architecture to the config - config_to_save.architectures = [model_to_save.__class__.__name__] + if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM): + config_to_save.architectures = [model_to_save.model.__class__.__name__] + else: + config_to_save.architectures = [model_to_save.__class__.__name__] if self.args.should_save: config_to_save.save_pretrained(save_directory) paddle.device.cuda.empty_cache() @@ -560,7 +563,11 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint) optim_state_dict = load_single_card_optimizer(self.args, model, optimizer, resume_from_checkpoint) return optim_state_dict - if "ignore_merge_optimizer" in self.args.unified_checkpoint_config: + has_merge_optimizer_safetensors = distributed_isfile( + os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME) + ) + # If not having merge optimizer, then load non-merge optimizer. + if not has_merge_optimizer_safetensors: if self.args.data_parallel_rank == 0: returned_optim_state_dict = self.load_non_merge_optimizer( model, diff --git a/paddlenlp/trainer/trainer.py b/paddlenlp/trainer/trainer.py index 738e73c53892..0ddad3cbf985 100644 --- a/paddlenlp/trainer/trainer.py +++ b/paddlenlp/trainer/trainer.py @@ -2241,6 +2241,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op self.args.unified_checkpoint_config = unified_checkpoint_config_backup else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: + os.makedirs(output_dir, exist_ok=True) if self.is_in_train: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}")) @@ -2259,6 +2260,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op and "async_save" in self.args.unified_checkpoint_config and not self.is_in_train ): + os.makedirs(output_dir, exist_ok=True) global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}")) @@ -2336,6 +2338,7 @@ def _save_checkpoint(self, model, metrics=None): else: if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config: global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + os.makedirs(output_dir, exist_ok=True) paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) if "skip_save_model_weight" not in self.args.unified_checkpoint_config: paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) @@ -2367,6 +2370,14 @@ def _save_checkpoint(self, model, metrics=None): if self.do_grad_scaling: paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME)) + else: + if self.args.unified_checkpoint and not self.args.use_hybrid_parallel: + if "async_save" in self.args.unified_checkpoint_config: + global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1 + os.makedirs(output_dir, exist_ok=True) + paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}")) + if "skip_save_model_weight" not in self.args.unified_checkpoint_config: + paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}")) self.runtime_timer.stop() # Determine the new best metric / best model checkpoint diff --git a/paddlenlp/trainer/training_args.py b/paddlenlp/trainer/training_args.py index 3ce797341b37..e38a528a7775 100644 --- a/paddlenlp/trainer/training_args.py +++ b/paddlenlp/trainer/training_args.py @@ -1133,7 +1133,7 @@ def split_parallel_config(parallel_config): "dp_comm_overlap": enable_dp_comm_overlap, "sharding_comm_overlap": enable_sharding_comm_overlap, "enable_timer": "enable_timer" in pipeline_parallel_config, - "release_gradients": "enable_release_grads" in pipeline_parallel_config, + "release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads, "overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config, "clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config, "use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,