diff --git a/python/ray/scripts/scripts.py b/python/ray/scripts/scripts.py index 13aea1e57db7..1f2ba6c4357a 100644 --- a/python/ray/scripts/scripts.py +++ b/python/ray/scripts/scripts.py @@ -226,6 +226,7 @@ def start(node_ip_address, redis_address, redis_port, num_redis_shards, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_protected_mode=False, include_webui=(not no_ui), plasma_directory=plasma_directory, huge_pages=huge_pages, diff --git a/python/ray/services.py b/python/ray/services.py index 90380c9c89a2..2087e79e2e6a 100644 --- a/python/ray/services.py +++ b/python/ray/services.py @@ -414,6 +414,7 @@ def start_redis(node_ip_address, redirect_output=False, redirect_worker_output=False, cleanup=True, + protected_mode=False, use_credis=None): """Start the Redis global state store. @@ -466,7 +467,8 @@ def start_redis(node_ip_address, redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + protected_mode=protected_mode) else: assigned_port, _ = _start_redis_instance( node_ip_address=node_ip_address, @@ -475,6 +477,7 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, + protected_mode=protected_mode, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray module, # as the latter contains an extern declaration that the former @@ -516,7 +519,8 @@ def start_redis(node_ip_address, redis_max_clients=redis_max_clients, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, - cleanup=cleanup) + cleanup=cleanup, + protected_mode=protected_mode) else: assert num_redis_shards == 1, \ "For now, RAY_USE_NEW_GCS supports 1 shard, and credis "\ @@ -528,6 +532,7 @@ def start_redis(node_ip_address, stdout_file=redis_stdout_file, stderr_file=redis_stderr_file, cleanup=cleanup, + protected_mode=protected_mode, executable=CREDIS_EXECUTABLE, # It is important to load the credis module BEFORE the ray # module, as the latter contains an extern declaration that the @@ -553,6 +558,22 @@ def start_redis(node_ip_address, return redis_address, redis_shards +def _make_temp_redis_config(node_ip_address): + """Create a configuration file for Redis. + + Args: + node_ip_address: The IP address of this node. This should not be + 127.0.0.1. + """ + redis_config_name = "/tmp/redis_conf{}".format(random_name()) + with open(redis_config_name, 'w') as f: + # This allows redis clients on the same machine to connect using the + # node's IP address as opposed to just 127.0.0.1. This is only relevant + # when the server is in protected mode. + f.write("bind 127.0.0.1 {}".format(node_ip_address)) + return redis_config_name + + def _start_redis_instance(node_ip_address="127.0.0.1", port=None, redis_max_clients=None, @@ -560,6 +581,7 @@ def _start_redis_instance(node_ip_address="127.0.0.1", stdout_file=None, stderr_file=None, cleanup=True, + protected_mode=False, executable=REDIS_EXECUTABLE, modules=None): """Start a single Redis server. @@ -579,6 +601,10 @@ def _start_redis_instance(node_ip_address="127.0.0.1", cleanup (bool): True if using Ray in local mode. If cleanup is true, then this process will be killed by serices.cleanup() when the Python process that imported services exits. + protected_mode: True if we should start the Redis server in protected + mode. This will prevent clients on other machines from connecting + and is only used when the Redis servers are started via ray.init() + as opposed to ray start. executable (str): Full path tho the redis-server executable. modules (list of str): A list of pathnames, pointing to the redis module(s) that will be loaded in this redis server. If None, load @@ -604,6 +630,9 @@ def _start_redis_instance(node_ip_address="127.0.0.1", else: port = new_port() + if protected_mode: + redis_config_filename = _make_temp_redis_config(node_ip_address) + load_module_args = [] for module in modules: load_module_args += ["--loadmodule", module] @@ -611,8 +640,14 @@ def _start_redis_instance(node_ip_address="127.0.0.1", while counter < num_retries: if counter > 0: logger.warning("Redis failed to start, retrying now.") - command = [executable, "--port", - str(port), "--loglevel", "warning"] + load_module_args + + # Construct the command to start the Redis server. + command = [executable] + if protected_mode: + command += [redis_config_filename] + command += ( + ["--port", str(port), "--loglevel", "warning"] + load_module_args) + p = subprocess.Popen(command, stdout=stdout_file, stderr=stderr_file) time.sleep(0.1) # Check if Redis successfully started (or at least if it the executable @@ -634,9 +669,12 @@ def _start_redis_instance(node_ip_address="127.0.0.1", # Configure Redis to generate keyspace notifications. TODO(rkn): Change # this to only generate notifications for the export keys. redis_client.config_set("notify-keyspace-events", "Kl") + # Configure Redis to not run in protected mode so that processes on other # hosts can connect to it. TODO(rkn): Do this in a more secure way. - redis_client.config_set("protected-mode", "no") + if not protected_mode: + redis_client.config_set("protected-mode", "no") + # If redis_max_clients is provided, attempt to raise the number of maximum # number of Redis clients. if redis_max_clients is not None: @@ -1259,6 +1297,7 @@ def start_ray_processes(address_info=None, object_store_memory=None, num_redis_shards=1, redis_max_clients=None, + redis_protected_mode=False, worker_path=None, cleanup=True, redirect_worker_output=False, @@ -1298,6 +1337,9 @@ def start_ray_processes(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_protected_mode: True if we should start Redis in protected mode. + This will prevent clients from other machines from connecting and + is only done when Redis is started via ray.init(). worker_path (str): The path of the source code that will be run by the worker. cleanup (bool): If cleanup is true, then the processes started here @@ -1373,7 +1415,8 @@ def start_ray_processes(address_info=None, use_raylet=use_raylet, redirect_output=True, redirect_worker_output=redirect_worker_output, - cleanup=cleanup) + cleanup=cleanup, + protected_mode=redis_protected_mode) address_info["redis_address"] = redis_address time.sleep(0.1) @@ -1653,6 +1696,7 @@ def start_ray_head(address_info=None, resources=None, num_redis_shards=None, redis_max_clients=None, + redis_protected_mode=False, include_webui=True, plasma_directory=None, huge_pages=False, @@ -1698,6 +1742,9 @@ def start_ray_head(address_info=None, the primary Redis shard. redis_max_clients: If provided, attempt to configure Redis with this maxclients number. + redis_protected_mode: True if we should start Redis in protected mode. + This will prevent clients from other machines from connecting and + is only done when Redis is started via ray.init(). include_webui: True if the UI should be started and false otherwise. plasma_directory: A directory where the Plasma memory mapped files will be created. @@ -1731,6 +1778,7 @@ def start_ray_head(address_info=None, resources=resources, num_redis_shards=num_redis_shards, redis_max_clients=redis_max_clients, + redis_protected_mode=redis_protected_mode, plasma_directory=plasma_directory, huge_pages=huge_pages, autoscaling_config=autoscaling_config, diff --git a/python/ray/worker.py b/python/ray/worker.py index dcbaff440850..d05aff40cc64 100644 --- a/python/ray/worker.py +++ b/python/ray/worker.py @@ -1728,6 +1728,7 @@ def init(redis_address=None, num_custom_resource=None, num_redis_shards=None, redis_max_clients=None, + redis_protected_mode=True, plasma_directory=None, huge_pages=False, include_webui=True,