Skip to content

Commit

Permalink
[Cherry-pick] add comm buffer size (#8963) (#9031)
Browse files Browse the repository at this point in the history
* add comm buffer size (#8963)

* add doc
  • Loading branch information
ForFishes authored Aug 29, 2024
1 parent 5eb8d03 commit ae691e2
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/trainer.md
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,14 @@ Trainer 是一个简单,但功能完整的 Paddle 训练和评估模块,并
with 8 cards, then set sharding_degree=8, sharding will only communication inside machine.
default -1 means sharding parameters between all workers. (`int`, *optional*, defaults to `-1`)
--sharding_comm_buffer_size_MB
设置sharding的通信中fuse梯度的大小。此选项只在sharding选项开启时候生效。
默认值为-1,表示所有通信fuse的梯度大小按照默认配置,默认配置是256MB。
(`int`, 可选, 默认为 `-1`)
Set the size of the fuse gradient in sharding communication. This option only takes effect when the sharding option is turned on.The default value is -1, which means that the gradient size of all communication fuses follows the default configuration, which is 256MB.
(`int`, optional, default `-1`)
--tensor_parallel_degree
张量并行是Megatron论文针对Transformer结构的张量切分方法.
此方法将一层transformer的计算划分到了不同卡上.
Expand Down
16 changes: 16 additions & 0 deletions paddlenlp/trainer/training_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,17 @@ class TrainingArguments:
)
},
)
sharding_comm_buffer_size_MB: int = field(
default=-1,
metadata={
"help": (
"Set the size of the fuse gradient in sharding communication. This option only takes effect when "
"the sharding option is turned on.The default value is -1, which means that the gradient size of "
"all communication fuses follows the default configuration, which is 256MB. "
)
},
)

save_sharded_model: bool = field(
default=False,
metadata={
Expand Down Expand Up @@ -1293,6 +1304,11 @@ def is_segment_parallel_supported():
)

try:
if self.sharding_comm_buffer_size_MB > 0:
strategy.hybrid_configs["sharding_configs"].comm_buffer_size_MB = int(
self.sharding_comm_buffer_size_MB
)

if "split_param" in sharding_parallel_config:
strategy.hybrid_configs["sharding_configs"].split_param = True

Expand Down

0 comments on commit ae691e2

Please sign in to comment.