Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable redis password in Java worker #3943

Merged
merged 5 commits into from
Feb 12, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public void start() throws Exception {
manager = new RunManager(rayConfig);
manager.startRayProcesses(true);
}
redisClient = new RedisClient(rayConfig.getRedisAddress());
redisClient = new RedisClient(rayConfig.getRedisAddress(), rayConfig.redisPassword);

// TODO(qwang): Get object_store_socket_name and raylet_socket_name from Redis.
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ public class RayConfig {
private Integer redisPort;
public final int headRedisPort;
public final int numberRedisShards;
public final String headRedisPassword;
public final String redisPassword;

public final String objectStoreSocketName;
public final Long objectStoreSize;
Expand Down Expand Up @@ -157,6 +159,8 @@ public RayConfig(Config config) {
}
headRedisPort = config.getInt("ray.redis.head-port");
numberRedisShards = config.getInt("ray.redis.shard-number");
headRedisPassword = config.getString("ray.redis.head-password");
redisPassword = config.getString("ray.redis.password");

// object store configurations
objectStoreSocketName = config.getString("ray.object-store.socket-name");
Expand Down
16 changes: 14 additions & 2 deletions java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package org.ray.runtime.gcs;

import java.util.Map;

import org.ray.runtime.util.StringUtil;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
Expand All @@ -15,6 +17,10 @@ public class RedisClient {
private JedisPool jedisPool;

public RedisClient(String redisAddress) {
this(redisAddress, null);
}

public RedisClient(String redisAddress, String password) {
String[] ipAndPort = redisAddress.split(":");
if (ipAndPort.length != 2) {
throw new IllegalArgumentException("The argument redisAddress " +
Expand All @@ -23,8 +29,14 @@ public RedisClient(String redisAddress) {

JedisPoolConfig jedisPoolConfig = new JedisPoolConfig();
jedisPoolConfig.setMaxTotal(JEDIS_POOL_SIZE);
jedisPool = new JedisPool(jedisPoolConfig, ipAndPort[0],
Integer.parseInt(ipAndPort[1]), 30000);

if (StringUtil.isNullOrEmpty(password)) {
jedisPool = new JedisPool(jedisPoolConfig,
ipAndPort[0], Integer.parseInt(ipAndPort[1]), 30000);
} else {
jedisPool = new JedisPool(jedisPoolConfig, ipAndPort[0],
Integer.parseInt(ipAndPort[1]), 30000, password);
}
}

public Long set(final String key, final String value, final String field) {
Expand Down
38 changes: 33 additions & 5 deletions java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.io.File;
import java.io.IOException;
import java.time.LocalDateTime;
Expand All @@ -16,6 +17,7 @@
import org.ray.runtime.config.RayConfig;
import org.ray.runtime.util.FileUtil;
import org.ray.runtime.util.ResourceUtil;
import org.ray.runtime.util.StringUtil;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
Expand Down Expand Up @@ -146,24 +148,29 @@ public void startRayProcesses(boolean isHead) {

private void startRedisServer() {
// start primary redis
String primary = startRedisInstance(rayConfig.nodeIp, rayConfig.headRedisPort, null);
String primary = startRedisInstance(rayConfig.nodeIp,
rayConfig.headRedisPort, rayConfig.headRedisPassword, null);
rayConfig.setRedisAddress(primary);
try (Jedis client = new Jedis("127.0.0.1", rayConfig.headRedisPort)) {
if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) {
client.auth(rayConfig.headRedisPassword);
}
client.set("UseRaylet", "1");
// Register the number of Redis shards in the primary shard, so that clients
// know how many redis shards to expect under RedisShards.
client.set("NumRedisShards", Integer.toString(rayConfig.numberRedisShards));

// start redis shards
for (int i = 0; i < rayConfig.numberRedisShards; i++) {
String shard = startRedisInstance(rayConfig.nodeIp, rayConfig.headRedisPort + i + 1, i);
String shard = startRedisInstance(rayConfig.nodeIp,
rayConfig.headRedisPort + i + 1, rayConfig.headRedisPassword, i);
client.rpush("RedisShards", shard);
}
}
}

private String startRedisInstance(String ip, int port, Integer shard) {
List<String> command = ImmutableList.of(
private String startRedisInstance(String ip, int port, String password, Integer shard) {
List<String> command = Lists.newArrayList(
rayConfig.redisServerExecutablePath,
"--protected-mode",
"no",
Expand All @@ -174,10 +181,20 @@ private String startRedisInstance(String ip, int port, Integer shard) {
"--loadmodule",
rayConfig.redisModulePath
);

if (!StringUtil.isNullOrEmpty(password)) {
command.add("--requirepass ");
command.add(password);
}

String name = shard == null ? "redis" : "redis-" + shard;
startProcess(command, null, name);

try (Jedis client = new Jedis("127.0.0.1", port)) {
if (!StringUtil.isNullOrEmpty(password)) {
client.auth(password);
}

// Configure Redis to only generate notifications for the export keys.
client.configSet("notify-keyspace-events", "Kl");
// Put a time stamp in Redis to indicate when it was started.
Expand All @@ -192,6 +209,11 @@ private void startRaylet() {
int maximumStartupConcurrency = Math.max(1,
Math.min(rayConfig.resources.getOrDefault("CPU", 0.0).intValue(), hardwareConcurrency));

String redisPasswordOption = "";
if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) {
redisPasswordOption = rayConfig.headRedisPassword;
}

// See `src/ray/raylet/main.cc` for the meaning of each parameter.
List<String> command = ImmutableList.of(
rayConfig.rayletExecutablePath,
Expand All @@ -207,7 +229,8 @@ private void startRaylet() {
ResourceUtil.getResourcesStringFromMap(rayConfig.resources),
String.join(",", rayConfig.rayletConfigParameters), // The internal config list.
buildPythonWorkerCommand(), // python worker command
buildWorkerCommandRaylet() // java worker command
buildWorkerCommandRaylet(), // java worker command
redisPasswordOption
);

startProcess(command, null, "raylet");
Expand Down Expand Up @@ -248,6 +271,11 @@ private String buildWorkerCommandRaylet() {
// Config overwrite
cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress());

// redis password
if (!StringUtil.isNullOrEmpty(rayConfig.headRedisPassword)) {
cmd.add("-Dray.redis.password=" + rayConfig.headRedisPassword);
}

cmd.addAll(rayConfig.jvmParameters);

// Main class
Expand Down
4 changes: 4 additions & 0 deletions java/runtime/src/main/resources/ray.default.conf
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,10 @@ ray {
address: ""
// If `redis.server` isn't provided, which port we should use to start redis server.
head-port: 6379
// The password used to start the redis server on the head node.
head-password: ""
// The password used to connect to the redis server.
password:""
// If `redis.server` isn't provided, how many Redis shards we should start in addition to the
// primary Redis shard. The ports of these shards will be `head-port + 1`, `head-port + 2`, etc.
shard-number: 1
Expand Down
9 changes: 9 additions & 0 deletions java/test/src/main/java/org/ray/api/test/BaseTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ public class BaseTest {
public void setUp() {
System.setProperty("ray.home", "../..");
System.setProperty("ray.resources", "CPU:4,RES-A:4");
beforeInitRay();
Ray.init();
}

Expand All @@ -20,6 +21,7 @@ public void tearDown() {
// We could not enable this until `systemInfo` enabled.
//File rayletSocketFIle = new File(Ray.systemInfo().rayletSocketName());
Ray.shutdown();
afterShutdownRay();

//remove raylet socket file
//rayletSocketFIle.delete();
Expand All @@ -29,4 +31,11 @@ public void tearDown() {
System.clearProperty("ray.resources");
}

protected void beforeInitRay() {

}

protected void afterShutdownRay() {

}
}
34 changes: 34 additions & 0 deletions java/test/src/main/java/org/ray/api/test/RedisPasswordTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
package org.ray.api.test;

import org.ray.api.Ray;
import org.ray.api.RayObject;
import org.ray.api.annotation.RayRemote;
import org.testng.Assert;
import org.testng.annotations.Test;

public class RedisPasswordTest extends BaseTest {

@Override
public void beforeInitRay() {
System.setProperty("ray.redis.head-password", "12345678");
System.setProperty("ray.redis.password", "12345678");
}

@Override
public void afterShutdownRay() {
System.clearProperty("ray.redis.head-password");
System.clearProperty("ray.redis.password");
}

@RayRemote
public static String echo(String str) {
return str;
}

@Test
public void testRedisPassword() {
RayObject<String> obj = Ray.call(RedisPasswordTest::echo, "hello");
Assert.assertEquals("hello", obj.get());
}

}
21 changes: 17 additions & 4 deletions python/ray/services.py
Original file line number Diff line number Diff line change
Expand Up @@ -1029,7 +1029,12 @@ def start_raylet(redis_address,
java_worker_options = (java_worker_options
or DEFAULT_JAVA_WORKER_OPTIONS)
java_worker_command = build_java_worker_command(
java_worker_options, redis_address, plasma_store_name, raylet_name)
java_worker_options,
redis_address,
plasma_store_name,
raylet_name,
redis_password,
)
else:
java_worker_command = ""

Expand Down Expand Up @@ -1086,8 +1091,13 @@ def start_raylet(redis_address,
return process_info


def build_java_worker_command(java_worker_options, redis_address,
plasma_store_name, raylet_name):
def build_java_worker_command(
java_worker_options,
redis_address,
plasma_store_name,
raylet_name,
redis_password,
):
"""This method assembles the command used to start a Java worker.

Args:
Expand All @@ -1096,7 +1106,7 @@ def build_java_worker_command(java_worker_options, redis_address,
plasma_store_name (str): The name of the plasma store socket to connect
to.
raylet_name (str): The name of the raylet socket to create.

redis_password (str): The password of connect to redis.
Returns:
The command string for starting Java worker.
"""
Expand All @@ -1113,6 +1123,9 @@ def build_java_worker_command(java_worker_options, redis_address,
if raylet_name is not None:
command += "-Dray.raylet.socket-name={} ".format(raylet_name)

if redis_password is not None:
command += ("-Dray.redis-password=%s", redis_password)

command += "-Dray.home={} ".format(RAY_HOME)
command += "-Dray.log-dir={} ".format(get_logs_dir_path())
command += "org.ray.runtime.runner.worker.DefaultWorker"
Expand Down