Skip to content

Commit e40f69d

Browse files
committed
Fix
1 parent 38f6a5a commit e40f69d

17 files changed

+155
-114
lines changed

python/paddle/distributed/auto_parallel/static/cost/base_cost.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
"recv_v2",
3030
"broadcast",
3131
"all_gather",
32+
"all_reduce",
3233
"c_allreduce_sum",
3334
"c_identity",
3435
]
@@ -311,7 +312,10 @@ def build_comm_desc_from_dist_op(
311312
input_list.append((var.dtype, shape))
312313

313314
# NOTE: The input_name of comm ops used usually is X.
314-
desc["inputs"] = {"X": input_list}
315+
if op_type == "all_reduce":
316+
desc["inputs"] = {"x": input_list}
317+
else:
318+
desc["inputs"] = {"X": input_list}
315319

316320
# Get comm group by parallel_axis or the given group_ranks.
317321
if parallel_axis is not None:
@@ -349,7 +353,10 @@ def build_comm_desc(op_type, group_ranks, dtype, shape, attrs=None):
349353
desc = {}
350354
desc["op"] = op_type
351355
desc["group_ranks"] = group_ranks
352-
desc["inputs"] = {"X": [(dtype, shape)]}
356+
if op_type == "all_reduce":
357+
desc["inputs"] = {"x": [(dtype, shape)]}
358+
else:
359+
desc["inputs"] = {"X": [(dtype, shape)]}
353360
desc["attrs"] = attrs
354361
return desc
355362

@@ -416,19 +423,19 @@ def build_dp_costs(
416423
if not has_found:
417424
return
418425

419-
c_allreduce_sum_descs = build_comm_desc_from_dist_op(
420-
"c_allreduce_sum",
426+
all_reduce_sum_descs = build_comm_desc_from_dist_op(
427+
"all_reduce",
421428
dist_op,
422429
ctx,
423430
var_names,
424431
attrs=attrs,
425432
parallel_axis=parallel_axis,
426433
)
427434
comm_cost_list = build_comm_costs_from_descs(
428-
_g_op_cost_factory["c_allreduce_sum"],
435+
_g_op_cost_factory["all_reduce"],
429436
ctx,
430437
processes,
431-
c_allreduce_sum_descs,
438+
all_reduce_sum_descs,
432439
cluster,
433440
is_dp=True,
434441
)

python/paddle/distributed/auto_parallel/static/cost/comm_op_cost.py

Lines changed: 33 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
@register_op_cost
2525
class AllreduceSumOpCost(CommOpCost):
26-
OP_TYPE = "c_allreduce_sum"
26+
OP_TYPE = "all_reduce"
2727

2828
def __init__(self, op=None, op_desc=None, comm_context=None):
2929
super().__init__(op=op, op_desc=op_desc, comm_context=comm_context)
@@ -82,6 +82,38 @@ def calc_time_tree(self):
8282

8383
return time
8484

85+
@property
86+
def comm_count(self):
87+
from ..reshard import get_var_with_recursion
88+
89+
if self._comm_count is None:
90+
dtype = None
91+
shape = None
92+
if self.op is not None:
93+
vars = self.op.block.vars
94+
try:
95+
var_name = self.op.input("x")[0]
96+
except:
97+
var_name = self.op.output("out")[0]
98+
var = get_var_with_recursion(
99+
var_name, self.op.block, self.op.block.program
100+
)
101+
dtype = var.dtype
102+
shape = var.shape
103+
elif self.op_desc is not None:
104+
dtype = self.op_desc["inputs"]["x"][0][0]
105+
shape = self.op_desc["inputs"]["x"][0][1]
106+
107+
factor = None
108+
if dtype == paddle.float32 or dtype == paddle.int32:
109+
factor = 4
110+
else:
111+
raise ValueError(f"Unsupported comm dtype {dtype}")
112+
comm_count = int(np.prod(shape)) * factor
113+
self._comm_count = comm_count
114+
115+
return self._comm_count
116+
85117

86118
@register_op_cost
87119
class AllgatherOpCost(CommOpCost):

python/paddle/distributed/auto_parallel/static/mapper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def get_comm_volume(comm_op, src_rank, tgt_rank):
104104
new_tensor_shape.append(val)
105105
tensor_size = functools.reduce(operator.mul, new_tensor_shape, 1)
106106
tensor_bytes = tensor_size * get_dtype_bytes(tensor.dtype)
107-
if "c_allreduce" in comm_op_type:
107+
if "c_allreduce" in comm_op_type or "all_reduce" in comm_op_type:
108108
comm_volume = 2 * tensor_bytes
109109
elif "all_gather" in comm_op_type:
110110
comm_volume = tensor_bytes

python/paddle/distributed/auto_parallel/static/operators/common.py

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -511,17 +511,17 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):
511511
dist_op_context = dist_ctx.dist_op_context
512512
main_block = dist_op_context.work_block
513513

514-
allreduce_type = "c_allreduce_sum"
514+
op_type = dist.ReduceOp.SUM
515515
need_scale = dist_ctx.gradient_scale
516516
scale_using_allreduce_avg = dist_ctx.gradient_scale_using_allreduce_avg
517517

518-
# With nccl_version > 2.10.00, we can use c_allreduce_avg to replace c_allreduce_sum and eliminate the scale op.
518+
# With nccl_version > 2.10.00, we can use all_reduce_avg to replace all_reduce_sum and eliminate the scale op.
519519
if (
520520
need_scale
521521
and scale_using_allreduce_avg
522522
and int(paddle.version.nccl()) > 21000
523523
):
524-
allreduce_type = "c_allreduce_avg"
524+
op_type = dist.ReduceOp.AVG
525525
need_scale = False
526526

527527
for group in groups:
@@ -531,12 +531,12 @@ def sync_and_scale_gradients(dist_ctx, op, groups, allreduce_var_names):
531531
added_ops = []
532532
grad_var = main_block.var(var_name)
533533
allreduce_op = main_block.append_op(
534-
type=allreduce_type,
535-
inputs={'X': [grad_var]},
536-
outputs={'Out': [grad_var]},
534+
type='all_reduce',
535+
inputs={'x': [grad_var]},
536+
outputs={'out': [grad_var]},
537537
attrs={
538538
'ring_id': group.id,
539-
'use_calc_stream': True,
539+
'op_type': op_type,
540540
OP_ROLE_KEY: OpRole.Backward,
541541
},
542542
)
@@ -670,9 +670,11 @@ def is_data_parallel_scale_op(op):
670670

671671

672672
def is_data_parallel_reduce_op(op):
673-
is_allreduce_op = op.type in [
674-
"c_allreduce_sum",
675-
"c_allreduce_avg",
673+
is_allreduce_op = op.type == "all_reduce" and op.desc.attr(
674+
"reduce_type"
675+
) in [
676+
dist.ReduceOp.SUM,
677+
dist.ReduceOp.AVG,
676678
]
677679
is_reduce_op = op.type == "reduce" and op.desc.attr("reduce_type") in [
678680
dist.ReduceOp.SUM,
@@ -695,7 +697,8 @@ def is_amp_flag_sync_op(op):
695697

696698
def is_global_norm_sync_op(op):
697699
return (
698-
op.type == "c_allreduce_sum"
700+
op.type == "all_reduce"
701+
and op.desc.attr("op_type") == dist.ReduceOp.SUM
699702
and op.desc.has_attr("op_namescope")
700703
and SyncMode.GlobalNormSync in op.desc.attr("op_namescope")
701704
)

python/paddle/distributed/auto_parallel/static/operators/dist_default.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License
1414

1515

16+
import paddle.distributed as dist
1617
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
1718

1819
from ..completion import contains_spmd_rule, get_phi_spmd_rule
@@ -67,12 +68,12 @@ def prim_operator_data_parallel_functor(ctx, src_op):
6768
sync_group = new_process_group(ctx.data_parallel_group)
6869

6970
allreduce_op = main_block.append_op(
70-
type='c_allreduce_sum',
71-
inputs={'X': [var_name]},
72-
outputs={'Out': [var_name]},
71+
type='all_reduce',
72+
inputs={'x': [var_name]},
73+
outputs={'out': [var_name]},
7374
attrs={
7475
'ring_id': sync_group.id,
75-
'use_calc_stream': True,
76+
'op_type': dist.ReduceOp.SUM,
7677
OP_ROLE_KEY: OpRole.Backward,
7778
},
7879
)

python/paddle/distributed/auto_parallel/static/operators/dist_embedding.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License
1414

15+
import paddle.distributed as dist
1516
from paddle.common_ops_import import check_variable_and_dtype
1617
from paddle.distributed.auto_parallel.static.cost.comm_op_cost import (
1718
AllreduceSumOpCost,
@@ -246,10 +247,10 @@ def calc_fwd_cost(self, dist_op, ctx, cluster):
246247
parallel_axis = dist_op.dist_attr.get_input_dims_mapping(
247248
serial_op.input("W")[0]
248249
)[0]
249-
attrs = {"use_calc_stream": True, "use_model_parallel": True}
250+
attrs = {"op_type": dist.ReduceOp.SUM}
250251
var_names = serial_op.output("Out")
251-
c_allreduce_sum_desc_mapping = build_comm_desc_from_dist_op(
252-
"c_allreduce_sum",
252+
all_reduce_sum_desc_mapping = build_comm_desc_from_dist_op(
253+
"all_reduce",
253254
dist_op,
254255
ctx,
255256
var_names,
@@ -261,7 +262,7 @@ def calc_fwd_cost(self, dist_op, ctx, cluster):
261262
AllreduceSumOpCost,
262263
ctx,
263264
processes,
264-
c_allreduce_sum_desc_mapping,
265+
all_reduce_sum_desc_mapping,
265266
cluster,
266267
)
267268

@@ -510,23 +511,22 @@ def forward(ctx, *args, **kwargs):
510511
naive_copy_op_dist_attr_for_program(c_embedding_op, src_op, ctx)
511512

512513
# use_model_parallel
513-
c_allreduce_sum_op = main_block.append_op(
514-
type='c_allreduce_sum',
515-
inputs={'X': [Out_var]},
516-
outputs={'Out': [Out_var]},
514+
all_reduce_sum_op = main_block.append_op(
515+
type='all_reduce',
516+
inputs={'x': [Out_var]},
517+
outputs={'out': [Out_var]},
517518
attrs={
518519
'ring_id': group.id,
519-
'use_calc_stream': True,
520-
'use_model_parallel': True,
520+
'op_type': dist.ReduceOp.SUM,
521521
OP_ROLE_KEY: src_op.attr('op_role'),
522522
},
523523
)
524-
c_allreduce_sum_op._set_attr(
524+
all_reduce_sum_op._set_attr(
525525
'op_namescope', '/' + ParallelMode.TensorParallel
526526
)
527527
# allreduce
528528
set_comm_op_dist_attr_for_program(
529-
c_allreduce_sum_op,
529+
all_reduce_sum_op,
530530
op_dist_attr.process_mesh,
531531
out_var_dist_attr,
532532
ctx,

0 commit comments

Comments
 (0)