1616from .common import DistributedOperatorImpl
1717from .common import register_distributed_operator
1818from .common import register_distributed_operator_impl
19+ from .common import copy_distributed_attr_for_var
20+ from .common import copy_distributed_attr_for_dist_op
1921from ..utils import is_dim_shard
2022from ..utils import is_dim_replicate
2123from ..utils import is_valid_list_index
@@ -223,13 +225,16 @@ def static_handle(dst_block,
223225 type = core .VarDesc .VarType .LOD_TENSOR ,
224226 persistable = False ,
225227 stop_gradient = X_var .stop_gradient )
228+ # copy X_var's dist_attr to intermediate_var_0's dist_attr
229+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
230+ X_var )
226231
227232 check_variable_and_dtype (
228233 X_var , 'tensor' ,
229234 ['float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
230235 '_c_identity' )
231236
232- dst_block .append_op (
237+ c_identity_op = dst_block .append_op (
233238 type = 'c_identity' ,
234239 inputs = {'X' : [X_var ]},
235240 outputs = {'Out' : intermediate_var_0 },
@@ -250,12 +255,18 @@ def static_handle(dst_block,
250255 'alpha' : 1 ,
251256 }
252257 inputs = {'X' : [intermediate_var_0 ], 'Y' : [Weight_var ]}
253- dst_block .append_op (
258+ matmul_op = dst_block .append_op (
254259 type = 'matmul' ,
255260 inputs = inputs ,
256261 outputs = {'Out' : Out_var },
257262 attrs = attrs )
258263
264+ # copy serial op's dist_attr to dist op's dist_attr
265+ copy_distributed_attr_for_dist_op (c_identity_op , dst_block ,
266+ op_dist_attr )
267+ copy_distributed_attr_for_dist_op (matmul_op , dst_block ,
268+ op_dist_attr )
269+
259270 if in_dygraph_mode ():
260271 raise NotImplementedError (
261272 "Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -369,13 +380,17 @@ def static_handle(dst_block,
369380 persistable = False ,
370381 is_data = False ,
371382 need_check_feed = Out_var .desc .need_check_feed ())
372- dst_block .append_op (
383+ # copy Out_var's dist_attr to intermediate_var_0's dist_attr
384+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
385+ Out_var )
386+
387+ matmul_op = dst_block .append_op (
373388 type = 'matmul' ,
374389 inputs = inputs ,
375390 outputs = {'Out' : intermediate_var_0 },
376391 attrs = attrs )
377392
378- dst_block .append_op (
393+ c_allreduce_sum_op = dst_block .append_op (
379394 type = 'c_allreduce_sum' ,
380395 inputs = {'X' : intermediate_var_0 },
381396 outputs = {'Out' : Out_var },
@@ -385,6 +400,12 @@ def static_handle(dst_block,
385400 'use_model_parallel' : True
386401 })
387402
403+ # copy serial op's dist_attr to dist op's dist_attr
404+ copy_distributed_attr_for_dist_op (matmul_op , dst_block ,
405+ op_dist_attr )
406+ copy_distributed_attr_for_dist_op (c_allreduce_sum_op , dst_block ,
407+ op_dist_attr )
408+
388409 if in_dygraph_mode ():
389410 raise NotImplementedError (
390411 "Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -540,15 +561,12 @@ def static_handle(dst_block,
540561 Out_var = dst_block .var (output_name_mapping ['Out' ][0 ])
541562
542563 # TODO infer logic comm presentation
543- from ..process import new_process_group
544- from ..transpiler import _get_comm_group
545564 model_parallel_axis , process_mesh = op_dist_attr .get_owner_context (
546565 )._get_model_parallel_info ()
547- group_ranks = _get_comm_group (process_mesh .topology ,
548- model_parallel_axis ,
549- process_mesh . process_group , rank_id )
566+ group_ranks = _get_comm_group (process_mesh .process_group ,
567+ process_mesh . topology ,
568+ model_parallel_axis , rank_id )
550569 group = new_process_group (group_ranks )
551- # print("@@@@@@@@@@@@@@@@@@@@@ 5", group)
552570
553571 intermediate_var_0 = dst_block .create_var (
554572 name = unique_name .generate_with_ignorable_key ("." .join (
@@ -558,13 +576,16 @@ def static_handle(dst_block,
558576 type = core .VarDesc .VarType .LOD_TENSOR ,
559577 persistable = False ,
560578 stop_gradient = X_var .stop_gradient )
579+ # copy X_var's dist_attr to intermediate_var_0's dist_attr
580+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
581+ X_var )
561582
562583 check_variable_and_dtype (
563584 X_var , 'tensor' ,
564585 ['float16' , 'float32' , 'float64' , 'int32' , 'int64' ],
565586 '_c_identity' )
566587
567- dst_block .append_op (
588+ c_identity_op = dst_block .append_op (
568589 type = 'c_identity' ,
569590 inputs = {'X' : [X_var ]},
570591 outputs = {'Out' : intermediate_var_0 },
@@ -581,12 +602,18 @@ def static_handle(dst_block,
581602 ['float16' , 'float32' , 'float64' ], 'linear' )
582603 attrs = {'trans_x' : False , 'trans_y' : False }
583604 inputs = {'X' : [intermediate_var_0 ], 'Y' : [Weight_var ]}
584- dst_block .append_op (
605+ matmul_v2_op = dst_block .append_op (
585606 type = 'matmul_v2' ,
586607 inputs = inputs ,
587608 outputs = {'Out' : Out_var },
588609 attrs = attrs )
589610
611+ # copy serial op's dist_attr to dist op's dist_attr
612+ copy_distributed_attr_for_dist_op (c_identity_op , dst_block ,
613+ op_dist_attr )
614+ copy_distributed_attr_for_dist_op (matmul_v2_op , dst_block ,
615+ op_dist_attr )
616+
590617 if in_dygraph_mode ():
591618 raise NotImplementedError (
592619 "Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
@@ -675,15 +702,12 @@ def static_handle(dst_block,
675702 Out_var = dst_block .var (output_name_mapping ['Out' ][0 ])
676703
677704 # TODO infer logic comm presentation
678- from ..process import new_process_group
679- from ..transpiler import _get_comm_group
680705 model_parallel_axis , process_mesh = op_dist_attr .get_owner_context (
681706 )._get_model_parallel_info ()
682- group_ranks = _get_comm_group (process_mesh .topology ,
683- model_parallel_axis ,
684- process_mesh . process_group , rank_id )
707+ group_ranks = _get_comm_group (process_mesh .process_group ,
708+ process_mesh . topology ,
709+ model_parallel_axis , rank_id )
685710 group = new_process_group (group_ranks )
686- # print("@@@@@@@@@@@@@@@@@@@@@ 4", group)
687711
688712 check_variable_and_dtype (
689713 X_var , 'x' , ['float16' , 'float32' , 'float64' ], 'linear' )
@@ -699,13 +723,17 @@ def static_handle(dst_block,
699723 persistable = False ,
700724 is_data = False ,
701725 need_check_feed = Out_var .desc .need_check_feed ())
702- dst_block .append_op (
726+ # copy Out_var's dist_attr to intermediate_var_0's dist_attr
727+ copy_distributed_attr_for_var (op_dist_attr , intermediate_var_0 ,
728+ Out_var )
729+
730+ matmul_v2_op = dst_block .append_op (
703731 type = 'matmul_v2' ,
704732 inputs = inputs ,
705733 outputs = {'Out' : intermediate_var_0 },
706734 attrs = attrs )
707735
708- dst_block .append_op (
736+ c_allreduce_sum_op = dst_block .append_op (
709737 type = 'c_allreduce_sum' ,
710738 inputs = {'X' : intermediate_var_0 },
711739 outputs = {'Out' : Out_var },
@@ -715,6 +743,12 @@ def static_handle(dst_block,
715743 'use_model_parallel' : True
716744 })
717745
746+ # copy serial op's dist_attr to dist op's dist_attr
747+ copy_distributed_attr_for_dist_op (matmul_v2_op , dst_block ,
748+ op_dist_attr )
749+ copy_distributed_attr_for_dist_op (c_allreduce_sum_op , dst_block ,
750+ op_dist_attr )
751+
718752 if in_dygraph_mode ():
719753 raise NotImplementedError (
720754 "Dist op for [{}] with idx [{}] is NOT implemented yet." .format (
0 commit comments