Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 24 additions & 13 deletions python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,14 +256,7 @@ def is_pipeline_last_stage(self, ignore_virtual=False):
def set_virtual_pipeline_rank(self, rank):
self._virtual_pp_rank = rank

def bw_hook_func(self, buffer, param):
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param)

return fused_allreduce

def register_allreduce_overlap_hook(
def fused_gradient(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
if model.get_num_virtual_stages() > 1:
Expand All @@ -275,7 +268,7 @@ def register_allreduce_overlap_hook(
assert hasattr(self, "optimizer")
assert hasattr(self.optimizer, "_param2rank")
_param2rank = self.optimizer._param2rank
# Note: after sharding change to reduce operation, here need to be cleared

act = (
HOOK_ACTION.ALL_REDUCE
if (dp or not g_shard_use_reduce)
Expand Down Expand Up @@ -311,16 +304,34 @@ def register_allreduce_overlap_hook(
if not dp:
# parse the relative dst rank to absolute dst rank for sharding
dst = comm_group.ranks[dst]

var_groups = assign_group_by_size(parameter_list, group_size)
for group_idx, parameters in var_groups.items():
buffer = FusedCommBuffer(
group_idx, parameters, comm_group, acc_steps, act, dst
)
self._chunk_2_comm_buffers[chunk_idx].append(buffer)
for param in parameters:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)

return self._chunk_2_comm_buffers

def bw_hook_func(self, buffer, param):
@paddle.autograd.no_grad()
def fused_allreduce(*_):
buffer.add_grad(param)

return fused_allreduce

def register_allreduce_overlap_hook(
self, model, comm_group, acc_steps, dp, group_size=128 * 1024 * 1024
):
# register hook
self.fused_gradient(model, comm_group, acc_steps, dp, group_size)
for _, buffers in self._chunk_2_comm_buffers.items():
for buffer in buffers:
for param in buffer._params:
param._register_backward_hook(
self.bw_hook_func(buffer, param)
)

def timer_printer(self):
if not self._enable_timer:
Expand Down