Skip to content

Commit

Permalink
[Trainer] ignore_save_lr_and_optim (#7978)
Browse files Browse the repository at this point in the history
* ignore_save_lr_and_optim

* 只对finetune的测试开启ignore_save_lr_and_optim
  • Loading branch information
JunnYu authored Feb 22, 2024
1 parent b2be2fc commit 2edcd08
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 26 deletions.
54 changes: 28 additions & 26 deletions paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2100,36 +2100,38 @@ def _save_checkpoint(self, model, metrics=None):
else:
self.save_model(output_dir)

optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)
# only save model state dict, ignore optimizer and scheduler
if not self.args.ignore_save_lr_and_optim:
optimizer_name = _add_variant(OPTIMIZER_NAME, self.args.optimizer_name_suffix)

if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0:
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
save_unified_optimizer(
self.args,
self.model,
self.optimizer,
output_dir,
safe_serialization=True,
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)
if self.args.use_hybrid_parallel:
if self.dp_group.rank <= 0:
os.makedirs(output_dir, exist_ok=True)
logger.info("Saving optimizer files.")
if self.args.unified_checkpoint:
save_unified_optimizer(
self.args,
self.model,
self.optimizer,
output_dir,
safe_serialization=True,
)
else:
self._save_ckpt_func(
self.optimizer.state_dict(),
os.path.join(output_dir, optimizer_name),
)

if self.args.should_save:
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))
if self.args.should_save:
if not self.args.use_hybrid_parallel:
logger.info("Saving optimizer files.")
self._save_ckpt_func(self.optimizer.state_dict(), os.path.join(output_dir, OPTIMIZER_NAME))

# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))
# FIXME: maybe only save one copy
paddle.save(self.lr_scheduler.state_dict(), os.path.join(output_dir, SCHEDULER_NAME))

if self.do_grad_scaling:
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))
if self.do_grad_scaling:
paddle.save(self.scaler.state_dict(), os.path.join(output_dir, SCALER_NAME))

self.runtime_timer.stop()
# Determine the new best metric / best model checkpoint
Expand Down
4 changes: 4 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,10 @@ class TrainingArguments:
default=False,
metadata={"help": "whether to ignore load optimizer and scheduler."},
)
ignore_save_lr_and_optim: Optional[bool] = field(
default=False,
metadata={"help": "whether to ignore save optimizer and scheduler."},
)
force_reshard_pp: Optional[bool] = field(
default=False,
metadata={"help": "reshard pp even if pp degree in the model and pp degree in script match"},
Expand Down
2 changes: 2 additions & 0 deletions tests/fixtures/llm/finetune.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ finetune:
save_total_limit: 1
tensor_parallel_degree: 1
pipeline_parallel_degree: 1
ignore_save_lr_and_optim: 1

default:
llama:
model_name_or_path: __internal_testing__/tiny-random-llama
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/pretrain.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ pretrain:
use_flash_attention: 0
use_fused_rms_norm: 0
continue_training: 1

default:
llama:
model_name_or_path: __internal_testing__/tiny-random-llama
Expand Down
1 change: 1 addition & 0 deletions tests/fixtures/llm/ptq.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ ptq:
eval_with_do_generation: false
do_ptq: true
ptq_step: 4

default:
llama:
model_name_or_path: __internal_testing__/tiny-fused-llama-inference5.2
Expand Down

0 comments on commit 2edcd08

Please sign in to comment.