Skip to content

Commit

Permalink
[AutoParallel][PIR] Fit pir grad merge (#8985)
Browse files Browse the repository at this point in the history
* fit pir llm grad merge

* fit code style

* add grad merge llama test

* Update run_pretrain_auto.py

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* Update ci_case_auto.sh

* remove test

* Update ci_case_auto.sh

* Update ci_case_auto.sh
  • Loading branch information
AndSonder authored Sep 2, 2024
1 parent cda9594 commit e204b6d
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 1 deletion.
10 changes: 9 additions & 1 deletion paddlenlp/trainer/auto_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ def loss_func(loss, outputs):

self.global_mesh = fleet.auto.get_mesh()
self.comm_group_in_pp = fleet.get_hybrid_communicate_group().get_pipe_parallel_group()
self._in_pir_mode = paddle.base.framework.get_flags("FLAGS_enable_pir_api")["FLAGS_enable_pir_api"]

def _nested_gather(self, tensors):
"""
Expand Down Expand Up @@ -164,6 +165,9 @@ def _split_batches_for_accumulation(self, inputs):
if self.args.to_static and self.args.pipeline_parallel_degree > 1:
return [inputs]

if self.args.to_static and self._in_pir_mode and self.args.gradient_accumulation_steps > 1:
return [inputs]

local_batches = [{} for i in range(self.args.gradient_accumulation_steps)]
assert isinstance(inputs, dict)

Expand Down Expand Up @@ -345,7 +349,11 @@ def _inner_training_loop(
with _exec_mode_guard("dynamic"):
tr_loss += tr_loss_step

disable_accumulation = self.args.pipeline_parallel_degree > 1 and self.args.to_static
disable_accumulation = False
if self.args.pipeline_parallel_degree > 1 and self.args.to_static:
disable_accumulation = True
if self.args.to_static and self._in_pir_mode and self.args.gradient_accumulation_steps > 1:
disable_accumulation = True
# disable_accumulation = self.args.to_static

if (step_control + 1) % args.gradient_accumulation_steps == 0 or (
Expand Down
1 change: 1 addition & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -1462,6 +1462,7 @@ def is_segment_parallel_supported():
pipeline.accumulate_steps = self.gradient_accumulation_steps
pipeline.micro_batch_size = self.per_device_train_batch_size
pipeline.schedule_mode = self.pipeline_schedule_mode
pipeline.pp_degree = self.pipeline_parallel_degree

logger.info(f"PP configs:{strategy.pipeline}, use master_grad: {self.amp_master_grad}")

Expand Down

0 comments on commit e204b6d

Please sign in to comment.