Skip to content

Commit

Permalink
[Distributed] Support pp non batch comm (#8097) (#8222)
Browse files Browse the repository at this point in the history
* add disable_non_batch_p2p_comm to pipeline_parallel_config
  • Loading branch information
SylarTiaNII authored Apr 2, 2024
1 parent 7b493a8 commit 2273ee7
Showing 1 changed file with 5 additions and 0 deletions.
5 changes: 5 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ class TrainingArguments:
enable_release_grads, reduce peak memory usage by releasing gradients after each iteration. The creation of gradients will be postponed until backward propagation of the next iteration.
enable_overlap_p2p_comm, overlap p2p communication with computation.
enable_clear_every_step_cache, clear every step cache for pipeline parallel.
disable_non_batch_p2p_comm, disable batched send/recv in pipeline parallel mode.
sharding_parallel_config (`str`, *optional*)(
Some additional config it highly affect the useage of sharding parallel, we provide some option to config it.
following config is support:
Expand Down Expand Up @@ -616,6 +617,7 @@ class TrainingArguments:
"enable_sharding_comm_overlap, fuse sharding stage 1 parallel gradient communication. \n"
"enable_overlap_p2p_comm, overlap p2p communication with computation. \n"
"enable_clear_every_step_cache, clear every step cache for pipeline parallel. \n"
"disable_batch_p2p_comm, disable batched send/recv in pipeline parallel mode. \n"
)
},
)
Expand Down Expand Up @@ -993,6 +995,7 @@ def __post_init__(self):
"enable_dp_comm_overlap",
"enable_clear_every_step_cache",
"enable_overlap_p2p_comm",
"disable_batch_p2p_comm",
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is disable_p2p_cache_shape, disable_partial_send_recv."
Expand Down Expand Up @@ -1025,6 +1028,7 @@ def __post_init__(self):
"release_gradients": "enable_release_grads" in pipeline_parallel_config,
"overlap_p2p_comm": "enable_overlap_p2p_comm" in pipeline_parallel_config,
"clear_every_step_cache": "enable_clear_every_step_cache" in pipeline_parallel_config,
"use_batch_p2p_comm": "disable_batch_p2p_comm" not in pipeline_parallel_config,
}
if dygraph_pp_configs["dp_comm_overlap"]:
raise ValueError("overlap has accuracy issue") # TODO: fix `overalap` + `delay_scale` issue
Expand Down Expand Up @@ -1249,6 +1253,7 @@ def is_segment_parallel_supported():
# "enable_dp_comm_overlap", # no implemenation for auto_parallel
# "enable_sharding_comm_overlap", # no implemenation for auto_parallel
# "enable_timer", # no implemenation for auto_parallel
# "disable_batch_p2p_comm", # no implemenation for auto_parallel
]:
raise ValueError(
f"Found unknown pipeline mode config {x}, accpet config is enable_send_recv_overlap."
Expand Down

0 comments on commit 2273ee7

Please sign in to comment.