Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unified Checkpoint] Fix uc lora config, fix release_grads #9082

Merged
merged 6 commits into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions paddlenlp/trainer/plugins/unified_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@
# save the config
config_to_save = save_config(model_to_save)
# Attach architecture to the config
config_to_save.architectures = [model_to_save.__class__.__name__]
if isinstance(model_to_save, LoRAModel) or isinstance(model_to_save, PrefixModelForCausalLM):
config_to_save.architectures = [model_to_save.model.__class__.__name__]

Check warning on line 353 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L352-L353

Added lines #L352 - L353 were not covered by tests
else:
config_to_save.architectures = [model_to_save.__class__.__name__]

Check warning on line 355 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L355

Added line #L355 was not covered by tests
if self.args.should_save:
config_to_save.save_pretrained(save_directory)
paddle.device.cuda.empty_cache()
Expand Down Expand Up @@ -560,7 +563,11 @@
optim_state_dict = load_single_card_optimizer(self.args, model, optimizer, resume_from_checkpoint)
return optim_state_dict

if "ignore_merge_optimizer" in self.args.unified_checkpoint_config:
has_merge_optimizer_safetensors = distributed_isfile(

Check warning on line 566 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L566

Added line #L566 was not covered by tests
os.path.join(resume_from_checkpoint, SAFE_OPTIMIZER_INDEX_NAME)
)
# If not having merge optimizer, then load non-merge optimizer.
if not has_merge_optimizer_safetensors:

Check warning on line 570 in paddlenlp/trainer/plugins/unified_checkpoint.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/plugins/unified_checkpoint.py#L570

Added line #L570 was not covered by tests
if self.args.data_parallel_rank == 0:
returned_optim_state_dict = self.load_non_merge_optimizer(
model,
Expand Down
11 changes: 11 additions & 0 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2241,6 +2241,7 @@
self.args.unified_checkpoint_config = unified_checkpoint_config_backup
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
os.makedirs(output_dir, exist_ok=True)

Check warning on line 2244 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2244

Added line #L2244 was not covered by tests
if self.is_in_train:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
paddle.save(global_rank, os.path.join(output_dir, f".model_weight.done.{global_rank}"))
Expand All @@ -2259,6 +2260,7 @@
and "async_save" in self.args.unified_checkpoint_config
and not self.is_in_train
):
os.makedirs(output_dir, exist_ok=True)

Check warning on line 2263 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2263

Added line #L2263 was not covered by tests
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
paddle.save(self.state.global_step, os.path.join(output_dir, f".model_weight.done.{global_rank}"))

Expand Down Expand Up @@ -2336,6 +2338,7 @@
else:
if self.args.unified_checkpoint and "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(output_dir, exist_ok=True)

Check warning on line 2341 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2341

Added line #L2341 was not covered by tests
paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}"))
Expand Down Expand Up @@ -2367,6 +2370,14 @@

if self.do_grad_scaling:
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
else:
if self.args.unified_checkpoint and not self.args.use_hybrid_parallel:
if "async_save" in self.args.unified_checkpoint_config:
global_rank = paddle.distributed.get_rank() if paddle.distributed.get_world_size() > 1 else -1
os.makedirs(output_dir, exist_ok=True)
paddle.save(global_rank, os.path.join(output_dir, f".optimizer_weight.done.{global_rank}"))
if "skip_save_model_weight" not in self.args.unified_checkpoint_config:
paddle.save(global_rank, os.path.join(output_dir, f".master_weight.done.{global_rank}"))

Check warning on line 2380 in paddlenlp/trainer/trainer.py

View check run for this annotation

Codecov / codecov/patch

paddlenlp/trainer/trainer.py#L2374-L2380

Added lines #L2374 - L2380 were not covered by tests

self.runtime_timer.stop()
# Determine the new best metric / best model checkpoint
Expand Down
2 changes: 1 addition & 1 deletion paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,7 +1133,7 @@ def split_parallel_config(parallel_config):
"dp_comm_overlap": enable_dp_comm_overlap,
"sharding_comm_overlap": enable_sharding_comm_overlap,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config or self.release_grads,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
Expand Down
Loading