Skip to content

Commit

Permalink
[Java] Add runtime context (ray-project#4194)
Browse files Browse the repository at this point in the history
  • Loading branch information
jovany-wang authored and raulchen committed Mar 5, 2019
1 parent c73d508 commit a116b7f
Show file tree
Hide file tree
Showing 12 changed files with 238 additions and 2 deletions.
7 changes: 7 additions & 0 deletions java/api/src/main/java/org/ray/api/Ray.java
Original file line number Diff line number Diff line change
Expand Up @@ -120,4 +120,11 @@ public static <T> WaitResult<T> wait(List<RayObject<T>> waitList) {
public static RayRuntime internal() {
return runtime;
}

/**
* Get the runtime context.
*/
public static RuntimeContext getRuntimeContext() {
return runtime.getRuntimeContext();
}
}
46 changes: 46 additions & 0 deletions java/api/src/main/java/org/ray/api/RuntimeContext.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package org.ray.api;

import org.ray.api.id.UniqueId;

/**
* A class used for getting information of Ray runtime.
*/
public interface RuntimeContext {

/**
* Get the current Driver ID.
*
* If called in a driver, this returns the driver ID. If called in a worker, this returns the ID
* of the associated driver.
*/
UniqueId getCurrentDriverId();

/**
* Get the current actor ID.
*
* Note, this can only be called in actors.
*/
UniqueId getCurrentActorId();

/**
* Returns true if the current actor was reconstructed, false if it's created for the first time.
*
* Note, this method should only be called from an actor creation task.
*/
boolean wasCurrentActorReconstructed();

/**
* Get the raylet socket name.
*/
String getRayletSocketName();

/**
* Get the object store socket name.
*/
String getObjectStoreSocketName();

/**
* Return true if Ray is running in single-process mode, false if Ray is running in cluster mode.
*/
boolean isSingleProcess();
}
3 changes: 3 additions & 0 deletions java/api/src/main/java/org/ray/api/runtime/RayRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import java.util.List;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.function.RayFunc;
import org.ray.api.id.UniqueId;
Expand Down Expand Up @@ -93,4 +94,6 @@ public interface RayRuntime {
*/
<T> RayActor<T> createActor(RayFunc actorFactoryFunc, Object[] args,
ActorCreationOptions options);

RuntimeContext getRuntimeContext();
}
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import java.util.stream.Collectors;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.RuntimeContext;
import org.ray.api.WaitResult;
import org.ray.api.exception.RayException;
import org.ray.api.function.RayFunc;
Expand Down Expand Up @@ -61,13 +62,15 @@ public abstract class AbstractRayRuntime implements RayRuntime {
protected RayletClient rayletClient;
protected ObjectStoreProxy objectStoreProxy;
protected FunctionManager functionManager;
protected RuntimeContext runtimeContext;

public AbstractRayRuntime(RayConfig rayConfig) {
this.rayConfig = rayConfig;
functionManager = new FunctionManager(rayConfig.driverResourcePath);
worker = new Worker(this);
workerContext = new WorkerContext(rayConfig.workerMode,
rayConfig.driverId, rayConfig.runMode);
runtimeContext = new RuntimeContextImpl(this);
}

/**
Expand Down Expand Up @@ -346,4 +349,9 @@ public FunctionManager getFunctionManager() {
public RayConfig getRayConfig() {
return rayConfig;
}

public RuntimeContext getRuntimeContext() {
return runtimeContext;
}

}
18 changes: 18 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,22 @@ List<Checkpoint> getCheckpointsForActor(UniqueId actorId) {
checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp));
return checkpoints;
}


/**
* Query whether the actor exists in Gcs.
*/
boolean actorExistsInGcs(UniqueId actorId) {
byte[] key = ArrayUtils.addAll("ACTOR".getBytes(), actorId.getBytes());

// TODO(qwang): refactor this with `GlobalState` after this issue
// getting finished. https://github.com/ray-project/ray/issues/3933
for (RedisClient client : redisClients) {
if (client.exists(key)) {
return true;
}
}

return false;
}
}
56 changes: 56 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package org.ray.runtime;

import com.google.common.base.Preconditions;
import org.ray.api.RuntimeContext;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;

public class RuntimeContextImpl implements RuntimeContext {

private AbstractRayRuntime runtime;

public RuntimeContextImpl(AbstractRayRuntime runtime) {
this.runtime = runtime;
}

@Override
public UniqueId getCurrentDriverId() {
return runtime.getWorkerContext().getCurrentDriverId();
}

@Override
public UniqueId getCurrentActorId() {
Preconditions.checkState(runtime.rayConfig.workerMode == WorkerMode.WORKER);
return runtime.getWorker().getCurrentActorId();
}

@Override
public boolean wasCurrentActorReconstructed() {
TaskSpec currentTask = runtime.getWorkerContext().getCurrentTask();
Preconditions.checkState(currentTask != null && currentTask.isActorCreationTask(),
"This method can only be called from an actor creation task.");
if (isSingleProcess()) {
return false;
}

return ((RayNativeRuntime) runtime).actorExistsInGcs(getCurrentActorId());
}

@Override
public String getRayletSocketName() {
return runtime.getRayConfig().rayletSocketName;
}

@Override
public String getObjectStoreSocketName() {
return runtime.getRayConfig().objectStoreSocketName;
}

@Override
public boolean isSingleProcess() {
return RunMode.SINGLE_PROCESS == runtime.getRayConfig().runMode;
}

}
12 changes: 10 additions & 2 deletions java/runtime/src/main/java/org/ray/runtime/Worker.java
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,10 @@ public Worker(AbstractRayRuntime runtime) {
this.runtime = runtime;
}

public UniqueId getCurrentActorId() {
return currentActorId;
}

public void loop() {
while (true) {
LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName());
Expand All @@ -86,6 +90,11 @@ public void execute(TaskSpec spec) {
// Set context
runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader);
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);

if (spec.isActorCreationTask()) {
currentActorId = returnId;
}

// Get local actor object and arguments.
Object actor = null;
if (spec.isActorTask()) {
Expand All @@ -94,6 +103,7 @@ public void execute(TaskSpec spec) {
throw actorCreationException;
}
actor = currentActor;

}
Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader);
// Execute the task.
Expand All @@ -112,7 +122,6 @@ public void execute(TaskSpec spec) {
} else {
maybeLoadCheckpoint(result, returnId);
currentActor = result;
currentActorId = returnId;
}
LOGGER.info("Finished executing task {}", spec.taskId);
} catch (Exception e) {
Expand All @@ -121,7 +130,6 @@ public void execute(TaskSpec spec) {
runtime.put(returnId, new RayTaskException("Error executing task " + spec, e));
} else {
actorCreationException = e;
currentActorId = returnId;
}
} finally {
Thread.currentThread().setContextClassLoader(oldLoader);
Expand Down
10 changes: 10 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ public class WorkerContext {
*/
private ThreadLocal<Integer> taskIndex;

private ThreadLocal<TaskSpec> currentTask;

private UniqueId currentDriverId;

private ClassLoader currentClassLoader;
Expand All @@ -46,6 +48,7 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode)
putIndex = ThreadLocal.withInitial(() -> 0);
currentTaskId = ThreadLocal.withInitial(UniqueId::randomId);
this.runMode = runMode;
currentTask = ThreadLocal.withInitial(() -> null);
currentClassLoader = null;
if (workerMode == WorkerMode.DRIVER) {
workerId = driverId;
Expand Down Expand Up @@ -83,6 +86,7 @@ public void setCurrentTask(TaskSpec task, ClassLoader classLoader) {
this.currentDriverId = task.driverId;
taskIndex.set(0);
putIndex.set(0);
this.currentTask.set(task);
currentClassLoader = classLoader;
}

Expand Down Expand Up @@ -124,4 +128,10 @@ public ClassLoader getCurrentClassLoader() {
return currentClassLoader;
}

/**
* Get the current task.
*/
public TaskSpec getCurrentTask() {
return this.currentTask.get();
}
}
10 changes: 10 additions & 0 deletions java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,14 @@ public List<String> lrange(String key, long start, long end) {
return jedis.lrange(key, start, end);
}
}

/**
* Whether the key exists in Redis.
*/
public boolean exists(byte[] key) {
try (Jedis jedis = jedisPool.getResource()) {
return jedis.exists(key);
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,10 @@ private String buildWorkerCommandRaylet() {
cmd.add("-Dray.logging.file.path=" + logFile);
}

// socket names
cmd.add("-Dray.raylet.socket-name=" + rayConfig.rayletSocketName);
cmd.add("-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName);

// Config overwrite
cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress());

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,16 @@ public static class Counter {

protected int value = 0;

private boolean wasCurrentActorReconstructed = false;

public Counter() {
wasCurrentActorReconstructed = Ray.getRuntimeContext().wasCurrentActorReconstructed();
}

public boolean wasCurrentActorReconstructed() {
return wasCurrentActorReconstructed;
}

public int increase() {
value += 1;
return value;
Expand All @@ -48,6 +58,8 @@ public void testActorReconstruction() throws InterruptedException, IOException {
Ray.call(Counter::increase, actor).get();
}

Assert.assertFalse(Ray.call(Counter::wasCurrentActorReconstructed, actor).get());

// Kill the actor process.
int pid = Ray.call(Counter::getPid, actor).get();
Runtime.getRuntime().exec("kill -9 " + pid);
Expand All @@ -58,6 +70,8 @@ public void testActorReconstruction() throws InterruptedException, IOException {
int value = Ray.call(Counter::increase, actor).get();
Assert.assertEquals(value, 4);

Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get());

// Kill the actor process again.
pid = Ray.call(Counter::getPid, actor).get();
Runtime.getRuntime().exec("kill -9 " + pid);
Expand Down
52 changes: 52 additions & 0 deletions java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package org.ray.api.test;

import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.annotation.RayRemote;
import org.ray.api.id.UniqueId;
import org.testng.Assert;
import org.testng.annotations.Test;

public class RuntimeContextTest extends BaseTest {

private static UniqueId DRIVER_ID =
UniqueId.fromHexString("0011223344556677889900112233445566778899");
private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket";
private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket";

@Override
public void beforeInitRay() {
System.setProperty("ray.driver.id", DRIVER_ID.toString());
System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME);
System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME);
}

@Test
public void testRuntimeContextInDriver() {
Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
}

@RayRemote
public static class RuntimeContextTester {

public String testRuntimeContext(UniqueId actorId) {
Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId());
Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId());
Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName());
Assert.assertEquals(OBJECT_STORE_SOCKET_NAME,
Ray.getRuntimeContext().getObjectStoreSocketName());
return "ok";
}
}

@Test
public void testRuntimeContextInActor() {
RayActor<RuntimeContextTester> actor = Ray.createActor(RuntimeContextTester::new);
Assert.assertEquals("ok",
Ray.call(RuntimeContextTester::testRuntimeContext, actor, actor.getId()).get());
}

}

0 comments on commit a116b7f

Please sign in to comment.