From 2191a08317e0465a7f3db094ecfac269a43f3285 Mon Sep 17 00:00:00 2001 From: 123malin Date: Fri, 7 Aug 2020 10:40:14 +0800 Subject: [PATCH] =?UTF-8?q?=E3=80=90paddle.fleet=E3=80=91fleet=5Futil=20mo?= =?UTF-8?q?ve=20to=20paddle.fleet=20(#25805)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * test=develop,test=document_fix, remove the out args * fleet_util move to paddle.fleet Co-authored-by: WuHaobo Co-authored-by: tangwei12 --- paddle/fluid/framework/fleet/gloo_wrapper.cc | 13 +- python/paddle/fleet/base/role_maker.py | 520 +++++++++++++++++- python/paddle/fleet/base/util_factory.py | 414 +++++++++++++- python/paddle/fleet/utils/__init__.py | 18 + python/paddle/fleet/utils/fs.py | 382 +++++++++++++ python/paddle/fleet/utils/http_server.py | 195 +++++++ .../fluid/incubate/fleet/base/fleet_base.py | 7 +- .../fluid/tests/unittests/CMakeLists.txt | 1 - .../fluid/tests/unittests/dist_fleet_ctr.py | 11 +- .../tests/unittests/test_dist_fleet_base.py | 64 ++- .../tests/unittests/test_dist_fleet_ctr.py | 2 +- .../tests/unittests/test_fleet_rolemaker_4.py | 7 +- .../unittests/test_fleet_rolemaker_new.py | 171 ++++++ .../fluid/tests/unittests/test_fleet_util.py | 273 ++++++++- .../tests/unittests/test_fs_interface.py | 4 +- .../paddle/fluid/tests/unittests/test_hdfs.py | 4 +- python/requirements.txt | 1 + python/setup.py.in | 1 + 18 files changed, 2014 insertions(+), 74 deletions(-) create mode 100644 python/paddle/fleet/utils/__init__.py create mode 100644 python/paddle/fleet/utils/fs.py create mode 100644 python/paddle/fleet/utils/http_server.py create mode 100644 python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py diff --git a/paddle/fluid/framework/fleet/gloo_wrapper.cc b/paddle/fluid/framework/fleet/gloo_wrapper.cc index 49181cd05f3fa..bb958f1ac015b 100644 --- a/paddle/fluid/framework/fleet/gloo_wrapper.cc +++ b/paddle/fluid/framework/fleet/gloo_wrapper.cc @@ -54,10 +54,9 @@ void HdfsStore::set(const std::string& key, const std::vector& data) { paddle::framework::fs_remove(tmp); if (i == retry_times_) { VLOG(0) << "fs_open_write failed, retry times reaches limit"; - // PADDLE_THROW(platform::errors::PreconditionNotMet( - // "fs_open_write failed, retry times reaches" - // " limit ", - // retry_times_)); + PADDLE_THROW(paddle::platform::errors::PreconditionNotMet( + "fs_open_write failed, retry times reaches %d limit.", + retry_times_)); } } else { break; @@ -143,9 +142,9 @@ void HdfsStore::wait(const std::vector& keys, break; } } - // PADDLE_THROW(platform::errors::ExecutionTimeout( - VLOG(0) << "TIMEOUT self_rank = " << self_rank_ - << " pair_rank = " << last_check_rank; + PADDLE_THROW(paddle::platform::errors::ExecutionTimeout( + "TIMEOUT self_rank = %d pair_rank = %d", self_rank_, + last_check_rank)); } std::this_thread::sleep_for(std::chrono::milliseconds(wait_sleep_ms_)); } diff --git a/python/paddle/fleet/base/role_maker.py b/python/paddle/fleet/base/role_maker.py index f6b5c8ac12e92..b3e8120af6f85 100644 --- a/python/paddle/fleet/base/role_maker.py +++ b/python/paddle/fleet/base/role_maker.py @@ -12,5 +12,523 @@ # See the License for the specific language governing permissions and # limitations under the License. """Defination of Role Makers.""" +import os +import numpy as np +from multiprocessing import Process, Manager +import paddle.fluid as fluid -# __all__ = ['RoleMakerBase', 'UserDefinedRoleMaker', 'PaddleCloudRoleMaker'] +__all__ = ['RoleMakerBase', 'UserDefinedRoleMaker', 'PaddleCloudRoleMaker'] + + +class Role: + WORKER = 1 + SERVER = 2 + + +class RoleMakerBase(object): + """ + RoleMakerBase is a base class for assigning a role to current process + in distributed training. + A paddle developer can implement RoleMakerBase to design a role maker + for worker or pserver assignment. + """ + + def __init__(self): + self._worker_endpoints = [] + self._server_endpoints = [] + self._role_is_generated = False + self._role = None + self._current_id = -1 + + self._node_type = None + self._node_type_comm = None + self._all_comm = None + + def is_worker(self): + """ + return is_worker() of current process + """ + raise NotImplementedError("Please implement this method in child class") + + def is_server(self): + """ + return is_server() of current process + """ + raise NotImplementedError("Please implement this method in child class") + + def is_first_worker(self): + """ + Check whether the node is the first instance of worker. + Returns: + bool: True if this is the first node of worker, + False if not. + """ + raise NotImplementedError("Please implement this method in child class") + + def worker_num(self): + """ + Get current total worker number. + + Returns: + int: worker number + """ + raise NotImplementedError("Please implement this method in child class") + + def server_num(self): + """ + Get current total server number. + + Returns: + int: server number + """ + raise NotImplementedError("Please implement this method in child class") + + def worker_index(self): + """ + Get current worker id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def server_index(self): + """ + Get current server id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def role_id(self): + """ + Get current id. + + Returns: + int: node id + """ + raise NotImplementedError("Please implement this method in child class") + + def get_trainer_endpoints(self): + """ + return trainer endpoints + """ + return self._worker_endpoints + + def get_pserver_endpoints(self): + """ + return pserver endpoints + """ + return self._server_endpoints + + def to_string(self): + return "role: {}, current_id: {}, worker_endpoints: {}, server_endpoints: {}".format( + self._role, self._current_id, self._worker_endpoints, + self._server_endpoints) + + def _all_gather(self, comm_world, input): + """ + + Args: + input(int|float): input value + + Returns: + return a list of values + """ + print("warning: RoleMakerBase does not have all gather.") + return None + + def _all_reduce(self, comm_world, input, mode="sum"): + """ + Args: + input(list/numpy.array): array of one dim + output(list/numpy.array): array of one dim + mode(str): "sum" or "min" or "max" + """ + print("warning: RoleMakerBase does not have all reduce worker.") + return None + + def _barrier(self, comm_world): + """ + barrier between trainers if current role is TRAINER + """ + print("warning: RoleMakerBase does not have barrier worker.") + + +class PaddleCloudRoleMaker(RoleMakerBase): + def __init__(self, is_collective=False, init_gloo=True, **kwargs): + super(PaddleCloudRoleMaker, self).__init__() + self._is_collective = is_collective + self._init_gloo = init_gloo + self._kwargs = kwargs + + self._role_is_generated = False + + self._server_endpoints = None + self._worker_endpoints = None + + self._node_type_comm = None + self._all_comm = None + + if not self._is_collective: + self._hdfs_name = kwargs.get("hdfs_name", "") + self._hdfs_ugi = kwargs.get("hdfs_ugi", "") + self._hdfs_path = kwargs.get("path", "").rstrip("/") + self._init_timeout_seconds = kwargs.get("init_timeout_seconds", + 3600) + self._run_timeout_seconds = kwargs.get("run_timeout_seconds", + 9999999) + ip_port = kwargs.get("http_ip_port", "") + self._http_ip_port = [] + self._http_server = None + # if ip_port is not empty, it will use http instead of hdfs + if ip_port != "": + self._http_ip_port = ip_port.split(":") + # it's for communication between processes + self._manager = Manager() + # global dict to store status + self._http_server_d = self._manager.dict() + # set running status of http server + self._http_server_d["running"] = False + self._iface = self.__get_default_iface() + # this environment variable can be empty + self._prefix = os.getenv("SYS_JOB_ID", "") + + def _barrier(self, comm_world): + if comm_world: + comm_world.barrier() + + def _all_gather(self, comm_world, input): + if comm_world: + self._barrier(comm_world) + output = comm_world.all_gather(input) + return output + else: + return None + + def _all_reduce(self, comm_world, input, mode="sum"): + if not comm_world: + return None + + input = np.array(input) + + input_shape = input.shape + input_list = input.reshape(-1).tolist() + + self._barrier(comm_world) + ans = comm_world.all_reduce(input_list, mode) + output = np.array(ans).reshape(input_shape) + return output + + def is_worker(self): + """ + whether current process is worker + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.WORKER + + def is_server(self): + """ + whether current process is server + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.SERVER + + def is_first_worker(self): + """ + whether current process is worker of rank 0 + """ + if not self._role_is_generated: + self.generate_role() + return self._role == Role.WORKER and self._current_id == 0 + + def worker_index(self): + """ + get index of current worker + """ + if not self._role_is_generated: + self.generate_role() + return self._current_id + + def server_index(self): + """ + get index of current server + """ + if not self._role_is_generated: + self.generate_role() + return self._current_id + + def role_id(self): + """ + get index of current node + """ + if self.is_server(): + return self.server_index() + elif self.is_worker(): + return self.worker_index() + + def worker_num(self): + """ + retrun the current number of worker + """ + if not self._role_is_generated: + self.generate_role() + return self._trainers_num + + def server_num(self): + """ + return the current number of server + """ + if not self._role_is_generated: + self.generate_role() + return self._trainers_num + + def get_trainer_endpoints(self): + """ + get endpoint of all trainers + """ + if not self._role_is_generated: + self.generate_role() + return self._worker_endpoints + + def get_pserver_endpoints(self): + """ + get endpoint of all pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._server_endpoints + + def _get_rank(self): + """ + get current rank in all workers and pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._rank + + def _get_size(self): + """ + get total num of all workers and pservers + """ + if not self._role_is_generated: + self.generate_role() + return self._size + + def _ps_env(self): + try: + # Environment variable PADDLE_PSERVERS_IP_PORT_LIST must be set + # format: string(ip:port), eg. 127.0.0.1:6001 + self._server_endpoints = os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"].split(",") + self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", + "").split(",") + + trainers_num = int(os.environ["PADDLE_TRAINERS_NUM"]) + training_role = os.environ["TRAINING_ROLE"] + + if training_role not in ["TRAINER", "PSERVER"]: + raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") + + if training_role == "TRAINER": + role = Role.WORKER + current_id = int(os.environ["PADDLE_TRAINER_ID"]) + if len(self._worker_endpoints) > 0: + self._cur_endpoint = self._worker_endpoints[current_id] + elif training_role == "PSERVER": + role = Role.SERVER + port = os.environ["PADDLE_PORT"] + ip = os.environ["POD_IP"] + self._cur_endpoint = ip + ":" + port + current_id = self._server_endpoints.index(self._cur_endpoint) + else: + raise ValueError("TRAINING_ROLE must be PSERVER or TRAINER") + except ValueError as ve: + raise ValueError( + "something wrong with PaddleCloud, please check environment") + + self._trainers_num = trainers_num + self._role = role + self._current_id = current_id + + def _collective_env(self): + self._current_id = int(os.getenv("PADDLE_TRAINER_ID", "0")) + self._training_role = os.getenv("PADDLE_TRAINING_ROLE", "TRAINER") + assert (self._training_role == "TRAINER") + self._worker_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS") + self._cur_endpoint = os.getenv("PADDLE_CURRENT_ENDPOINT") + assert self._worker_endpoints is not None, "can't find PADDLE_TRAINER_ENDPOINTS" + self._worker_endpoints = self._worker_endpoints.split(",") + self._trainers_num = len(self._worker_endpoints) + + def _init_gloo_env(self): + def init_gloo_instance(role="trainer"): + role = role.lower() + assert role in ["trainer", "pserver", "all"] + if role == "trainer": + all_list = self._worker_endpoints + rank = self._current_id + elif role == "pserver": + all_list = self._server_endpoints + rank = self._current_id + else: + all_list = self._worker_endpoints + self._server_endpoints + rank = all_list.index(self._cur_endpoint) + gloo = fluid.core.Gloo() + gloo.set_rank(rank) + gloo.set_size(len(all_list)) + gloo.set_prefix(self._prefix) + gloo.set_iface(self._iface) + gloo.set_timeout_seconds(self._init_timeout_seconds, + self._run_timeout_seconds) + if len(self._http_ip_port) != 0: + gloo.set_http_store(self._http_ip_port[0], + int(self._http_ip_port[1]), role) + else: + gloo.set_hdfs_store(self._hdfs_path + "/" + role, + self._hdfs_name, self._hdfs_ugi) + gloo.init() + return gloo + + # paddlecloud support gloo + if self._role == Role.WORKER: + if self._current_id == 0 and len(self._http_ip_port) != 0: + size_d = { + "trainer": len(self._worker_endpoints), + "pserver": len(self._server_endpoints), + "all": + len(self._worker_endpoints) + len(self._server_endpoints) + } + # child process for http server + self._http_server = Process( + target=self.__start_kv_server, + args=(self._http_server_d, size_d)) + self._http_server.daemon = True + # set running status to True + self._http_server_d["running"] = True + # start child process + self._http_server.start() + self._node_type = 1 + gloo = init_gloo_instance("trainer") + self._node_type_comm = gloo + else: + assert self._role == Role.SERVER + self._node_type = 0 + gloo = init_gloo_instance("pserver") + self._node_type_comm = gloo + + all_list = self._worker_endpoints + self._server_endpoints + self._rank = all_list.index(self._cur_endpoint) + self._size = len(all_list) + + gloo = init_gloo_instance("all") + self._all_comm = gloo + + if self._http_server is not None: + # set running status to False + self._http_server_d["running"] = False + # wait until child process exits + self._http_server.join() + + def generate_role(self): + """ + generate role for role maker + """ + if not self._role_is_generated: + if not self._is_collective: + self._ps_env() + if self._init_gloo: + self._init_gloo_env() + else: + self._collective_env() + self._role_is_generated = True + + def __get_default_iface(self): + """ + get default physical interface + """ + default1 = self.__get_default_iface_from_gateway() + default2 = self.__get_default_iface_from_interfaces() + return default2 if default1 == "lo" else default1 + + def __get_default_iface_from_gateway(self): + """ + get default physical interface + """ + import netifaces + gateways = netifaces.gateways() + if gateways.get(netifaces.AF_INET) != None: + gateway = gateways[netifaces.AF_INET] + if len(gateway) > 0 and len(gateway[0]) > 1: + return gateway[0][1] + return "lo" + + def __get_default_iface_from_interfaces(self): + """ + get default physical interface + """ + import netifaces + for intf_name in netifaces.interfaces(): + addresses = netifaces.ifaddresses(intf_name) + if netifaces.AF_INET in addresses: + ipv4_addresses = addresses[netifaces.AF_INET] + for ipv4_address in ipv4_addresses: + if 'broadcast' in ipv4_address: + return intf_name + return "lo" + + def __start_kv_server(self, http_server_d, size_d): + from paddle.fleet.utils import KVServer + http_server = KVServer(int(self._http_ip_port[1]), size_d) + http_server.start() + wait_seconds = 5 + while http_server_d.get("running", + False) and not http_server.shoud_stop(): + time.sleep(wait_seconds) + http_server.stop() + + +class UserDefinedRoleMaker(PaddleCloudRoleMaker): + def __init__(self, is_collective=False, init_gloo=False, **kwargs): + super(UserDefinedRoleMaker, self).__init__( + is_collective=is_collective, init_gloo=init_gloo, **kwargs) + + def _user_defined_ps_env(self): + self._server_endpoints = self._kwargs.get("server_endpoints") + self._worker_endpoints = self._kwargs.get("worker_endpoints", []) + self._trainers_num = self._kwargs.get("worker_num", 0) + + if self._trainers_num == 0: + assert (len(self._worker_endpoints) > 0) + self._trainers_num = len(self._worker_endpoints) + + self._role = self._kwargs.get("role") + self._current_id = self._kwargs.get("current_id") + + if self._role == Role.WORKER and len( + self._worker_endpoints) > self._current_id: + self._cur_endpoint = self._worker_endpoints[self._current_id] + elif self._role == Role.SERVER: + self._cur_endpoint = self._server_endpoints[self._current_id] + + def _user_defined_collective_env(self): + self._worker_endpoints = self._kwargs.get("worker_endpoints") + self._current_id = self._kwargs.get("current_id") + self._trainers_num = len(self._worker_endpoints) + self._training_role = Role.Worker + + def generate_role(self): + """ + generate role for role maker + """ + if not self._role_is_generated: + if not self._is_collective: + self._user_defined_ps_env() + if self._init_gloo: + self._init_gloo_env() + else: + self._user_defined_collective_env() + self._role_is_generated = True diff --git a/python/paddle/fleet/base/util_factory.py b/python/paddle/fleet/base/util_factory.py index 385500de8c018..ed2a8db586aa9 100644 --- a/python/paddle/fleet/base/util_factory.py +++ b/python/paddle/fleet/base/util_factory.py @@ -18,12 +18,27 @@ __all__ = ['UtilBase'] +import numpy as np +import os + +import subprocess +from paddle.fluid import core +from collections import OrderedDict +import paddle.fluid as fluid +from google.protobuf import text_format +from paddle.fluid import debugger +from paddle.fluid.framework import Program +from paddle.fluid.proto import framework_pb2 +from ..utils.fs import FS, LocalFS, HDFSClient + class UtilFactory(object): - def _create_util(self, context): + def _create_util(self, context=None): util = UtilBase() - util._set_strategy(context["valid_strategy"]) - util._set_role_maker(context["role_maker"]) + if context is not None and "valid_strategy" in context: + util._set_strategy(context["valid_strategy"]) + if context is not None and "role_maker" in context: + util._set_role_maker(context["role_maker"]) return util @@ -38,43 +53,390 @@ def _set_strategy(self, dist_strategy): def _set_role_maker(self, role_maker): self.role_maker = role_maker - ''' def set_file_system(self, fs_client): + assert isinstance( + fs_client, + FS), "fs_client must be the instance of paddle.fleet.utils.FS" self.fs_client = fs_client - def broadcast(self): - pass + def __check_comm_world(self, comm_world="worker"): + if not self.role_maker._role_is_generated: + self.role_maker.generate_role() - def all_gather(self): - pass + _comm_world = None + comm_world_upper = comm_world.upper() + if comm_world_upper == "WORKER": + if not self.role_maker.is_worker(): + print( + "warning: current role is not worker in collective_func(comm_world=\"worker\")" + ) + _comm_world = self.role_maker._node_type_comm + elif comm_world_upper == "SERVER": + if not self.role_maker.is_server(): + print( + "warning: current role is not server in collective_func(comm_world=\"server\")" + ) + _comm_world = self.role_maker._node_type_comm + elif comm_world_upper == "ALL": + _comm_world = self.role_maker._all_comm + else: + raise ValueError( + "not support comm_world, please choose one from [worker, server, all]" + ) - def all_reduce(self): - pass + return _comm_world - def reduce_scatter(self): + def all_reduce(self, input, mode, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + return self.role_maker._all_reduce(_comm_world, input, mode) + + def barrier(self, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + self.role_maker._barrier(_comm_world) + + def all_gather(self, input, comm_world="worker"): + _comm_world = self.__check_comm_world(comm_world) + return self.role_maker._all_gather(_comm_world, input) + + def broadcast(self): pass - def reduce(self): + def scatter(self): pass def get_file_shard(self, files): - pass + """ + split files before distributed training, + example 1: files is [a, b, c ,d, e] and trainer_num = 2, then trainer + 0 gets [a, b, c] and trainer 1 gets [d, e]. + example 2: files is [a, b], and trainer_num = 3, then trainer 0 gets + [a], trainer 1 gets [b], trainer 2 gets [] - def feed_gen(self, batch_size, feed_vars_dims, feeded_vars_filelist): - pass + Args: + files(list): file list need to be read. - def save_program(program, output_dir): - pass + Returns: + list: files belongs to this worker. + """ + if not isinstance(files, list): + raise TypeError("files should be a list of file need to be read.") - def load_program(input_dir): - pass + trainer_id = self.role_maker.worker_index() + trainers = self.role_maker.worker_num() - def load_var(): - pass + remainder = len(files) % trainers + blocksize = int(len(files) / trainers) - def save_var(): - pass + blocks = [blocksize] * trainers + for i in range(remainder): + blocks[i] += 1 - def print_on_rank(self): - pass - ''' + trainer_files = [[]] * trainers + begin = 0 + for i in range(trainers): + trainer_files[i] = files[begin:begin + blocks[i]] + begin += blocks[i] + + return trainer_files[trainer_id] + + def print_on_rank(self, message, rank_id): + if self.role_maker.worker_index() != rank_id: + return + print(message) + + def _save_program(self, program, model_filename='__model__', is_text=False): + if is_text: + with open(model_filename, "w") as f: + f.write(str(program)) + else: + with open(model_filename, "wb") as f: + f.write(program.desc.serialize_to_string()) + + def _load_program(self, path, is_text): + def load_program_binary(path): + """load program from binary string file""" + with open(path, "rb") as f: + program_desc_str = f.read() + return Program.parse_from_string(program_desc_str) + + def load_program_text(path): + """load program from human-readable text file""" + with open(path, "r") as f: + program_desc_text = f.read() + + prog_desc = framework_pb2.ProgramDesc() + text_format.Merge(program_desc_text, prog_desc) + return Program.parse_from_string(prog_desc.SerializeToString()) + + if is_text: + return load_program_text(path) + else: + return load_program_binary(path) + + def _program_type_trans(self, prog_dir, prog_fn, is_text): + prog = self._load_program(os.path.join(prog_dir, prog_fn), is_text) + prog_out_fn = prog_fn + ".bin" if is_text else prog_fn + ".pbtxt" + self._save_program(prog, + os.path.join(prog_dir, prog_out_fn), 1 - is_text) + return prog_out_fn + + def _visualize_graphviz(self, program, output_dir, output_filename): + block = program.global_block() + dot_path = os.path.join(output_dir, output_filename + '.dot') + pdf_path = os.path.join(output_dir, output_filename + '.pdf') + debugger.draw_block_graphviz(block, path=dot_path) + cmd = ["dot", "-Tpdf", dot_path, "-o", pdf_path] + p = subprocess.Popen( + cmd, + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE) + p.wait() + + def _proto_check(self, config): + train_prog = self._load_program(config.train_prog_path, + config.is_text_train_program) + pruned_prog = self._load_program(config.pruned_prog_path, + config.is_text_pruned_program) + + is_match = True + + pruned_vars = [(v.name, v) for v in pruned_prog.list_vars() + if fluid.io.is_persistable(v)] + pruned_vars = OrderedDict(pruned_vars) + pruned_vars_name = [name for name in pruned_vars] + print("persistable vars in pruned program: {}".format(pruned_vars_name)) + + # feed and fetch op is added in pruned program when pruning, not need to be found in train program + feed_fetch_type_list = [ + core.VarDesc.VarType.FEED_MINIBATCH, core.VarDesc.VarType.FETCH_LIST + ] + + for var_name in pruned_vars: + var = pruned_vars[var_name] + # feed and fetch op is added in pruned program when pruning, not need to be found in train program + if var.type in feed_fetch_type_list: + break + try: + train_prog_var = train_prog.global_block().var(var_name) + except ValueError as e: + print( + "Not find variable '%s' in train program. please check pruning." + % var_name) + is_match = False + continue + if var.shape != train_prog_var.shape or var.dtype != train_prog_var.dtype: + print( + "variable: {} not match. in pruned program shape: {} dtype:{}, in train program shape: {} dtype: {}". + format(var_name, var.shape, var.dtype, train_prog_var.shape, + train_prog_var.dtype)) + is_match = False + return is_match + + def _params_check(self, config): + def feed_gen(batch_size, feeded_vars_dims, feeded_vars_filelist): + def reader(batch_size, fn, dim): + data = [] + if isinstance(dim, list) or isinstance(dim, tuple): + shape = list(dim) + _temp = 1 + for x in dim: + _temp = _temp * x + dim = _temp + else: + shape = [dim] + + shape = [batch_size] + shape + dim = dim * batch_size + + for line in open(fn, 'r'): + fields = line.strip().split(' ') + fields = [float(d) for d in fields] + while len(fields) >= dim: + tmp = fields[:dim] + fields = fields[dim:] + data.append(np.array(tmp).reshape(shape)) + return data + + batch_feed = [] + for i, fn in enumerate(feeded_vars_filelist): + batch_feed.append(reader(batch_size, fn, feeded_vars_dims[i])) + return batch_feed + + prog = self._load_program( + os.path.join(config.dump_model_dir, config.dump_program_filename), + config.is_text_dump_program) + if config.is_text_dump_program: + model_filename = self._program_type_trans( + config.dump_model_dir, config.dump_program_filename, + config.is_text_dump_program) + + saved_params = [ + v for v in prog.list_vars() if fluid.io.is_persistable(v) + ] + print("persistable vars in dump program: {}".format( + [v.name for v in saved_params])) + + def check_not_expected_ops(prog, not_expected_op_types): + op_types_set = set() + for op in prog.global_block().ops: + if op.type in not_expected_op_types and op.type not in op_types_set: + op_types_set.add(op.type) + return op_types_set + + not_expected_op_types = check_not_expected_ops(prog, ["lookup_table"]) + if len(not_expected_op_types) > 0: + print( + "find op type '{}' in program, please check if your program is pruned correctly !". + format(list(not_expected_op_types))) + return False + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.core.Scope() + with fluid.scope_guard(scope): + inference_program, feed_target_names, fetch_targets = \ + fluid.io.load_inference_model(config.dump_model_dir, exe, model_filename=model_filename, + params_filename=config.save_params_filename) + + # check program vars and saved vars shape + orig_para_shape = { + each_var.name: tuple(each_var.desc.shape()) + for each_var in saved_params + } + for each_var in saved_params: + var_temp = fluid.global_scope().find_var(each_var.name) + assert var_temp != None, "can't not find var: " + each_var.name + new_shape = (np.array(var_temp.get_tensor())).shape + assert each_var.name in orig_para_shape, each_var.name + "MUST in var list" + orig_shape = orig_para_shape.get(each_var.name) + if new_shape != orig_shape: + raise RuntimeError( + "Shape not matching: the Program requires a parameter with a shape of ({}), " + "while the loaded parameter (namely [ {} ]) has a shape of ({}).". + format(orig_shape, each_var.name, new_shape)) + + # check feed/fetch vars in program and config + feed_config = config.feed_config + fetch_config = config.fetch_config + fetch_targets_names = [v.name for v in fetch_targets] + if not feed_target_names: + print("warning! no feed targets in program.") + if not fetch_targets_names: + print("warning! no fetch targets in program.") + fetch_list = fetch_targets + feed_name_list = feed_target_names + if feed_config.feeded_vars_names is not None and feed_target_names != feed_config.feeded_vars_names: + print( + "warning! feed vars in program and config are diff: feed in program: {}. feed in config {}.". + format(feed_target_names, feed_config.feeded_vars_names)) + feed_name_list = feed_config.feeded_vars_names + # remove feed op in inference_program. new feed op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed": # only remove feed op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + if fetch_config.fetch_vars_names is not None and fetch_targets_names != fetch_config.fetch_vars_names: + print( + "warning! fetch vars in program and config are diff: fetch in program: {}. fetch in config {}.". + format(fetch_targets_names, fetch_config.fetch_vars_names)) + fetch_list = [ + inference_program.global_block().var(i) + for i in fetch_config.fetch_vars_names + ] + # remove fetch op in inference_program. new fetch op will be added in exe.run + global_block = inference_program.global_block() + need_to_remove_op_index = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "fetch": # only remove fetch op here + need_to_remove_op_index.append(i) + for index in need_to_remove_op_index[::-1]: + global_block._remove_op(index) + + # if fetch_list have lod tensor + return_numpy = all([v.lod_level == 0 for v in fetch_list]) + + # try dump fetch_targets + feed_tensors = [] + assert len(feed_config.feeded_vars_names) == len( + feed_config.feeded_vars_dims) == len( + feed_config.feeded_vars_types) + # check program vars and feed tensor shape in config + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + if not isinstance(feed_config.feeded_vars_dims[i], + (list, tuple)): + tensor_shape = (feed_config.feeded_vars_dims[i], ) + else: + tensor_shape = tuple(feed_config.feeded_vars_dims[i]) + feed_config.feeded_vars_dims[i] = tensor_shape + var_shape = var.shape[1:] + if tensor_shape != var_shape: + raise RuntimeError( + "feed variable '{}' shape not match. infer program shape: {}. feed tensor shape: {}". + format(feed_config.feeded_vars_names[i], var_shape, + tensor_shape)) + + if not feed_config.feeded_vars_filelist: + print("generate random feed vars.") + for i in range(len(feed_config.feeded_vars_names)): + var = inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + # create fake feed tensor. if lod_level > 1, should create_lod_tensor() + if var.lod_level == 0: + feed_tensors.append( + np.array( + np.random.random( + tuple([config.batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i])) + elif var.lod_level == 1: + t = np.array( + np.random.random( + tuple([config.batch_size] + list( + feed_config.feeded_vars_dims[i]))), + dtype=feed_config.feeded_vars_types[i]) + feed_tensors.append( + fluid.create_lod_tensor(t, [[1] * config.batch_size + ], place)) + else: + raise RuntimeError( + "vars with lod_level >= 2 is not supported now in this infer program check tool." + ) + results = exe.run(inference_program, + feed={ + name: feed_tensors[i] + for i, name in enumerate(feed_name_list) + }, + fetch_list=fetch_list, + return_numpy=return_numpy) + else: + print("load feed vars from files: {}.".format( + feed_config.feeded_vars_filelist)) + feed_vars = [ + inference_program.global_block().var( + feed_config.feeded_vars_names[i]) + for i in range(len(feed_config.feeded_vars_names)) + ] + feeder = fluid.DataFeeder(feed_list=feed_vars, place=place) + batch_feed = feed_gen(config.batch_size, + feed_config.feeded_vars_dims, + feed_config.feeded_vars_filelist) + slots = [batch_feed] + results = exe.run(inference_program, + feed=feeder.feed(slots), + fetch_list=fetch_list, + return_numpy=return_numpy) + for i, v in enumerate(fetch_list): + print("fetch_targets name: %s" % v.name) + print("fetch_targets: {}".format(results[i])) + return results + + +fleet_util = UtilFactory()._create_util(None) diff --git a/python/paddle/fleet/utils/__init__.py b/python/paddle/fleet/utils/__init__.py new file mode 100644 index 0000000000000..212308159aabb --- /dev/null +++ b/python/paddle/fleet/utils/__init__.py @@ -0,0 +1,18 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .fs import * +from .http_server import KVHandler, KVHTTPServer, KVServer + +__all__ = ['KVHandler', 'KVHTTPServer', 'KVServer'] + fs.__all__ diff --git a/python/paddle/fleet/utils/fs.py b/python/paddle/fleet/utils/fs.py new file mode 100644 index 0000000000000..3fec773f27318 --- /dev/null +++ b/python/paddle/fleet/utils/fs.py @@ -0,0 +1,382 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import sys +import subprocess +import multiprocessing +from datetime import datetime + +import re +import copy +import errno +import time +import logging +import six +import abc +import paddle.fluid as fluid +import functools + +from pathlib import PurePosixPath, Path +import shutil + +__all__ = [ + 'FS', 'LocalFS', 'HDFSClient', 'ExecuteError', 'FSTimeOut', + 'FSFileExistsError', 'FSFileNotExistsError' +] + + +class ExecuteError(Exception): + pass + + +class FSFileExistsError(Exception): + pass + + +class FSFileNotExistsError(Exception): + pass + + +class FSTimeOut(Exception): + pass + + +class FS(object): + @abc.abstractmethod + def ls_dir(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_file(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_dir(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def is_exist(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def upload(self, local_path, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def download(self, fs_path, local_path): + raise NotImplementedError + + @abc.abstractmethod + def mkdirs(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def delete(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def need_upload_download(self): + raise NotImplementedError + + @abc.abstractmethod + def rename(self, fs_src_path, fs_dst_path): + raise NotImplementedError + + @abc.abstractmethod + def mv(self, fs_src_path, fs_dst_path): + raise NotImplementedError + + @abc.abstractmethod + def upload_dir(self, local_dir, dest_dir): + raise NotImplementedError + + @abc.abstractmethod + def glob(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def stat(self, fs_path): + raise NotImplementedError + + @abc.abstractmethod + def walk(self, fs_path): + raise NotImplementedError + + +class LocalFS(FS): + def ls_dir(self, fs_path): + if not self.is_exist(fs_path): + return [], [] + + dirs = [] + files = [] + for f in os.listdir(fs_path): + if os.path.isdir(fs_path + "/" + f): + dirs.append(f) + else: + files.append(f) + + return dirs, files + + def mkdirs(self, fs_path): + assert not os.path.isfile(fs_path), "{} is already a file".format( + fs_path) + os.system("mkdir -p {}".format(fs_path)) + + def is_file(self, fs_path): + return os.path.isfile(fs_path) + + def is_dir(self, fs_path): + return os.path.isdir(fs_path) + + def is_exist(self, fs_path): + return os.path.exists(fs_path) + + def _rmr(self, fs_path): + shutil.rmtree(fs_path) + + def _rm(self, fs_path): + os.remove(fs_path) + + def delete(self, fs_path): + if not self.is_exist(fs_path): + return + + if os.path.isfile(fs_path): + return self._rm(fs_path) + + return self._rmr(fs_path) + + def rename(self, fs_src_path, fs_dst_path): + os.rename(fs_src_path, fs_dst_path) + + def need_upload_download(self): + return False + + def touch(self, fs_path): + return Path(fs_path).touch() + + def mv(self, src_path, dst_path): + if not self.is_exist(src_path): + raise FSFileNotExistsError + + if self.is_exist(dst_path): + raise FSFileExistsError + + return self.rename(src_path, dst_path) + + +"""HDFS Utils.""" + + +def _handle_errors(f): + def handler(*args, **kwargs): + start = time.time() + while True: + try: + return f(*args, **kwargs) + except ExecuteError as e: + o = args[0] + time_out = float(o._time_out) / 1000.0 + inter = float(o._sleep_inter) / 1000.0 + if time.time() - start >= time_out: + raise FSTimeOut + time.sleep(inter) + + return functools.wraps(f)(handler) + + +class HDFSClient(FS): + def __init__( + self, + hadoop_home, + configs, + time_out=5 * 60 * 1000, #ms + sleep_inter=1000): #ms + # Raise exception if JAVA_HOME not exists. + java_home = os.environ["JAVA_HOME"] + + self.pre_commands = [] + hadoop_bin = '%s/bin/hadoop' % hadoop_home + self.pre_commands.append(hadoop_bin) + dfs = 'fs' + self.pre_commands.append(dfs) + + if configs: + for k, v in six.iteritems(configs): + self.pre_commands.append('-D%s=%s' % (k, v)) + + self._time_out = time_out + self._sleep_inter = sleep_inter + self._base_cmd = " ".join(self.pre_commands) + self._bd_err_re = re.compile( + r'\s?responseErrorMsg\s?\:.*, errorCode\:\s?[0-9]+, path\:') + + def _run_cmd(self, cmd, redirect_stderr=False): + ret, output = fluid.core.shell_execute_cmd(cmd, 0, 0, redirect_stderr) + return int(ret), output.splitlines() + + @_handle_errors + def ls_dir(self, fs_path): + """ + list directory under fs_path, and only give the pure name, not include the fs_path + """ + if not self.is_exist(fs_path): + return [], [] + + cmd = "{} -ls {}".format(self._base_cmd, fs_path) + ret, lines = self._run_cmd(cmd) + + if ret != 0: + raise ExecuteError + + dirs = [] + files = [] + for line in lines: + arr = line.split() + if len(arr) != 8: + continue + + if fs_path not in arr[7]: + continue + + p = PurePosixPath(arr[7]) + if arr[0][0] == 'd': + dirs.append(p.name) + else: + files.append(p.name) + + return dirs, files + + def _test_match(self, lines): + for l in lines: + m = self._bd_err_re.match(l) + if m != None: + return m + + return None + + @_handle_errors + def is_dir(self, fs_path): + if not self.is_exist(fs_path): + return False + + cmd = "{} -test -d {}".format( + self._base_cmd, fs_path, redirect_stderr=True) + ret, lines = self._run_cmd(cmd) + if ret: + # other error + if self._test_match(lines) != None: + raise ExecuteError + + return False + + return True + + def is_file(self, fs_path): + if not self.is_exist(fs_path): + return False + + return not self.is_dir(fs_path) + + @_handle_errors + def is_exist(self, fs_path): + cmd = "{} -ls {} ".format(self._base_cmd, fs_path) + ret, out = self._run_cmd(cmd, redirect_stderr=True) + if ret != 0: + for l in out: + if "No such file or directory" in l: + return False + raise ExecuteError + + return True + + @_handle_errors + def upload(self, local_path, fs_path): + if self.is_exist(fs_path): + raise FSFileExistsError + + local = LocalFS() + if not local.is_exist(local_path): + raise FSFileNotExistsError + + cmd = "{} -put {} {}".format(self._base_cmd, local_path, fs_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def download(self, fs_path, local_path): + if self.is_exist(local_path): + raise FSFileExistsError + + if not self.is_exist(fs_path): + raise FSFileNotExistsError + + cmd = "{} -get {} {}".format(self._base_cmd, fs_path, local_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def mkdirs(self, fs_path): + if self.is_exist(fs_path): + return + + cmd = "{} -mkdir {}".format(self._base_cmd, fs_path) + ret, lines = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def mv(self, fs_src_path, fs_dst_path, test_exists=True): + if test_exists: + if not self.is_exist(fs_src_path): + raise FSFileNotExistsError + + if self.is_exist(fs_dst_path): + raise FSFileExistsError + + cmd = "{} -mv {} {}".format(self._base_cmd, fs_src_path, fs_dst_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def _rmr(self, fs_path): + cmd = "{} -rmr {}".format(self._base_cmd, fs_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + @_handle_errors + def _rm(self, fs_path): + cmd = "{} -rm {}".format(self._base_cmd, fs_path) + ret, _ = self._run_cmd(cmd) + if ret != 0: + raise ExecuteError + + def delete(self, fs_path): + if not self.is_exist(fs_path): + return + + is_dir = self.is_dir(fs_path) + if is_dir: + return self._rmr(fs_path) + + return self._rm(fs_path) + + def need_upload_download(self): + return True diff --git a/python/paddle/fleet/utils/http_server.py b/python/paddle/fleet/utils/http_server.py new file mode 100644 index 0000000000000..78e310b0a5a51 --- /dev/null +++ b/python/paddle/fleet/utils/http_server.py @@ -0,0 +1,195 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Http Server.""" + +import logging + +import six +# NOTE: HTTPServer has a different name in python2 and python3 +if six.PY2: + from BaseHTTPServer import HTTPServer + import SimpleHTTPServer +else: + from http.server import HTTPServer + import http.server as SimpleHTTPServer + +import time +import threading +import socket + + +def get_logger(name, level, fmt): + logger = logging.getLogger(name) + logger.setLevel(level) + handler = logging.FileHandler('http.log', mode='w') + formatter = logging.Formatter(fmt=fmt) + handler.setFormatter(formatter) + logger.addHandler(handler) + return logger + + +_http_server_logger = get_logger( + __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s') + + +class KVHandler(SimpleHTTPServer.SimpleHTTPRequestHandler): + """ + kv handler class for kv http server, + it defines the way to get/set kv in server. + """ + + def do_GET(self): + """ + get method for kv handler, get value according to key. + """ + log_str = "GET " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + with self.server.kv_lock: + value = self.server.kv.get(scope, {}).get(key) + if value is None: + log_str += ' , key not found: ' + key + self.send_status_code(404) + else: + log_str += ' , key found: ' + key + self.send_response(200) + self.send_header("Content-Length", str(len(value))) + self.end_headers() + self.wfile.write(value) + _http_server_logger.info(log_str) + + def do_PUT(self): + """ + put method for kv handler, set value according to key. + """ + log_str = "PUT " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + content_length = int(self.headers['Content-Length']) + try: + value = self.rfile.read(content_length) + except: + print("receive error invalid request") + self.send_status_code(404) + return + with self.server.kv_lock: + if self.server.kv.get(scope) is None: + self.server.kv[scope] = {} + self.server.kv[scope][key] = value + self.send_status_code(200) + _http_server_logger.info(log_str) + + def do_DELETE(self): + """ + delete method for kv handler, set value according to key. + """ + log_str = "DELETE " + self.address_string() + self.path + paths = self.path.split('/') + if len(paths) < 3: + print('len of request path must be 3: ' + self.path) + self.send_status_code(400) + return + _, scope, key = paths + with self.server.delete_kv_lock: + if self.server.delete_kv.get(scope) is None: + self.server.delete_kv[scope] = [] + self.server.delete_kv[scope].append(key) + self.send_status_code(200) + _http_server_logger.info(log_str) + + def log_message(self, format, *args): + """ + ignore all logging messages in kv handler. + """ + pass + + def send_status_code(self, code): + """ + send status code back to client. + """ + self.send_response(code) + self.send_header("Content-Length", 0) + self.end_headers() + + +class KVHTTPServer(HTTPServer, object): + """ + it is a http server storing kv pairs. + """ + + def __init__(self, port, handler): + """Init.""" + super(KVHTTPServer, self).__init__(('', port), handler) + self.delete_kv_lock = threading.Lock() + self.delete_kv = {} + self.kv_lock = threading.Lock() + self.kv = {} + + def get_deleted_size(self, key): + """ + get deleted size in key. + """ + ret = 0 + with self.delete_kv_lock: + ret = self.delete_kv.get(key, 0) + return ret + + +class KVServer: + """ + it is a server storing kv pairs, has a http server inside. + """ + + def __init__(self, port, size={}): + """Init.""" + self.http_server = KVHTTPServer(port, KVHandler) + self.listen_thread = None + self.size = {} + + def start(self): + """ + start server until user calls stop to let it quit. + """ + self.listen_thread = threading.Thread( + target=lambda: self.http_server.serve_forever()) + self.listen_thread.start() + + def stop(self): + """ + stop server and clear its resources. + """ + self.http_server.shutdown() + self.listen_thread.join() + self.http_server.server_close() + + def shoud_stop(self): + """ + return whether the server should stop. + + Returns: + ret(bool): whether the server should stop + """ + for key in self.size: + s = self.http_server.get_deleted_size(key) + if s != self.size.get(key, 0): + return False + return True diff --git a/python/paddle/fluid/incubate/fleet/base/fleet_base.py b/python/paddle/fluid/incubate/fleet/base/fleet_base.py index 9be1fe92d1d0c..f236a3e98c61b 100644 --- a/python/paddle/fluid/incubate/fleet/base/fleet_base.py +++ b/python/paddle/fluid/incubate/fleet/base/fleet_base.py @@ -21,7 +21,7 @@ from paddle.fluid.optimizer import SGD from paddle.fluid.incubate.fleet.base.mode import Mode -from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase +from paddle.fleet.base.role_maker import RoleMakerBase from paddle.fluid.contrib.mixed_precision.decorator import OptimizerWithMixedPrecision from . import mode @@ -209,7 +209,10 @@ def init(self, role_maker=None): self._executor = Executor(fluid.CPUPlace()) if role_maker and not isinstance(role_maker, RoleMakerBase): - raise TypeError("role_maker must be an instance of RoleMakerBase") + from paddle.fluid.incubate.fleet.base.role_maker import RoleMakerBase as RoleMakerBaseIncubate + if role_maker and not isinstance(role_maker, RoleMakerBaseIncubate): + raise TypeError( + "role_maker must be an instance of RoleMakerBase") self._role_maker = role_maker self._role_maker.generate_role() diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index d73b9511b76ed..686844fea76c0 100755 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -345,7 +345,6 @@ if(WITH_DISTRIBUTE) # FIXME(typhoonzero): add these tests back list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transformer") list(REMOVE_ITEM DIST_TEST_OPS "test_dist_transpiler") - list(REMOVE_ITEM DIST_TEST_OPS "test_dist_fleet_ctr") #not need list(REMOVE_ITEM DIST_TEST_OPS "test_dist_base") diff --git a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py index 56ca3105dea79..033bc38500521 100644 --- a/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/dist_fleet_ctr.py @@ -28,6 +28,7 @@ import ctr_dataset_reader from test_dist_fleet_base import runtime_main, FleetDistRunnerBase +from paddle.fleet.base.util_factory import fleet_util # Fix seed for test fluid.default_startup_program().random_seed = 1 @@ -181,8 +182,14 @@ def do_pyreader_training(self, fleet): loss_val = exe.run(program=compiled_prog, fetch_list=[self.avg_cost.name]) loss_val = np.mean(loss_val) - print("TRAIN ---> pass: {} loss: {}\n".format(epoch_id, - loss_val)) + reduce_output = fleet_util.all_reduce( + np.array(loss_val), mode="sum") + loss_all_trainer = fleet_util.all_gather(float(loss_val)) + loss_val = float(reduce_output) / len(loss_all_trainer) + message = "TRAIN ---> pass: {} loss: {}\n".format(epoch_id, + loss_val) + fleet_util.print_on_rank(message, 0) + pass_time = time.time() - pass_start except fluid.core.EOFException: self.reader.reset() diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py index 16f0fc0a35e61..8b2f7118ea766 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_base.py @@ -21,6 +21,9 @@ import sys import subprocess +import six +import shutil +import numpy as np import argparse from contextlib import closing import socket @@ -29,7 +32,8 @@ import unittest import paddle.fluid as fluid -import paddle.fluid.incubate.fleet.base.role_maker as role_maker +import paddle.fleet.base.role_maker as role_maker +from paddle.fleet.base.util_factory import fleet_util from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler import fleet from paddle.fluid.incubate.fleet.parameter_server.distribute_transpiler.distributed_strategy import StrategyFactory @@ -48,18 +52,26 @@ class FleetDistRunnerBase(object): """ def build_role(self, args): + if args.role.upper() == "PSERVER": role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=True, + path=args.gloo_path, current_id=args.current_id, role=role_maker.Role.SERVER, - worker_num=args.trainers, + worker_endpoints=args.trainer_endpoints.split(","), server_endpoints=args.endpoints.split(",")) else: role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=True, + path=args.gloo_path, current_id=args.current_id, role=role_maker.Role.WORKER, - worker_num=args.trainers, + worker_endpoints=args.trainer_endpoints.split(","), server_endpoints=args.endpoints.split(",")) + self.role = role return role def build_strategy(self, args): @@ -114,26 +126,13 @@ def build_optimizer(self, avg_cost, strategy): optimizer.minimize(avg_cost) def run_pserver(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) - fleet.init_server() fleet.run_server() def run_dataset_trainer(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) out = self.do_dataset_training(fleet) def run_pyreader_trainer(self, args): - fleet.init(self.build_role(args)) - strategy = self.build_strategy(args) - avg_cost = self.net(args) - self.build_optimizer(avg_cost, strategy) out = self.do_pyreader_training(fleet) def net(self, args, batch_size=4, lr=0.01): @@ -173,10 +172,14 @@ def setUp(self): print("set begin_port:", DIST_UT_PORT) self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( DIST_UT_PORT, DIST_UT_PORT + 1) - DIST_UT_PORT += 2 + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + DIST_UT_PORT + 2, DIST_UT_PORT + 3) + DIST_UT_PORT += 4 else: self._ps_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( self._find_free_port(), self._find_free_port()) + self._tr_endpoints = "127.0.0.1:%s,127.0.0.1:%s" % ( + self._find_free_port(), self._find_free_port()) self._python_interp = sys.executable self._geo_sgd_need_push_nums = 5 @@ -236,18 +239,22 @@ def _start_trainer(self, cmd, required_envs): def _run_cluster(self, model, envs): env = {'GRAD_CLIP': str(self._grad_clip_mode)} python_path = self._python_interp + gloo_path = tempfile.mkdtemp() + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': envs['COVERAGE_FILE'] = os.getenv('COVERAGE_FILE', '') python_path += " -m coverage run --branch -p" env.update(envs) - tr_cmd = "{0} {1} --role trainer --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( - python_path, model, self._ps_endpoints, self._trainers, self._mode, - self._geo_sgd_need_push_nums, self._reader) + tr_cmd = "{0} {1} --role trainer --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path) - ps_cmd = "{0} {1} --role pserver --endpoints {2} --current_id {{}} --trainers {3} --mode {4} --geo_sgd_need_push_nums {5} --reader {6}".format( - python_path, model, self._ps_endpoints, self._trainers, self._mode, - self._geo_sgd_need_push_nums, self._reader) + ps_cmd = "{0} {1} --role pserver --endpoints {2} --trainer_endpoints {3} --current_id {{}} --trainers {4} --mode {5} --geo_sgd_need_push_nums {6} --reader {7} --gloo_path {8}".format( + python_path, model, self._ps_endpoints, self._tr_endpoints, + self._trainers, self._mode, self._geo_sgd_need_push_nums, + self._reader, gloo_path) # Run dist train to compare with local results ps0, ps1, ps0_pipe, ps1_pipe = self._start_pserver(ps_cmd, env) @@ -284,6 +291,7 @@ def _run_cluster(self, model, envs): ps0.terminate() ps1.terminate() + shutil.rmtree(gloo_path) return 0, 0 def check_with_place(self, @@ -313,6 +321,9 @@ def runtime_main(test_class): parser.add_argument( '--role', type=str, required=True, choices=['pserver', 'trainer']) parser.add_argument('--endpoints', type=str, required=False, default="") + parser.add_argument( + '--trainer_endpoints', type=str, required=False, default="") + parser.add_argument('--gloo_path', type=str, required=False, default="") parser.add_argument('--current_id', type=int, required=False, default=0) parser.add_argument('--trainers', type=int, required=False, default=1) parser.add_argument('--mode', type=str, required=False, default='geo') @@ -322,6 +333,13 @@ def runtime_main(test_class): args = parser.parse_args() model = test_class() + role = model.build_role(args) + fleet.init(role) + strategy = model.build_strategy(args) + avg_cost = model.net(args) + model.build_optimizer(avg_cost, strategy) + fleet_util._set_strategy(strategy) + fleet_util._set_role_maker(role) if args.role == "pserver": model.run_pserver(args) else: diff --git a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py index 5fc37335b2153..18629c4f996a6 100644 --- a/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py +++ b/python/paddle/fluid/tests/unittests/test_dist_fleet_ctr.py @@ -22,7 +22,7 @@ class TestDistMnistSync2x2(TestFleetBase): def _setup_config(self): - self._mode = "sync" + self._mode = "async" self._reader = "pyreader" def check_with_place(self, diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py index dd5cd715ecd1e..a91f6cbd69e18 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_4.py @@ -40,10 +40,9 @@ def test_pslib_1(self): from paddle.fluid.incubate.fleet.parameter_server.pslib import PSLib from paddle.fluid.incubate.fleet.base.role_maker import \ GeneralRoleMaker - from paddle.fluid.incubate.fleet.utils.http_server import KVHandler - from paddle.fluid.incubate.fleet.utils.http_server import KVServer - from paddle.fluid.incubate.fleet.utils.http_server import \ - KVHTTPServer + from paddle.fleet.utils import KVHandler + from paddle.fleet.utils import KVServer + from paddle.fleet.utils import KVHTTPServer except: print("warning: no fleet, skip test_pslib_4") return diff --git a/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py new file mode 100644 index 0000000000000..659cc34b54958 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fleet_rolemaker_new.py @@ -0,0 +1,171 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""Test cloud role maker.""" + +from __future__ import print_function +import os +import unittest +import paddle.fleet.base.role_maker as role_maker + + +class TestRoleMakerBase(unittest.TestCase): + """ + Test cases for RoleMakerBase + """ + + def test_rolemaker_base(self): + role = role_maker.RoleMakerBase() + self.assertRaises(Exception, role.is_worker) + self.assertRaises(Exception, role.is_server) + self.assertRaises(Exception, role.is_first_worker) + self.assertRaises(Exception, role.worker_num) + self.assertRaises(Exception, role.server_num) + self.assertRaises(Exception, role.worker_index) + self.assertRaises(Exception, role.server_index) + self.assertRaises(Exception, role.role_id) + + trainer_endpoints = role.get_trainer_endpoints() + self.assertTrue(len(trainer_endpoints) == 0) + pserver_endpoints = role.get_pserver_endpoints() + self.assertTrue(len(pserver_endpoints) == 0) + + print(role.to_string()) + self.assertTrue(role._all_gather(role._node_type_comm, 1) is None) + self.assertTrue(role._all_reduce(role._node_type_comm, 1) is None) + role._barrier(role._node_type_comm) + + +class TestCloudRoleMaker(unittest.TestCase): + """ + Test cases for PaddleCloudRoleMaker. + """ + + def setUp(self): + """Set up, set envs.""" + os.environ["PADDLE_TRAINERS_NUM"] = "2" + os.environ[ + "PADDLE_PSERVERS_IP_PORT_LIST"] = "127.0.0.1:36001,127.0.0.2:36001" + os.environ[ + "PADDLE_TRAINER_ENDPOINTS"] = "127.0.0.1:36001,127.0.0.2:36001" + os.environ["POD_IP"] = "127.0.0.1" + + def test_tr_rolemaker(self): + """Test tr rolenamer.""" + os.environ["TRAINING_ROLE"] = "TRAINER" + os.environ["PADDLE_TRAINER_ID"] = "0" + + try: + import netifaces + except: + print("warning: no netifaces, skip test_tr_rolemaker") + return + + ro = role_maker.PaddleCloudRoleMaker( + is_collective=False, init_gloo=False) + self.assertTrue(ro.is_worker()) + self.assertFalse(ro.is_server()) + self.assertEqual(ro.worker_num(), 2) + self.assertTrue(ro.is_first_worker()) + worker_endpoints = ro.get_trainer_endpoints() + self.assertEqual(worker_endpoints[0], '127.0.0.1:36001') + self.assertEqual(ro.role_id(), 0) + + def test_tr_rolemaker_collective(self): + ro = role_maker.PaddleCloudRoleMaker(is_collective=True) + self.assertEqual(ro.worker_num(), 2) + + def test_ps_rolemaker(self): + """Test ps rolemaker.""" + os.environ["TRAINING_ROLE"] = "PSERVER" + os.environ["POD_IP"] = "127.0.0.1" + os.environ["PADDLE_PORT"] = "36001" + + try: + import netifaces + except: + print("warning: no netifaces, skip test_ps_rolemaker") + return + + ro = role_maker.PaddleCloudRoleMaker( + is_collective=False, init_gloo=False) + self.assertEqual(ro.server_index(), 0) + self.assertFalse(ro.is_worker()) + self.assertTrue(ro.is_server()) + self.assertEqual(ro.server_num(), 2) + pserver_endpoints = ro.get_pserver_endpoints() + self.assertEqual(pserver_endpoints[0], '127.0.0.1:36001') + self.assertTrue(ro._all_gather(ro._all_comm, 1) is None) + self.assertTrue(ro._all_reduce(ro._all_comm, 1) is None) + + def test_traing_role(self): + """Test training role.""" + os.environ["TRAINING_ROLE"] = "TEST" + try: + import netifaces + except: + print("warning: no netifaces, skip test_training_role") + return + + ro = role_maker.PaddleCloudRoleMaker(is_collective=False) + self.assertRaises(ValueError, ro.generate_role) + + +class TestUserDefinedRoleMaker(unittest.TestCase): + """ + Test cases for UserDefinedRoleMaker. + """ + + def setUp(self): + pass + + def test_ps_rolemaker(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_ps_rolemaker") + return + + ro = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + server_endpoints="127.0.0.1:36001,127.0.0.1:36001", + role=role_maker.Role.SERVER, + current_id=0, + worker_num=2) + self.assertEqual(ro.server_num(), 2) + ro.generate_role() + self.assertTrue(ro.is_server()) + self.assertEqual(ro.role_id(), 0) + + def test_tr_rolemaker(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_tr_rolemaker") + return + + ro = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + server_endpoints="127.0.0.1:36001,127.0.0.1:36001", + role=role_maker.Role.WORKER, + current_id=0, + worker_num=2) + self.assertIn("127.0.0.1:36001", ro.get_pserver_endpoints()) + self.assertTrue(ro.is_worker()) + self.assertEqual(ro.role_id(), 0) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_util.py b/python/paddle/fluid/tests/unittests/test_fleet_util.py index 427e077416e97..e52cb5f920c2e 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_util.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_util.py @@ -12,12 +12,27 @@ # See the License for the specific language governing permissions and # limitations under the License. -import unittest +from __future__ import print_function import paddle +import paddle.fluid as fluid +import unittest +import numpy as np +import tarfile +import tempfile import os +import sys +from paddle.dataset.common import download, DATA_HOME +from paddle.fleet.base.util_factory import fleet_util +import paddle.fleet.base.role_maker as role_maker class TestFleetUtil(unittest.TestCase): + proto_data_url = "https://fleet.bj.bcebos.com/fleet_util_data.tgz" + proto_data_md5 = "59b7f12fd9dc24b64ae8e4629523a92a" + module_name = "fleet_util_data" + pruned_dir = os.path.join("fleet_util_data", "pruned_model") + train_dir = os.path.join("fleet_util_data", "train_program") + def test_util_base(self): import paddle.fleet as fleet util = fleet.UtilBase() @@ -65,6 +80,262 @@ def get_user_id(self): user_id = fleet.util.get_user_id() self.assertEqual(user_id, 10) + def test_fs(self): + from paddle.fleet.utils import LocalFS + fs = LocalFS() + dirs, files = fs.ls_dir("test_tmp") + dirs, files = fs.ls_dir("./") + self.assertFalse(fs.need_upload_download()) + fleet_util.set_file_system(fs) + + def test_barrier(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_barrier") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_barrier", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.SERVER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + fleet_util.barrier("worker") + + def test_all_reduce(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_all_reduce") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.WORKER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + output = fleet_util.all_reduce(1, "sum", comm_world="server") + print(output) + + # self.assertEqual(output, 1) + + def test_all_gather(self): + try: + import netifaces + except: + print("warning: no netifaces, skip test_all_gather") + return + + gloo = fluid.core.Gloo() + gloo.set_rank(0) + gloo.set_size(1) + gloo.set_prefix("123") + gloo.set_iface("lo") + gloo.set_hdfs_store("./tmp_test_fleet_reduce", "", "") + gloo.init() + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.SERVER, + worker_endpoints=["127.0.0.1:6003"], + server_endpoints=["127.0.0.1:6001"]) + role._node_type_comm = gloo + role._all_comm = gloo + role._role_is_generated = True + fleet_util._set_role_maker(role) + + output = fleet_util.all_gather(1, comm_world="all") + print(output) + # self.assertTrue(len(output) == 1 and output[0] == 1) + self.assertRaises(Exception, fleet_util.all_gather, 1, "test") + + def download_files(self): + path = download(self.proto_data_url, self.module_name, + self.proto_data_md5) + print('data is downloaded at ' + path) + tar = tarfile.open(path) + unzip_folder = tempfile.mkdtemp() + tar.extractall(unzip_folder) + return unzip_folder + + def test_get_file_shard(self): + self.assertRaises(Exception, fleet_util.get_file_shard, "files") + try: + import netifaces + except: + print("warning: no netifaces, skip test_get_file_shard") + return + + role = role_maker.UserDefinedRoleMaker( + is_collective=False, + init_gloo=False, + current_id=0, + role=role_maker.Role.WORKER, + worker_endpoints=["127.0.0.1:6003", "127.0.0.1:6004"], + server_endpoints=["127.0.0.1:6001", "127.0.0.1:6002"]) + fleet_util._set_role_maker(role) + files = fleet_util.get_file_shard(["1", "2", "3"]) + self.assertTrue(len(files) == 2 and "1" in files and "2" in files) + + def test_program_type_trans(self): + data_dir = self.download_files() + program_dir = os.path.join(data_dir, self.pruned_dir) + text_program = "pruned_main_program.pbtxt" + binary_program = "pruned_main_program.bin" + text_to_binary = fleet_util._program_type_trans(program_dir, + text_program, True) + binary_to_text = fleet_util._program_type_trans(program_dir, + binary_program, False) + self.assertTrue( + os.path.exists(os.path.join(program_dir, text_to_binary))) + self.assertTrue( + os.path.exists(os.path.join(program_dir, binary_to_text))) + + def test_prams_check(self): + data_dir = self.download_files() + + class config: + pass + + feed_config = config() + feed_config.feeded_vars_names = ['concat_1.tmp_0', 'concat_2.tmp_0'] + feed_config.feeded_vars_dims = [682, 1199] + feed_config.feeded_vars_types = [np.float32, np.float32] + feed_config.feeded_vars_filelist = [ + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_1")), + os.path.join(data_dir, os.path.join(self.pruned_dir, "concat_2")) + ] + + fetch_config = config() + fetch_config.fetch_vars_names = ['similarity_norm.tmp_0'] + + conf = config() + conf.batch_size = 1 + conf.feed_config = feed_config + conf.fetch_config = fetch_config + conf.dump_model_dir = os.path.join(data_dir, self.pruned_dir) + conf.dump_program_filename = "pruned_main_program.pbtxt" + conf.is_text_dump_program = True + conf.save_params_filename = None + + # test saved var's shape + conf.dump_program_filename = "pruned_main_program.save_var_shape_not_match" + + self.assertRaises(Exception, fleet_util._params_check) + + # test program.proto without feed_op and fetch_op + conf.dump_program_filename = "pruned_main_program.no_feed_fetch" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test feed_var's shape + conf.dump_program_filename = "pruned_main_program.feed_var_shape_not_match" + self.assertRaises(Exception, fleet_util._params_check) + + # test correct case with feed_vars_filelist + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + np.testing.assert_array_almost_equal( + results[0], np.array( + [[3.0590223e-07]], dtype=np.float32)) + + # test correct case without feed_vars_filelist + conf.feed_config.feeded_vars_filelist = None + # test feed var with lod_level >= 2 + conf.dump_program_filename = "pruned_main_program.feed_lod2" + self.assertRaises(Exception, fleet_util._params_check) + + conf.dump_program_filename = "pruned_main_program.pbtxt" + results = fleet_util._params_check(conf) + self.assertTrue(len(results) == 1) + + def test_proto_check(self): + data_dir = self.download_files() + + class config: + pass + + conf = config() + conf.train_prog_path = os.path.join( + data_dir, os.path.join(self.train_dir, "join_main_program.pbtxt")) + conf.is_text_train_program = True + + # test not match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, + "pruned_main_program.save_var_shape_not_match")) + conf.is_text_pruned_program = True + conf.draw = False + res = fleet_util._proto_check(conf) + self.assertFalse(res) + + # test match + conf.pruned_prog_path = os.path.join( + data_dir, + os.path.join(self.pruned_dir, "pruned_main_program.pbtxt")) + if sys.platform == 'win32' or sys.platform == 'sys.platform': + conf.draw = False + else: + conf.draw = True + conf.draw_out_name = "pruned_check" + res = fleet_util._proto_check(conf) + self.assertTrue(res) + + def test_visualize(self): + if sys.platform == 'win32' or sys.platform == 'sys.platform': + pass + else: + data_dir = self.download_files() + program_path = os.path.join( + data_dir, + os.path.join(self.train_dir, "join_main_program.pbtxt")) + is_text = True + program = fleet_util._load_program(program_path, is_text) + output_dir = os.path.join(data_dir, self.train_dir) + output_filename = "draw_prog" + fleet_util._visualize_graphviz(program, output_dir, output_filename) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename + ".dot"))) + self.assertTrue( + os.path.exists( + os.path.join(output_dir, output_filename + ".pdf"))) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fs_interface.py b/python/paddle/fluid/tests/unittests/test_fs_interface.py index 0d87b94538f05..7f780bd44f7e2 100644 --- a/python/paddle/fluid/tests/unittests/test_fs_interface.py +++ b/python/paddle/fluid/tests/unittests/test_fs_interface.py @@ -20,9 +20,7 @@ import sys import inspect -from paddle.fluid.incubate.fleet.utils.fs import LocalFS, FS -from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient -from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError +from paddle.fleet.utils import LocalFS, FS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError class FSTest(unittest.TestCase): diff --git a/python/paddle/fluid/tests/unittests/test_hdfs.py b/python/paddle/fluid/tests/unittests/test_hdfs.py index 9826542cee373..80c7fd4ad57d1 100644 --- a/python/paddle/fluid/tests/unittests/test_hdfs.py +++ b/python/paddle/fluid/tests/unittests/test_hdfs.py @@ -19,9 +19,7 @@ import os import sys -from paddle.fluid.incubate.fleet.utils.fs import LocalFS -from paddle.fluid.incubate.fleet.utils.hdfs import HDFSClient -from paddle.fluid.incubate.fleet.utils.hdfs import FSTimeOut, FSFileExistsError, FSFileNotExistsError +from paddle.fleet.utils import LocalFS, HDFSClient, FSTimeOut, FSFileExistsError, FSFileNotExistsError java_home = os.environ["JAVA_HOME"] diff --git a/python/requirements.txt b/python/requirements.txt index 5e081f5e85b6e..13a1c9a9d638d 100644 --- a/python/requirements.txt +++ b/python/requirements.txt @@ -21,3 +21,4 @@ prettytable objgraph astor pathlib +netifaces diff --git a/python/setup.py.in b/python/setup.py.in index df200da2cfc5b..72819a7b9eed3 100644 --- a/python/setup.py.in +++ b/python/setup.py.in @@ -152,6 +152,7 @@ packages=['paddle', 'paddle.fleet.dataset', 'paddle.fleet.metrics', 'paddle.fleet.proto', + 'paddle.fleet.utils', 'paddle.framework', 'paddle.fluid', 'paddle.fluid.dygraph',