Skip to content

Commit

Permalink
[Unified Checkpoint] Fix uc lora config, fix release_grads (#9082)
Browse files Browse the repository at this point in the history
* [Unified checkpoint] update optimizer async save signal
---------

Co-authored-by: gongenlei <gongenlei@baidu.com>
  • Loading branch information
DesmonDay and gongel authored Sep 12, 2024
1 parent d3302c5 commit 93a9b2c
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 3 deletions.
11 changes: 9 additions & 2 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ def save_unified_checkpoint(self, model, optimizer, output_dir):
# save the config
config_to_save = save_config(model_to_save)
# Attach architecture to the config
config_to_save.architectures = [model_to_save.__class__.__name__]
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
config_to_save.architectures = [model_to_save.model.__class__.__name__]
else:
config_to_save.architectures = [model_to_save.__class__.__name__]
if self.args.should_save:
config_to_save.save_pretrained(save_directory)
paddle.device.cuda.empty_cache()
Expand Down Expand Up @@ -560,7 +563,11 @@ def load_unified_optimizer(self, args, model, optimizer, resume_from_checkpoint)
optim_state_dict = load_single_card_optimizer(self.args, model, optimizer, resume_from_checkpoint)
return optim_state_dict

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
has_merge_optimizer_safetensors = distributed_isfile(
os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME)
)
# If not having merge optimizer, then load non-merge optimizer.
if not has_merge_optimizer_safetensors:
if self.args.data_parallel_rank == 0:
returned_optim_state_dict = self.load_non_merge_optimizer(
model,
Expand Down
11 changes: 11 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)
if self.is_in_train:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}"))
Expand All @@ -2259,6 +2260,7 @@ def save_model(self, output_dir: Optional[str] = None, merge_tensor_parallel: Op
and "async_save" in self.args.unified_checkpoint_config
and not self.is_in_train
):
os.makedirs(output_dir, exist_ok=True)
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}"))

Expand Down Expand Up @@ -2336,6 +2338,7 @@ def _save_checkpoint(self, model, metrics=None):
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(output_dir, exist_ok=True)
paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}"))
Expand Down Expand Up @@ -2367,6 +2370,14 @@ def _save_checkpoint(self, model, metrics=None):

if self.do_grad_scaling:
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
else:
if self.args.unified_checkpoint and not self.args.use_hybrid_parallel:
if "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(output_dir, exist_ok=True)
paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}"))

self.runtime_timer.stop()
# Determine the new best metric / best model checkpoint
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def split_parallel_config(parallel_config):
"dp_comm_overlap": enable_dp_comm_overlap,
"sharding_comm_overlap": enable_sharding_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
Expand Down

0 comments on commit 93a9b2c

Please sign in to comment.