Skip to content

Commit

Permalink
Refine position_ids for auto parallel training of llama (#8363)
Browse files Browse the repository at this point in the history
* fix

* fix

* fix
  • Loading branch information
zhangbo9674 authored May 10, 2024
1 parent 16ef8f4 commit ac117a1
Show file tree
Hide file tree
Showing 2 changed files with 37 additions and 13 deletions.
19 changes: 19 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,16 @@ class TrainingArguments:
)
},
)
sequence_parallel_config: str = field(
default="",
metadata={
"help": (
"Some additional configs which affect sequence parallel performance, we provide some option to config it."
"following config is support:\n"
"enable_allreduce_avg_in_gradinent_scale, it replace `allreduce_sum + scale` pattern with `allreduce_avg` when scale gradient in sequence_parallel, which improve the performance. ONLY supported for auto mode now. \n"
)
},
)
tensor_parallel_config: str = field(
default="",
metadata={
Expand Down Expand Up @@ -1270,6 +1280,15 @@ def is_segment_parallel_supported():
strategy.gradient_scale_using_allreduce_avg = True
if "gradient_sync_after_accumulate" in data_parallel_config:
strategy.dp_optimization.gradient_sync_after_accumulate = True
sequence_parallel_config = set(self.sequence_parallel_config.split(" "))
for x in sequence_parallel_config:
if len(x) > 0:
if x not in ["enable_allreduce_avg_in_gradinent_scale"]:
raise ValueError(
f"Found unknown sequence parallel config {x}, accpet config is enable_allreduce_avg_in_gradinent_scale."
)
if "enable_allreduce_avg_in_gradinent_scale" in sequence_parallel_config:
strategy.gradient_scale_using_allreduce_avg = True

# navie-pp: pipeline_parallel_degree > 1 and gradient_accumulation_steps == 1
if self.pipeline_parallel_degree > 1 and self.gradient_accumulation_steps > 1:
Expand Down
31 changes: 18 additions & 13 deletions paddlenlp/transformers/llama/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,13 +534,15 @@ def forward(
else:
attn_output = outputs

if self.config.sequence_parallel:
attn_output = paddle.transpose(attn_output, [1, 0, 2])

# [bs, q_len, num_head * head_dim]
attn_output = self.o_proj(attn_output)

# enter sp region
if self.config.sequence_parallel:
# [bs, q_len, num_head * head_dim] -> [q_len / n, bs, num_head * head_dim]
attn_output = paddle.transpose(attn_output, [1, 0, 2])
attn_output = dist.reshard(
attn_output,
get_mesh(self.ipp),
Expand Down Expand Up @@ -953,14 +955,14 @@ def forward(
inputs_embeds = paddle.transpose(inputs_embeds, [1, 0, 2])

global_mesh = global_mesh_starts_with_pp()
if position_ids is None:
if position_ids is None and self.config.sep_parallel_degree > 1:
position_ids = paddle.arange(seq_length, dtype="int64").expand((batch_size, seq_length))

position_ids = dist.shard_tensor(
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)
if position_ids is not None:
position_ids = dist.shard_tensor(
position_ids,
global_mesh,
[dist.Replicate() for _ in range(len(global_mesh._shape))],
)

# embed positions
if attention_mask is None:
Expand Down Expand Up @@ -1005,11 +1007,14 @@ def forward(
position_ids_input = position_ids
attention_mask_input = attention_mask
else:
position_ids_input = dist.reshard(
position_ids,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
if position_ids is not None:
position_ids_input = dist.reshard(
position_ids,
get_mesh(ipp),
[dist.Replicate(), dist.Replicate()],
)
else:
position_ids_input = position_ids
attention_mask_input = (
dist.reshard(
attention_mask,
Expand Down

0 comments on commit ac117a1

Please sign in to comment.