From c403ab11ab714c16255142ec69005eadb1e3f9ff Mon Sep 17 00:00:00 2001 From: Stephanie Wang Date: Wed, 28 Dec 2016 14:17:29 -0800 Subject: [PATCH] Allow ray.init to take in address information about existing services. (#161) * Refactor ray.init and ray.services to allow processes that are already running * Fix indexing error * Address Robert's comments --- lib/python/ray/services.py | 284 +++++++++++++++++++++++++------------ lib/python/ray/worker.py | 148 ++++++++++++++----- test/array_test.py | 2 +- test/stress_tests.py | 9 +- 4 files changed, 319 insertions(+), 124 deletions(-) diff --git a/lib/python/ray/services.py b/lib/python/ray/services.py index be13720a3c32f..ec28d753133e8 100644 --- a/lib/python/ray/services.py +++ b/lib/python/ray/services.py @@ -12,6 +12,7 @@ import subprocess import sys import time +from collections import namedtuple # Ray modules import photon @@ -28,9 +29,25 @@ RUN_PLASMA_MANAGER_PROFILER = False RUN_PLASMA_STORE_PROFILER = False +# ObjectStoreAddress tuples contain all information necessary to connect to an +# object store. The fields are: +# - name: The socket name for the object store +# - manager_name: The socket name for the object store manager +# - manager_port: The Internet port that the object store manager listens on +ObjectStoreAddress = namedtuple("ObjectStoreAddress", ["name", + "manager_name", + "manager_port"]) + def address(host, port): return host + ":" + str(port) +def get_port(address): + try: + port = int(address.split(":")[1]) + except: + raise Exception("Unable to parse port from address {}".format(address)) + return port + def new_port(): return random.randint(10000, 65535) @@ -120,7 +137,7 @@ def wait_for_redis_to_start(redis_host, redis_port, num_retries=5): if counter == num_retries: raise Exception("Unable to connect to Redis. If the Redis instance is on a different machine, check that your firewall is configured properly.") -def start_redis(num_retries=20, cleanup=True, redirect_output=False): +def start_redis(node_ip_address, num_retries=20, cleanup=True, redirect_output=False): """Start a Redis server. Args: @@ -171,7 +188,8 @@ def start_redis(num_retries=20, cleanup=True, redirect_output=False): # Configure Redis to not run in protected mode so that processes on other # hosts can connect to it. TODO(rkn): Do this in a more secure way. redis_client.config_set("protected-mode", "no") - return port + redis_address = address(node_ip_address, port) + return redis_address def start_global_scheduler(redis_address, cleanup=True, redirect_output=False): """Start a global scheduler process. @@ -212,7 +230,7 @@ def start_local_scheduler(redis_address, node_ip_address, plasma_store_name, pla all_processes.append(p) return local_scheduler_name -def start_objstore(node_ip_address, redis_address, cleanup=True, redirect_output=False): +def start_objstore(node_ip_address, redis_address, cleanup=True, redirect_output=False, objstore_memory=None): """This method starts an object store process. Args: @@ -228,24 +246,26 @@ def start_objstore(node_ip_address, redis_address, cleanup=True, redirect_output A tuple of the Plasma store socket name, the Plasma manager socket name, and the plasma manager port. """ - # Compute a fraction of the system memory for the Plasma store to use. - system_memory = psutil.virtual_memory().total - if sys.platform == "linux" or sys.platform == "linux2": - # On linux we use /dev/shm, its size is half the size of the physical - # memory. To not overflow it, we set the plasma memory limit to 0.4 times - # the size of the physical memory. - plasma_store_memory = int(system_memory * 0.4) - else: - plasma_store_memory = int(system_memory * 0.75) + if objstore_memory is None: + # Compute a fraction of the system memory for the Plasma store to use. + system_memory = psutil.virtual_memory().total + if sys.platform == "linux" or sys.platform == "linux2": + # On linux we use /dev/shm, its size is half the size of the physical + # memory. To not overflow it, we set the plasma memory limit to 0.4 times + # the size of the physical memory. + objstore_memory = int(system_memory * 0.4) + else: + objstore_memory = int(system_memory * 0.75) # Start the Plasma store. - plasma_store_name, p1 = plasma.start_plasma_store(plasma_store_memory=plasma_store_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, redirect_output=redirect_output) + plasma_store_name, p1 = plasma.start_plasma_store(plasma_store_memory=objstore_memory, use_profiler=RUN_PLASMA_STORE_PROFILER, redirect_output=redirect_output) # Start the plasma manager. plasma_manager_name, p2, plasma_manager_port = plasma.start_plasma_manager(plasma_store_name, redis_address, node_ip_address=node_ip_address, run_profiler=RUN_PLASMA_MANAGER_PROFILER, redirect_output=redirect_output) if cleanup: all_processes.append(p1) all_processes.append(p2) - return plasma_store_name, plasma_manager_name, plasma_manager_port + return ObjectStoreAddress(plasma_store_name, plasma_manager_name, + plasma_manager_port) def start_worker(node_ip_address, object_store_name, object_store_manager_name, local_scheduler_name, redis_address, worker_path, cleanup=True, redirect_output=False): """This method starts a worker process. @@ -300,18 +320,28 @@ def start_webui(redis_port, cleanup=True, redirect_output=False): if cleanup: all_processes.append(p) -def start_ray_node(node_ip_address, redis_address, num_workers=0, num_local_schedulers=1, worker_path=None, cleanup=True, redirect_output=False): - """Start the Ray processes for a single node. - - This assumes that the Ray processes on some master node have already been - started. +def start_ray_processes(address_info=None, + node_ip_address="127.0.0.1", + num_workers=0, + num_local_schedulers=1, + worker_path=None, + cleanup=True, + redirect_output=False, + include_global_scheduler=False, + include_webui=True): + """Helper method to start Ray processes. Args: + address_info (dict): A dictionary with address information for processes + that have already been started. If provided, address_info will be + modified to include processes that are newly started. node_ip_address (str): The IP address of this node. - redis_address (str): The address of the Redis server. num_workers (int): The number of workers to start. - num_local_schedulers (int): The number of local schedulers to start. This is - also the number of plasma stores and plasma managers to start. + num_local_schedulers (int): The total number of local schedulers required. + This is also the total number of object stores required. This method will + start new instances of local schedulers and object stores until there are + num_local_schedulers existing instances of each, including ones already + registered with the given address_info. worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here will be @@ -319,46 +349,113 @@ def start_ray_node(node_ip_address, redis_address, num_workers=0, num_local_sche method exits. redirect_output (bool): True if stdout and stderr should be redirected to /dev/null. + include_global_scheduler (bool): If include_global_scheduler is True, then + start a global scheduler process. + include_webui (bool): If include_webui is True, then start a Web UI + process. + + Returns: + A dictionary of the address information for the processes that were + started. """ + if address_info is None: + address_info = {} + address_info["node_ip_address"] = node_ip_address + if worker_path is None: worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py") - object_store_names = [] - object_store_manager_names = [] - local_scheduler_names = [] - for _ in range(num_local_schedulers): + + # Start Redis if there isn't already an instance running. TODO(rkn): We are + # suppressing the output of Redis because on Linux it prints a bunch of + # warning messages when it starts up. Instead of suppressing the output, we + # should address the warnings. + redis_address = address_info.get("redis_address") + if redis_address is None: + redis_address = start_redis(node_ip_address, cleanup=cleanup, + redirect_output=redirect_output) + address_info["redis_address"] = redis_address + time.sleep(0.1) + redis_port = get_port(redis_address) + + # Start the global scheduler, if necessary. + if include_global_scheduler: + start_global_scheduler(redis_address, cleanup=cleanup, + redirect_output=redirect_output) + + # Initialize with existing services. + if "object_store_addresses" not in address_info: + address_info["object_store_addresses"] = [] + object_store_addresses = address_info["object_store_addresses"] + if "local_scheduler_socket_names" not in address_info: + address_info["local_scheduler_socket_names"] = [] + local_scheduler_socket_names = address_info["local_scheduler_socket_names"] + + # Start any object stores that do not yet exist. + for _ in range(num_local_schedulers - len(object_store_addresses)): # Start Plasma. - object_store_name, object_store_manager_name, object_store_manager_port = start_objstore(node_ip_address, redis_address, cleanup=cleanup, redirect_output=redirect_output) - object_store_names.append(object_store_name) - object_store_manager_names.append(object_store_manager_name) + object_store_address = start_objstore(node_ip_address, redis_address, + cleanup=cleanup, + redirect_output=redirect_output) + object_store_addresses.append(object_store_address) time.sleep(0.1) + + # Start any local schedulers that do not yet exist. + for i in range(len(local_scheduler_socket_names), num_local_schedulers): + # Connect the local scheduler to the object store at the same index. + object_store_address = object_store_addresses[i] + plasma_address = "{}:{}".format(node_ip_address, + object_store_address.manager_port) # Start the local scheduler. - plasma_address = "{}:{}".format(node_ip_address, object_store_manager_port) - local_scheduler_name = start_local_scheduler(redis_address, node_ip_address, object_store_name, object_store_manager_name, plasma_address=plasma_address, cleanup=cleanup, redirect_output=redirect_output) - local_scheduler_names.append(local_scheduler_name) + local_scheduler_name = start_local_scheduler(redis_address, + node_ip_address, + object_store_address.name, + object_store_address.manager_name, + plasma_address=plasma_address, + cleanup=cleanup, + redirect_output=redirect_output) + local_scheduler_socket_names.append(local_scheduler_name) time.sleep(0.1) - # Aggregate the address information together. - address_info = {"node_ip_address": node_ip_address, - "object_store_names": object_store_names, - "object_store_manager_names": object_store_manager_names, - "local_scheduler_names": local_scheduler_names} + + # Make sure that we have exactly num_local_schedulers instances of object + # stores and local schedulers. + assert len(object_store_addresses) == num_local_schedulers + assert len(local_scheduler_socket_names) == num_local_schedulers + # Start the workers. for i in range(num_workers): - start_worker(address_info["node_ip_address"], - address_info["object_store_names"][i % num_local_schedulers], - address_info["object_store_manager_names"][i % num_local_schedulers], - address_info["local_scheduler_names"][i % num_local_schedulers], + object_store_address = object_store_addresses[i % num_local_schedulers] + local_scheduler_name = local_scheduler_socket_names[i % num_local_schedulers] + start_worker(node_ip_address, + object_store_address.name, + object_store_address.manager_name, + local_scheduler_name, redis_address, worker_path, cleanup=cleanup, redirect_output=redirect_output) + + # Start the web UI, if necessary. + if include_webui: + start_webui(redis_port, cleanup=cleanup, redirect_output=redirect_output) + # Return the addresses of the relevant processes. return address_info -def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, num_local_schedulers=1, worker_path=None, cleanup=True, redirect_output=False): - """Start Ray in local mode. +def start_ray_node(node_ip_address, + redis_address, + num_workers=0, + num_local_schedulers=1, + worker_path=None, + cleanup=True, + redirect_output=False): + """Start the Ray processes for a single node. + + This assumes that the Ray processes on some master node have already been + started. Args: node_ip_address (str): The IP address of this node. + redis_address (str): The address of the Redis server. num_workers (int): The number of workers to start. num_local_schedulers (int): The number of local schedulers to start. This is also the number of plasma stores and plasma managers to start. @@ -371,49 +468,60 @@ def start_ray_local(node_ip_address="127.0.0.1", num_workers=0, num_local_schedu /dev/null. Returns: - This returns a dictionary of the address information for the processes that - were started. + A dictionary of the address information for the processes that were + started. """ - if worker_path is None: - worker_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "workers/default_worker.py") - # Start Redis. TODO(rkn): We are suppressing the output of Redis because on - # Linux it prints a bunch of warning messages when it starts up. Instead of - # suppressing the output, we should address the warnings. - redis_port = start_redis(cleanup=cleanup, redirect_output=True) - redis_address = address(node_ip_address, redis_port) - time.sleep(0.1) - # Start the global scheduler. - start_global_scheduler(redis_address, cleanup=cleanup, redirect_output=redirect_output) - object_store_names = [] - object_store_manager_names = [] - local_scheduler_names = [] - for _ in range(num_local_schedulers): - # Start Plasma. - object_store_name, object_store_manager_name, object_store_manager_port = start_objstore(node_ip_address, redis_address, cleanup=cleanup, redirect_output=redirect_output) - object_store_names.append(object_store_name) - object_store_manager_names.append(object_store_manager_name) - time.sleep(0.1) - # Start the local scheduler. - plasma_address = "{}:{}".format(node_ip_address, object_store_manager_port) - local_scheduler_name = start_local_scheduler(redis_address, node_ip_address, object_store_name, object_store_manager_name, plasma_address=plasma_address, cleanup=cleanup, redirect_output=redirect_output) - local_scheduler_names.append(local_scheduler_name) - time.sleep(0.1) - # Aggregate the address information together. - address_info = {"node_ip_address": node_ip_address, - "redis_address": redis_address, - "object_store_names": object_store_names, - "object_store_manager_names": object_store_manager_names, - "local_scheduler_names": local_scheduler_names} - # Start the workers. - for i in range(num_workers): - start_worker(address_info["node_ip_address"], - address_info["object_store_names"][i % num_local_schedulers], - address_info["object_store_manager_names"][i % num_local_schedulers], - address_info["local_scheduler_names"][i % num_local_schedulers], - redis_address, - worker_path, - cleanup=cleanup, - redirect_output=redirect_output) - # Return the addresses of the relevant processes. - start_webui(redis_port, cleanup=cleanup, redirect_output=redirect_output) - return address_info + address_info = { + "redis_address": redis_address, + } + return start_ray_processes(address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + worker_path=worker_path, + cleanup=cleanup, + redirect_output=redirect_output, + include_webui=False) + +def start_ray_local(address_info=None, + node_ip_address="127.0.0.1", + num_workers=0, + num_local_schedulers=1, + worker_path=None, + cleanup=True, + redirect_output=False, + include_webui=True): + """Start Ray in local mode. + + Args: + address_info (dict): A dictionary with address information for processes + that have already been started. If provided, address_info will be + modified to include processes that are newly started. + node_ip_address (str): The IP address of this node. + num_workers (int): The number of workers to start. + num_local_schedulers (int): The total number of local schedulers required. + This is also the total number of object stores required. This method will + start new instances of local schedulers and object stores until there are + at least num_local_schedulers existing instances of each, including ones + already registered with the given address_info. + worker_path (str): The path of the source code that will be run by the + worker. + cleanup (bool): If cleanup is true, then the processes started here will be + killed by services.cleanup() when the Python process that called this + method exits. + redirect_output (bool): True if stdout and stderr should be redirected to + /dev/null. + + Returns: + A dictionary of the address information for the processes that were + started. + """ + return start_ray_processes(address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers, + worker_path=worker_path, + cleanup=cleanup, + redirect_output=redirect_output, + include_global_scheduler=True, + include_webui=True) diff --git a/lib/python/ray/worker.py b/lib/python/ray/worker.py index 5493c58c6600b..6bd3f89f24cab 100644 --- a/lib/python/ray/worker.py +++ b/lib/python/ray/worker.py @@ -651,13 +651,27 @@ def get_address_info_from_redis_helper(redis_address, node_ip_address): elif info[b"client_type"].decode("ascii") == "photon": local_schedulers.append(info) # Make sure that we got at one plasma manager and local scheduler. - assert len(plasma_managers) == 1 - assert len(local_schedulers) == 1 + assert len(plasma_managers) >= 1 + assert len(local_schedulers) >= 1 + # Build the address information. + object_store_addresses = [] + for manager in plasma_managers: + address = manager[b"address"].decode("ascii") + port = services.get_port(address) + object_store_addresses.append( + services.ObjectStoreAddress( + name=manager[b"store_socket_name"].decode("ascii"), + manager_name=manager[b"manager_socket_name"].decode("ascii"), + manager_port=port + ) + ) + scheduler_names = [scheduler[b"local_scheduler_socket_name"].decode("ascii") + for scheduler in local_schedulers] client_info = {"node_ip_address": node_ip_address, "redis_address": redis_address, - "store_socket_name": plasma_managers[0][b"store_socket_name"].decode("ascii"), - "manager_socket_name": plasma_managers[0][b"manager_socket_name"].decode("ascii"), - "local_scheduler_socket_name": local_schedulers[0][b"local_scheduler_socket_name"].decode("ascii")} + "object_store_addresses": object_store_addresses, + "local_scheduler_socket_names": scheduler_names, + } return client_info def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5): @@ -673,21 +687,25 @@ def get_address_info_from_redis(redis_address, node_ip_address, num_retries=5): time.sleep(1) counter += 1 -def init(node_ip_address=None, redis_address=None, start_ray_local=False, object_id_seed=None, num_workers=None, num_local_schedulers=None, driver_mode=SCRIPT_MODE): - """Either connect to an existing Ray cluster or start one and connect to it. +def _init(address_info=None, start_ray_local=False, object_id_seed=None, + num_workers=None, num_local_schedulers=None, + driver_mode=SCRIPT_MODE): + """Helper method to connect to an existing Ray cluster or start a new one. This method handles two cases. Either a Ray cluster already exists and we just attach this driver to it, or we start all of the processes associated with a Ray cluster and attach to the newly started cluster. Args: - node_ip_address (str): The IP address of the node that we are on. - redis_address (str): The address of the Redis server to connect to. This - should only be provided if start_ray_local is False. - start_ray_local (bool): If True then this will start Redis, a global - scheduler, a local scheduler, a plasma store, a plasma manager, and some - workers. It will also kill these processes when Python exits. If False, - this will attach to an existing Ray cluster. + address_info (dict): A dictionary with address information for processes in + a partially-started Ray cluster. If start_ray_local=True, any processes + not in this dictionary will be started. If provided, address_info will be + modified to include processes that are newly started. + start_ray_local (bool): If True then this will start any processes not + already in address_info, including Redis, a global scheduler, local + scheduler(s), object store(s), and worker(s). It will also kill these + processes when Python exits. If False, this will attach to an existing + Ray cluster. object_id_seed (int): Used to seed the deterministic generation of object IDs. The same value can be used across multiple runs of the same job in order to generate the object IDs in a consistent manner. However, the same @@ -709,28 +727,42 @@ def init(node_ip_address=None, redis_address=None, start_ray_local=False, object check_main_thread() if driver_mode not in [SCRIPT_MODE, PYTHON_MODE, SILENT_MODE]: raise Exception("Driver_mode must be in [ray.SCRIPT_MODE, ray.PYTHON_MODE, ray.SILENT_MODE].") + + # Get addresses of existing services. + if address_info is None: + address_info = {} + else: + assert isinstance(address_info, dict) + node_ip_address = address_info.get("node_ip_address") + redis_address = address_info.get("redis_address") + + # Start any services that do not yet exist. if driver_mode == PYTHON_MODE: # If starting Ray in PYTHON_MODE, don't start any other processes. - info = {} + pass elif start_ray_local: # In this case, we launch a scheduler, a new object store, and some workers, - # and we connect to them. - if redis_address is not None: - raise Exception("If start_ray_local=True, then redis_address cannot be provided because ray.init will start a new Redis server.") + # and we connect to them. We do not launch any processes that are already + # registered in address_info. # Use the address 127.0.0.1 in local mode. node_ip_address = "127.0.0.1" if node_ip_address is None else node_ip_address # Use 1 worker if num_workers is not provided. num_workers = 1 if num_workers is None else num_workers - # Use 1 local scheduler if num_local_schedulers is not provided. - num_local_schedulers = 1 if num_local_schedulers is None else num_local_schedulers + # Use 1 local scheduler if num_local_schedulers is not provided. If + # existing local schedulers are provided, use that count as + # num_local_schedulers. + local_schedulers = address_info.get("local_scheduler_socket_names", []) + if num_local_schedulers is None: + if len(local_schedulers) > 0: + num_local_schedulers = len(local_schedulers) + else: + num_local_schedulers = 1 # Start the scheduler, object store, and some workers. These will be killed # by the call to cleanup(), which happens when the Python script exits. - address_info = services.start_ray_local(node_ip_address=node_ip_address, num_workers=num_workers, num_local_schedulers=num_local_schedulers) - info = {"node_ip_address": node_ip_address, - "redis_address": address_info["redis_address"], - "store_socket_name": address_info["object_store_names"][0], - "manager_socket_name": address_info["object_store_manager_names"][0], - "local_scheduler_socket_name": address_info["local_scheduler_names"][0]} + address_info = services.start_ray_local(address_info=address_info, + node_ip_address=node_ip_address, + num_workers=num_workers, + num_local_schedulers=num_local_schedulers) else: if redis_address is None: raise Exception("If start_ray_local=False, then redis_address must be provided.") @@ -742,12 +774,64 @@ def init(node_ip_address=None, redis_address=None, start_ray_local=False, object if node_ip_address is None: node_ip_address = services.get_node_ip_address(redis_address) # Get the address info of the processes to connect to from Redis. - info = get_address_info_from_redis(redis_address, node_ip_address) - # Connect this driver to Redis, the object store, and the local scheduler. The - # corresponing call to disconnect will happen in the call to cleanup() when - # the Python script exits. - connect(info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker) - return info + address_info = get_address_info_from_redis(redis_address, node_ip_address) + + # Connect this driver to Redis, the object store, and the local scheduler. + # Choose the first object store and local scheduler if there are multiple. + # The corresponding call to disconnect will happen in the call to cleanup() + # when the Python script exits. + if driver_mode == PYTHON_MODE: + driver_address_info = {} + else: + driver_address_info = { + "node_ip_address": node_ip_address, + "redis_address": address_info["redis_address"], + "store_socket_name": address_info["object_store_addresses"][0].name, + "manager_socket_name": address_info["object_store_addresses"][0].manager_name, + "local_scheduler_socket_name": address_info["local_scheduler_socket_names"][0], + } + connect(driver_address_info, object_id_seed=object_id_seed, mode=driver_mode, worker=global_worker) + return address_info + +def init(node_ip_address=None, redis_address=None, start_ray_local=False, + object_id_seed=None, num_workers=None, driver_mode=SCRIPT_MODE): + """Either connect to an existing Ray cluster or start one and connect to it. + + This method handles two cases. Either a Ray cluster already exists and we + just attach this driver to it, or we start all of the processes associated + with a Ray cluster and attach to the newly started cluster. + + Args: + node_ip_address (str): The IP address of the node that we are on. + redis_address (str): The address of the Redis server to connect to. This + should only be provided if start_ray_local is False. + start_ray_local (bool): If True then this will start Redis, a global + scheduler, a local scheduler, a plasma store, a plasma manager, and some + workers. It will also kill these processes when Python exits. If False, + this will attach to an existing Ray cluster. + object_id_seed (int): Used to seed the deterministic generation of object + IDs. The same value can be used across multiple runs of the same job in + order to generate the object IDs in a consistent manner. However, the same + ID should not be used for different jobs. + num_workers (int): The number of workers to start. This is only provided if + start_ray_local is True. + driver_mode (bool): The mode in which to start the driver. This should be + one of ray.SCRIPT_MODE, ray.PYTHON_MODE, and ray.SILENT_MODE. + + Returns: + Address information about the started processes. + + Raises: + Exception: An exception is raised if an inappropriate combination of + arguments is passed in. + """ + info = { + "node_ip_address": node_ip_address, + "redis_address": redis_address, + } + return _init(address_info=info, + start_ray_local=start_ray_local, num_workers=num_workers, + driver_mode=driver_mode) def cleanup(worker=global_worker): """Disconnect the driver, and terminate any processes started in init. diff --git a/test/array_test.py b/test/array_test.py index c1c75c6eaaa21..30605d213c77c 100644 --- a/test/array_test.py +++ b/test/array_test.py @@ -66,7 +66,7 @@ def testAssemble(self): def testMethods(self): for module in [ra.core, ra.random, ra.linalg, da.core, da.random, da.linalg]: reload(module) - ray.init(start_ray_local=True, num_workers=10, num_local_schedulers=2) + ray.worker._init(start_ray_local=True, num_workers=10, num_local_schedulers=2) x = da.zeros.remote([9, 25, 51], "float") assert_equal(ray.get(da.assemble.remote(x)), np.zeros([9, 25, 51])) diff --git a/test/stress_tests.py b/test/stress_tests.py index 85c31b50c194b..06e9da6cb640b 100644 --- a/test/stress_tests.py +++ b/test/stress_tests.py @@ -13,7 +13,8 @@ def testSubmittingTasks(self): for num_local_schedulers in [1, 4]: for num_workers_per_scheduler in [4]: num_workers = num_local_schedulers * num_workers_per_scheduler - ray.init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(x): @@ -38,7 +39,8 @@ def testDependencies(self): for num_local_schedulers in [1, 4]: for num_workers_per_scheduler in [4]: num_workers = num_local_schedulers * num_workers_per_scheduler - ray.init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(x): @@ -82,7 +84,8 @@ def testWait(self): for num_local_schedulers in [1, 4]: for num_workers_per_scheduler in [4]: num_workers = num_local_schedulers * num_workers_per_scheduler - ray.init(start_ray_local=True, num_workers=num_workers, num_local_schedulers=num_local_schedulers) + ray.worker._init(start_ray_local=True, num_workers=num_workers, + num_local_schedulers=num_local_schedulers) @ray.remote def f(x):