diff --git a/models/benchmark/ctr_dnn/backend.yaml b/benchmark/ctr_dnn/backend.yaml similarity index 100% rename from models/benchmark/ctr_dnn/backend.yaml rename to benchmark/ctr_dnn/backend.yaml diff --git a/models/benchmark/ctr_dnn/config.yaml b/benchmark/ctr_dnn/config.yaml similarity index 87% rename from models/benchmark/ctr_dnn/config.yaml rename to benchmark/ctr_dnn/config.yaml index 5741a7dce..b6501df89 100644 --- a/models/benchmark/ctr_dnn/config.yaml +++ b/benchmark/ctr_dnn/config.yaml @@ -25,25 +25,22 @@ hyper_parameters: dense_feature_dim: 13 fc_sizes: [400, 400, 400] -mode: [local_train] +mode: [collective] runner: - name: ps_cpu - class: cluster_train - epochs: 10 + class: local_cluster_train + epochs: 1 device: cpu fleet_mode: ps - save_checkpoint_interval: 1 - save_checkpoint_path: "increment_dnn" - print_interval: 1 + print_interval: 10 phases: [phase1] -- name: ps_gpu - class: cluster_train +- name: collective + class: single epochs: 10 device: gpu - fleet_mode: ps - save_checkpoint_interval: 1 - save_checkpoint_path: "increment_dnn" + fleet_mode: collective + selected_gpus: "0,1" print_interval: 1 phases: [phase1] @@ -74,7 +71,7 @@ runner: phase: - name: phase1 model: "{workspace}/model.py" - dataset_name: dataset_train + dataset_name: dataloader_train thread_num: 1 - name: phase2 diff --git a/models/benchmark/ctr_dnn/dataset_generator.py b/benchmark/ctr_dnn/dataset_generator.py similarity index 100% rename from models/benchmark/ctr_dnn/dataset_generator.py rename to benchmark/ctr_dnn/dataset_generator.py diff --git a/models/benchmark/ctr_dnn/download_data.sh b/benchmark/ctr_dnn/download_data.sh similarity index 100% rename from models/benchmark/ctr_dnn/download_data.sh rename to benchmark/ctr_dnn/download_data.sh diff --git a/models/benchmark/ctr_dnn/model.py b/benchmark/ctr_dnn/model.py similarity index 97% rename from models/benchmark/ctr_dnn/model.py rename to benchmark/ctr_dnn/model.py index 9f9b43bda..a0ed93e85 100644 --- a/models/benchmark/ctr_dnn/model.py +++ b/benchmark/ctr_dnn/model.py @@ -69,7 +69,6 @@ def embedding_layer(input): sparse_embed_seq = list(map(embedding_layer, self.sparse_input)) concated = paddle.concat(sparse_embed_seq + [self.dense_input], axis=1) - fluid.layers.Print(concated, message="concated") fc1 = paddle.static.nn.fc( x=concated, @@ -78,7 +77,6 @@ def embedding_layer(input): name="fc1", weight_attr=paddle.ParamAttr(initializer=fluid.initializer.Normal( scale=1.0 / math.sqrt(concated.shape[1])))) - fluid.layers.Print(fc1, message="fc1") fc2 = paddle.static.nn.fc( x=fc1, diff --git a/models/benchmark/simnet_bow/config.yaml b/benchmark/simnet_bow/config.yaml similarity index 100% rename from models/benchmark/simnet_bow/config.yaml rename to benchmark/simnet_bow/config.yaml diff --git a/benchmark/simnet_bow/dataset_generator.py b/benchmark/simnet_bow/dataset_generator.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/benchmark/simnet_bow/dataset_generator.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmark/simnet_bow/model.py b/benchmark/simnet_bow/model.py new file mode 100644 index 000000000..74721e038 --- /dev/null +++ b/benchmark/simnet_bow/model.py @@ -0,0 +1,147 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.fluid as fluid + +from paddlerec.core.utils import envs +from paddlerec.core.model import ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.dense_feature_dim = envs.get_global_env( + "hyper_parameters.dense_feature_dim") + self.sparse_feature_number = envs.get_global_env( + "hyper_parameters.sparse_feature_number") + self.sparse_feature_dim = envs.get_global_env( + "hyper_parameters.sparse_feature_dim") + self.learning_rate = envs.get_global_env( + "hyper_parameters.optimizer.learning_rate") + + def input_data(self, is_infer=False, **kwargs): + q = fluid.layers.data( + name="query", shape=[1], dtype="int64", lod_level=1) + pt = fluid.layers.data( + name="pos_title", shape=[1], dtype="int64", lod_level=1) + nt = fluid.layers.data( + name="neg_title", shape=[1], dtype="int64", lod_level=1) + + inputs = [q, pt, nt] + return inputs + + def net(self, input, is_infer=False): + dict_dim = self.dict_dim + emb_dim = self.emb_dim + hid_dim = self.hid_dim + base_lr = self.learning_rate + emb_lr = self.learning_rate * 3 + + q = input[0] + pt = input[1] + nt = input[2] + + q_emb = fluid.layers.embedding( + input=q, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + name="__emb__", learning_rate=emb_lr), + is_sparse=is_sparse) + # vsum + q_sum = fluid.layers.sequence_pool(input=q_emb, pool_type='sum') + q_ss = fluid.layers.softsign(q_sum) + # fc layer after conv + q_fc = fluid.layers.fc(input=q_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + name="__q_fc__", + learning_rate=base_lr, + initializer=fluid.initializer.Xavier())) + # embedding + pt_emb = fluid.layers.embedding( + input=pt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + name="__emb__", + learning_rate=emb_lr, + initializer=fluid.initializer.Xavier()), + is_sparse=is_sparse) + # vsum + pt_sum = fluid.layers.sequence_pool(input=pt_emb, pool_type='sum') + pt_ss = fluid.layers.softsign(pt_sum) + # fc layer + pt_fc = fluid.layers.fc(input=pt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + name="__fc__", + learning_rate=base_lr, + initializer=fluid.initializer.Xavier()), + bias_attr=fluid.ParamAttr( + name="__fc_b__", + initializer=fluid.initializer.Xavier())) + + # embedding + nt_emb = fluid.layers.embedding( + input=nt, + size=[dict_dim, emb_dim], + param_attr=fluid.ParamAttr( + name="__emb__", + learning_rate=emb_lr, + initializer=fluid.initializer.Xavier()), + is_sparse=is_sparse) + + # vsum + nt_sum = fluid.layers.sequence_pool(input=nt_emb, pool_type='sum') + nt_ss = fluid.layers.softsign(nt_sum) + # fc layer + nt_fc = fluid.layers.fc(input=nt_ss, + size=hid_dim, + param_attr=fluid.ParamAttr( + name="__fc__", + learning_rate=base_lr, + initializer=fluid.initializer.Xavier()), + bias_attr=fluid.ParamAttr( + name="__fc_b__", + initializer=fluid.initializer.Xavier())) + cos_q_pt = fluid.layers.cos_sim(q_fc, pt_fc) + cos_q_nt = fluid.layers.cos_sim(q_fc, nt_fc) + # loss + avg_cost = self.get_loss(cos_q_pt, cos_q_nt, params) + + def get_loss(self, cos_q_pt, cos_q_nt): + loss_op1 = fluid.layers.elementwise_sub( + fluid.layers.fill_constant_batch_size_like( + input=cos_q_pt, + shape=[-1, 1], + value=params.margin, + dtype='float32'), + cos_q_pt) + loss_op2 = fluid.layers.elementwise_add(loss_op1, cos_q_nt) + loss_op3 = fluid.layers.elementwise_max( + fluid.layers.fill_constant_batch_size_like( + input=loss_op2, shape=[-1, 1], value=0.0, dtype='float32'), + loss_op2) + avg_cost = fluid.layers.mean(loss_op3) + return avg_cost + + def optimizer(self): + optimizer = paddle.optimizer.SGD(self.learning_rate) + return optimizer + + def infer_net(self): + pass diff --git a/models/benchmark/word2vec/config.yaml b/benchmark/word2vec/config.yaml similarity index 100% rename from models/benchmark/word2vec/config.yaml rename to benchmark/word2vec/config.yaml diff --git a/benchmark/word2vec/dataset_generator.py b/benchmark/word2vec/dataset_generator.py new file mode 100644 index 000000000..abf198b97 --- /dev/null +++ b/benchmark/word2vec/dataset_generator.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/benchmark/word2vec/model.py b/benchmark/word2vec/model.py new file mode 100644 index 000000000..d535b532e --- /dev/null +++ b/benchmark/word2vec/model.py @@ -0,0 +1,140 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import paddle +import paddle.fluid as fluid + +from paddlerec.core.utils import envs +from paddlerec.core.model import ModelBase + + +class Model(ModelBase): + def __init__(self, config): + ModelBase.__init__(self, config) + + def _init_hyper_parameters(self): + self.dense_feature_dim = envs.get_global_env( + "hyper_parameters.dense_feature_dim") + self.sparse_feature_number = envs.get_global_env( + "hyper_parameters.sparse_feature_number") + self.sparse_feature_dim = envs.get_global_env( + "hyper_parameters.sparse_feature_dim") + self.learning_rate = envs.get_global_env( + "hyper_parameters.optimizer.learning_rate") + + def input_data(self, is_infer=False, **kwargs): + + input_word = fluid.layers.data( + name="input_word", shape=[1], dtype='int64', lod_level=1) + true_word = fluid.layers.data( + name='true_label', shape=[1], dtype='int64', lod_level=1) + neg_word = fluid.layers.data( + name="neg_label", shape=[1], dtype='int64', lod_level=1) + inputs = [input_word, true_word, neg_word] + return inputs + + def net(self, input, is_infer=False): + + init_width = 0.5 / params.embedding_size + input_emb = fluid.layers.embedding( + input=inputs[0], + is_sparse=params.is_sparse, + size=[params.dict_size, params.embedding_size], + param_attr=fluid.ParamAttr( + name='emb', + initializer=fluid.initializer.Uniform(-init_width, + init_width))) + + true_emb_w = fluid.layers.embedding( + input=inputs[1], + is_sparse=params.is_sparse, + size=[params.dict_size, params.embedding_size], + param_attr=fluid.ParamAttr( + name='emb_w', + initializer=fluid.initializer.Constant(value=0.0))) + + true_emb_b = fluid.layers.embedding( + input=inputs[1], + is_sparse=params.is_sparse, + size=[params.dict_size, 1], + param_attr=fluid.ParamAttr( + name='emb_b', + initializer=fluid.initializer.Constant(value=0.0))) + + neg_word_reshape = fluid.layers.reshape(inputs[2], shape=[-1, 1]) + neg_word_reshape.stop_gradient = True + + neg_emb_w = fluid.layers.embedding( + input=neg_word_reshape, + is_sparse=params.is_sparse, + size=[params.dict_size, params.embedding_size], + param_attr=fluid.ParamAttr( + name='emb_w', learning_rate=1.0)) + + neg_emb_w_re = fluid.layers.reshape( + neg_emb_w, shape=[-1, params.nce_num, params.embedding_size]) + + neg_emb_b = fluid.layers.embedding( + input=neg_word_reshape, + is_sparse=params.is_sparse, + size=[params.dict_size, 1], + param_attr=fluid.ParamAttr( + name='emb_b', learning_rate=1.0)) + + neg_emb_b_vec = fluid.layers.reshape( + neg_emb_b, shape=[-1, params.nce_num]) + + true_logits = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + fluid.layers.elementwise_mul(input_emb, true_emb_w), + dim=1, + keep_dim=True), + true_emb_b) + + input_emb_re = fluid.layers.reshape( + input_emb, shape=[-1, 1, params.embedding_size]) + + neg_matmul = fluid.layers.matmul( + input_emb_re, neg_emb_w_re, transpose_y=True) + neg_matmul_re = fluid.layers.reshape( + neg_matmul, shape=[-1, params.nce_num]) + neg_logits = fluid.layers.elementwise_add(neg_matmul_re, neg_emb_b_vec) + # nce loss + + label_ones = fluid.layers.fill_constant_batch_size_like( + true_logits, shape=[-1, 1], value=1.0, dtype='float32') + label_zeros = fluid.layers.fill_constant_batch_size_like( + true_logits, + shape=[-1, params.nce_num], + value=0.0, + dtype='float32') + + true_xent = fluid.layers.sigmoid_cross_entropy_with_logits(true_logits, + label_ones) + neg_xent = fluid.layers.sigmoid_cross_entropy_with_logits(neg_logits, + label_zeros) + cost = fluid.layers.elementwise_add( + fluid.layers.reduce_sum( + true_xent, dim=1), + fluid.layers.reduce_sum( + neg_xent, dim=1)) + avg_cost = fluid.layers.reduce_mean(cost) + + def optimizer(self): + optimizer = paddle.optimizer.Adam(self.learning_rate, lazy_mode=True) + return optimizer + + def infer_net(self): + pass diff --git a/core/engine/local_cluster.py b/core/engine/local_cluster.py index b6ff736ee..dab3eb583 100755 --- a/core/engine/local_cluster.py +++ b/core/engine/local_cluster.py @@ -20,6 +20,8 @@ import sys import subprocess import logging +import tempfile +import shutil from paddlerec.core.engine.engine import Engine from paddlerec.core.utils import envs @@ -45,6 +47,10 @@ def start_procs(self): procs = [] log_fns = [] + self.gloo_rendezvous_dir = tempfile.mkdtemp() + gloo_http_port = str(envs.find_free_port()) + self.gloo_endpoints = ":".join(["127.0.0.1", gloo_http_port]) + if fleet_mode.upper() == "PS": for i in range(server_num - 1): while True: @@ -70,7 +76,11 @@ def start_procs(self): "PADDLE_PORT": user_endpoints_port[i], "TRAINING_ROLE": "PSERVER", "PADDLE_TRAINERS_NUM": str(worker_num), - "POD_IP": user_endpoints_ips[i] + "POD_IP": user_endpoints_ips[i], + "PADDLE_WITH_GLOO": "1", + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": self.gloo_endpoints }) os.system("mkdir -p {}".format(logs_dir)) @@ -89,7 +99,11 @@ def start_procs(self): "PADDLE_PSERVERS_IP_PORT_LIST": user_endpoints, "PADDLE_TRAINERS_NUM": str(worker_num), "TRAINING_ROLE": "TRAINER", - "PADDLE_TRAINER_ID": str(i) + "PADDLE_TRAINER_ID": str(i), + "PADDLE_WITH_GLOO": "1", + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, + "PADDLE_GLOO_HTTP_ENDPOINT": self.gloo_endpoints }) os.system("mkdir -p {}".format(logs_dir)) @@ -116,8 +130,8 @@ def start_procs(self): cuda_visible_devices_list = cuda_visible_devices.split(',') for x in self.envs["selected_gpus"].split(","): assert x in cuda_visible_devices_list, "Can't find "\ - "your selected_gpus %s in CUDA_VISIBLE_DEVICES[%s]."\ - % (x, cuda_visible_devices) + "your selected_gpus %s in CUDA_VISIBLE_DEVICES[%s]."\ + % (x, cuda_visible_devices) selected_gpus = [ cuda_visible_devices_list.index(x.strip()) for x in self.envs["selected_gpus"].split(",") @@ -153,7 +167,10 @@ def start_procs(self): "TRAINING_ROLE": "TRAINER", "PADDLE_TRAINER_ID": str(i), "FLAGS_selected_gpus": str(selected_gpus[i]), - "PADDLEREC_GPU_NUMS": str(selected_gpus_num) + "PADDLEREC_GPU_NUMS": str(selected_gpus_num), + "PADDLE_WITH_GLOO": "1", + "PADDLE_GLOO_RENDEZVOUS": "3", + "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, }) os.system("mkdir -p {}".format(logs_dir)) @@ -183,6 +200,8 @@ def start_procs(self): "all workers already completed, you can view logs under the `{}` directory". format(logs_dir), file=sys.stderr) + if os.path.exists(self.gloo_rendezvous_dir): + shutil.rmtree(self.gloo_rendezvous_dir) def run(self): self.start_procs() diff --git a/core/trainers/framework/network.py b/core/trainers/framework/network.py index 8e1072197..2bd259b2a 100644 --- a/core/trainers/framework/network.py +++ b/core/trainers/framework/network.py @@ -194,7 +194,6 @@ def build_network(self, context): class FleetNetwork(NetworkBase): def __init__(self, context): print("Running FleetNetwork.") - pass def build_network(self, context): context["model"] = {} diff --git a/core/trainers/framework/runner.py b/core/trainers/framework/runner.py index eda4b2ab7..33a671bbe 100644 --- a/core/trainers/framework/runner.py +++ b/core/trainers/framework/runner.py @@ -134,7 +134,8 @@ def _executor_dataset_train(self, model_dict, context): def _executor_dataloader_train(self, model_dict, context): model_name = model_dict["name"] - program = self._get_dataloader_program(model_dict, context) + model_class = context["model"][model_dict["name"]]["model"] + program = context["model"][model_name]["main_program"] fetch_period = int( envs.get_global_env("runner." + context["runner_name"] + @@ -181,8 +182,7 @@ def _executor_dataloader_train(self, model_dict, context): try: while True: metrics_tensors = context["exe"].run( - program=context["model"][model_dict["name"]][ - "main_program"], + program=program, fetch_list=metrics_varnames, return_numpy=False) diff --git a/core/trainers/framework/startup.py b/core/trainers/framework/startup.py index d22c967e6..cf53288c0 100644 --- a/core/trainers/framework/startup.py +++ b/core/trainers/framework/startup.py @@ -15,7 +15,7 @@ from __future__ import print_function import warnings - +import paddle import paddle.fluid as fluid import paddle.fluid.core as core from paddlerec.core.utils import envs @@ -53,17 +53,16 @@ class SingleStartup(StartupBase): def __init__(self, context): print("Running SingleStartup.") - pass def startup(self, context): for model_dict in context["phases"]: - with fluid.scope_guard(context["model"][model_dict["name"]][ - "scope"]): + with paddle.static.scope_guard(context["model"][model_dict["name"]] + ["scope"]): train_prog = context["model"][model_dict["name"]][ "main_program"] startup_prog = context["model"][model_dict["name"]][ "startup_program"] - with fluid.program_guard(train_prog, startup_prog): + with paddle.static.program_guard(train_prog, startup_prog): context["exe"].run(startup_prog) self.load(context, main_program=train_prog) context["status"] = "train_pass" @@ -187,18 +186,18 @@ def startup(self, context): class FleetStartup(StartupBase): def __init__(self, context): - print("Running PSStartup.") + print("Running FleetStartup.") pass def startup(self, context): model_dict = context["env"]["phase"][0] - with fluid.scope_guard(context["model"][model_dict["name"]]["scope"]): - + with paddle.static.scope_guard(context["model"][model_dict["name"]][ + "scope"]): train_prog = context["model"][model_dict["name"]][ "default_main_program"] startup_prog = context["model"][model_dict["name"]][ "startup_program"] - with fluid.program_guard(train_prog, startup_prog): + with paddle.static.program_guard(train_prog, startup_prog): context["exe"].run(startup_prog) if context['fleet'].is_worker(): # for parameter-server worker @@ -216,12 +215,12 @@ def __init__(self, context): def startup(self, context): for model_dict in context["phases"]: - with fluid.scope_guard(context["model"][model_dict["name"]][ - "scope"]): + with paddle.static.scope_guard(context["model"][model_dict["name"]] + ["scope"]): train_prog = context["model"][model_dict["name"]][ "main_program"] startup_prog = context["model"][model_dict["name"]][ "startup_program"] - with fluid.program_guard(train_prog, startup_prog): + with paddle.static.program_guard(train_prog, startup_prog): context["exe"].run(startup_prog) context["status"] = "train_pass" diff --git a/core/trainers/general_trainer.py b/core/trainers/general_trainer.py index b974748bb..951a842d7 100644 --- a/core/trainers/general_trainer.py +++ b/core/trainers/general_trainer.py @@ -52,12 +52,10 @@ def instance(self, context): else: if self.engine == EngineMode.SINGLE: instance_class_name = "SingleInstance" - elif self.fleet_mode == FleetMode.PSLIB: - instance_class_name = "PslibInstance" - elif self.fleet_mode == FleetMode.PS: - instance_class_name = "PSInstance" - elif self.fleet_mode == FleetMode.COLLECTIVE: - instance_class_name = "CollectiveInstance" + elif self.fleet_mode in [ + FleetMode.PSLIB, FleetMode.PS, FleetMode.COLLECTIVE + ]: + instance_class_name = "FleetInstance" else: raise ValueError("Instance Init Error") instance_path = os.path.join(self.abs_dir, "framework", @@ -76,12 +74,10 @@ def network(self, context): else: if self.engine == EngineMode.SINGLE: network_class_name = "SingleNetwork" - elif self.fleet_mode == FleetMode.PSLIB: - network_class_name = "PslibNetwork" - elif self.fleet_mode == FleetMode.PS: - network_class_name = "PSNetwork" - elif self.fleet_mode == FleetMode.COLLECTIVE: - network_class_name = "CollectiveNetwork" + elif self.fleet_mode in [ + FleetMode.PSLIB, FleetMode.PS, FleetMode.COLLECTIVE + ]: + network_class_name = "FleetNetwork" else: raise ValueError("NetWork Init Error") network_path = os.path.join(self.abs_dir, "framework", @@ -102,10 +98,10 @@ def startup(self, context): startup_class_name = "SingleInferStartup" elif self.engine == EngineMode.SINGLE and not context["is_infer"]: startup_class_name = "SingleStartup" - elif self.fleet_mode == FleetMode.PS or self.fleet_mode == FleetMode.PSLIB: - startup_class_name = "PSStartup" - elif self.fleet_mode == FleetMode.COLLECTIVE: - startup_class_name = "CollectiveStartup" + elif self.fleet_mode in [ + FleetMode.PSLIB, FleetMode.PS, FleetMode.COLLECTIVE + ]: + startup_class_name = "FleetStartup" else: raise ValueError("Startup Init Error") startup_path = os.path.join(self.abs_dir, "framework", @@ -125,12 +121,10 @@ def runner(self, context): runner_class_name = "SingleInferRunner" elif self.engine == EngineMode.SINGLE and not context["is_infer"]: runner_class_name = "SingleRunner" - elif self.fleet_mode == FleetMode.PSLIB: - runner_class_name = "PslibRunner" - elif self.fleet_mode == FleetMode.PS: - runner_class_name = "PSRunner" - elif self.fleet_mode == FleetMode.COLLECTIVE: - runner_class_name = "CollectiveRunner" + elif self.fleet_mode in [ + FleetMode.PSLIB, FleetMode.PS, FleetMode.COLLECTIVE + ]: + runner_class_name = "FleetRunner" else: raise ValueError("Runner Init Error") runner_path = os.path.join(self.abs_dir, "framework", "runner.py")