Skip to content

Commit

Permalink
【AutoParallel】Change llama in auto-parallel (#8151)
Browse files Browse the repository at this point in the history
* change llama in auto

* change ci

* polish
  • Loading branch information
heavyrain-lzy authored Mar 26, 2024
1 parent 89bff20 commit 4d49a3e
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 9 deletions.
8 changes: 0 additions & 8 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -1071,14 +1071,6 @@ def forward(self, prediction_scores, masked_lm_labels):
)
self.loss_func = paddle.nn.CrossEntropyLoss(reduction="none", ignore_index=self.ignore_index)

# Force Replicated to match dy & st
prediction_scores = dist.reshard(
prediction_scores,
get_mesh(-1),
[dist.Replicate(), dist.Replicate(), dist.Replicate()],
)
masked_lm_labels = dist.reshard(masked_lm_labels, get_mesh(-1), [dist.Replicate(), dist.Replicate()])

# Force entropy same kernel
with paddle.amp.auto_cast(False):
if isinstance(prediction_scores, paddle.Tensor):
Expand Down
2 changes: 1 addition & 1 deletion scripts/distribute/ci_case_auto.sh
Original file line number Diff line number Diff line change
Expand Up @@ -1503,7 +1503,7 @@ function llama_dygraph_auto_bs8_fp16_DP2-MP2-PP2() {
ips=-1
mem=-1
echo "result: loss=$loss ips=$ips mem=$mem"
loss_base=9.38257694
loss_base=9.38256836
ips_base=-1
mem_base=-1
check_result $FUNCNAME ${loss_base} ${loss} ${ips_base} ${ips} ${mem_base} ${mem}
Expand Down
2 changes: 2 additions & 0 deletions scripts/distribute/run_ci.sh
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ target_lists_for_gpt=(

target_lists_for_llama=(
"llm/llama/auto_parallel"
"paddlenlp/trainer/auto_trainer.py"
"paddlenlp/transformers/llama/modeling_auto_static.py"
"paddlenlp/transformers/llama/modeling_auto.py"
"scripts/distribute"
)

Expand Down

0 comments on commit 4d49a3e

Please sign in to comment.