Skip to content

Commit aef9f5e

Browse files
authored
[fluid_ops] collective_global_gather.py remove dynamic_static_unified_comm (#70713)
1 parent f747eef commit aef9f5e

File tree

2 files changed

+8
-18
lines changed

2 files changed

+8
-18
lines changed

test/collective/collective_global_gather.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,8 @@ def run_trainer(self, args):
6262
endpoints = args["endpoints"].split(",")
6363
rank = args["trainerid"]
6464
current_endpoint = args["currentendpoint"]
65-
if args["dynamic_static_unified_comm"]:
66-
paddle.distributed.collective._init_parallel_env(args["backend"])
67-
else:
68-
paddle.distributed.init_parallel_env()
65+
66+
paddle.distributed.collective._init_parallel_env(args["backend"])
6967
nranks = 2
7068
if args['backend'] == 'nccl':
7169
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
@@ -112,11 +110,8 @@ def run_trainer(self, args):
112110
)
113111

114112
if args['static_mode']:
115-
result = (
116-
self.get_model(train_prog, startup_prog, rank)
117-
if args["dynamic_static_unified_comm"]
118-
else self.get_model(train_prog, startup_prog, rank)
119-
)
113+
result = self.get_model(train_prog, startup_prog, rank)
114+
120115
fetch_list = []
121116
for elem in result:
122117
fetch_list.append(elem.name)

test/collective/collective_global_scatter.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,8 @@ def run_trainer(self, args):
6363
rank = args["trainerid"]
6464
current_endpoint = args["currentendpoint"]
6565
nranks = 2
66-
if args["dynamic_static_unified_comm"]:
67-
paddle.distributed.collective._init_parallel_env(args["backend"])
68-
else:
69-
paddle.distributed.init_parallel_env()
66+
67+
paddle.distributed.collective._init_parallel_env(args["backend"])
7068
if args['backend'] == 'nccl':
7169
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
7270
place = base.CUDAPlace(
@@ -90,11 +88,8 @@ def run_trainer(self, args):
9088
"float32"
9189
)
9290
if args['static_mode']:
93-
result = (
94-
self.get_model(train_prog, startup_prog, rank)
95-
if args["dynamic_static_unified_comm"]
96-
else self.get_model(train_prog, startup_prog, rank)
97-
)
91+
result = self.get_model(train_prog, startup_prog, rank)
92+
9893
exe = base.Executor(place)
9994
exe.run(startup_prog)
10095
fetch_list = []

0 commit comments

Comments
 (0)