Skip to content

Commit

Permalink
[Distributed] Support pp non batch comm (#8097)
Browse files Browse the repository at this point in the history
* add disable_non_batch_p2p_comm to pipeline_parallel_config
  • Loading branch information
SylarTiaNII authored Mar 14, 2024
1 parent dc19e4d commit d7b5939
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ class TrainingArguments:
enable_dp_comm_overlap, fuse data parallel gradient communication.
enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication.
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
disable_non_batch_p2p_comm, disable batched send/recv in pipeline parallel mode.
sharding_parallel_config (`str`, *optional*)(
Some additional config it highly affect the useage of sharding parallel, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -591,6 +592,7 @@ class TrainingArguments:
"enable_delay_scale_loss, accumulate gradients util optimizer step, all gradients div by inner pipeline accumute step. instead of div accumute step on loss directly.\n"
"enable_dp_comm_overlap, fuse data parallel gradient communication. \n"
"enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n"
"disable_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. \n"
)
},
)
Expand Down Expand Up @@ -950,6 +952,7 @@ def __post_init__(self):
"enable_sharding_comm_overlap",
"enable_timer",
"enable_release_grads",
"disable_batch_p2p_comm",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
Expand All @@ -971,6 +974,7 @@ def __post_init__(self):
and self.sharding_parallel_degree > 1,
"enable_timer": "enable_timer" in pipeline_parallel_config,
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
}
if dygraph_pp_configs["dp_comm_overlap"]:
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue
Expand Down Expand Up @@ -1165,6 +1169,7 @@ def is_segment_parallel_supported():
# "enable_dp_comm_overlap", # no implemenation for auto_parallel
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
# "enable_timer", # no implemenation for auto_parallel
# "disable_batch_p2p_comm", # no implemenation for auto_parallel
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is enable_send_recv_overlap."
Expand Down

0 comments on commit d7b5939

Please sign in to comment.