From a116b7f6468ce04c5e6aa91ac0205c2176a6ed39 Mon Sep 17 00:00:00 2001 From: Wang Qing Date: Tue, 5 Mar 2019 20:25:29 +0800 Subject: [PATCH] [Java] Add runtime context (#4194) --- java/api/src/main/java/org/ray/api/Ray.java | 7 +++ .../main/java/org/ray/api/RuntimeContext.java | 46 +++++++++++++++ .../java/org/ray/api/runtime/RayRuntime.java | 3 + .../org/ray/runtime/AbstractRayRuntime.java | 8 +++ .../org/ray/runtime/RayNativeRuntime.java | 18 ++++++ .../org/ray/runtime/RuntimeContextImpl.java | 56 +++++++++++++++++++ .../src/main/java/org/ray/runtime/Worker.java | 12 +++- .../java/org/ray/runtime/WorkerContext.java | 10 ++++ .../java/org/ray/runtime/gcs/RedisClient.java | 10 ++++ .../org/ray/runtime/runner/RunManager.java | 4 ++ .../ray/api/test/ActorReconstructionTest.java | 14 +++++ .../org/ray/api/test/RuntimeContextTest.java | 52 +++++++++++++++++ 12 files changed, 238 insertions(+), 2 deletions(-) create mode 100644 java/api/src/main/java/org/ray/api/RuntimeContext.java create mode 100644 java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java create mode 100644 java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java diff --git a/java/api/src/main/java/org/ray/api/Ray.java b/java/api/src/main/java/org/ray/api/Ray.java index 7e252274ef735..2e660e543c722 100644 --- a/java/api/src/main/java/org/ray/api/Ray.java +++ b/java/api/src/main/java/org/ray/api/Ray.java @@ -120,4 +120,11 @@ public static WaitResult wait(List> waitList) { public static RayRuntime internal() { return runtime; } + + /** + * Get the runtime context. + */ + public static RuntimeContext getRuntimeContext() { + return runtime.getRuntimeContext(); + } } diff --git a/java/api/src/main/java/org/ray/api/RuntimeContext.java b/java/api/src/main/java/org/ray/api/RuntimeContext.java new file mode 100644 index 0000000000000..45c17f36b433e --- /dev/null +++ b/java/api/src/main/java/org/ray/api/RuntimeContext.java @@ -0,0 +1,46 @@ +package org.ray.api; + +import org.ray.api.id.UniqueId; + +/** + * A class used for getting information of Ray runtime. + */ +public interface RuntimeContext { + + /** + * Get the current Driver ID. + * + * If called in a driver, this returns the driver ID. If called in a worker, this returns the ID + * of the associated driver. + */ + UniqueId getCurrentDriverId(); + + /** + * Get the current actor ID. + * + * Note, this can only be called in actors. + */ + UniqueId getCurrentActorId(); + + /** + * Returns true if the current actor was reconstructed, false if it's created for the first time. + * + * Note, this method should only be called from an actor creation task. + */ + boolean wasCurrentActorReconstructed(); + + /** + * Get the raylet socket name. + */ + String getRayletSocketName(); + + /** + * Get the object store socket name. + */ + String getObjectStoreSocketName(); + + /** + * Return true if Ray is running in single-process mode, false if Ray is running in cluster mode. + */ + boolean isSingleProcess(); +} diff --git a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java index 48ebd6fce01ae..905bf1f14a6bd 100644 --- a/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java +++ b/java/api/src/main/java/org/ray/api/runtime/RayRuntime.java @@ -3,6 +3,7 @@ import java.util.List; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RuntimeContext; import org.ray.api.WaitResult; import org.ray.api.function.RayFunc; import org.ray.api.id.UniqueId; @@ -93,4 +94,6 @@ public interface RayRuntime { */ RayActor createActor(RayFunc actorFactoryFunc, Object[] args, ActorCreationOptions options); + + RuntimeContext getRuntimeContext(); } diff --git a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java index 12dd1f759ea9b..2411b92675964 100644 --- a/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java @@ -10,6 +10,7 @@ import java.util.stream.Collectors; import org.ray.api.RayActor; import org.ray.api.RayObject; +import org.ray.api.RuntimeContext; import org.ray.api.WaitResult; import org.ray.api.exception.RayException; import org.ray.api.function.RayFunc; @@ -61,6 +62,7 @@ public abstract class AbstractRayRuntime implements RayRuntime { protected RayletClient rayletClient; protected ObjectStoreProxy objectStoreProxy; protected FunctionManager functionManager; + protected RuntimeContext runtimeContext; public AbstractRayRuntime(RayConfig rayConfig) { this.rayConfig = rayConfig; @@ -68,6 +70,7 @@ public AbstractRayRuntime(RayConfig rayConfig) { worker = new Worker(this); workerContext = new WorkerContext(rayConfig.workerMode, rayConfig.driverId, rayConfig.runMode); + runtimeContext = new RuntimeContextImpl(this); } /** @@ -346,4 +349,9 @@ public FunctionManager getFunctionManager() { public RayConfig getRayConfig() { return rayConfig; } + + public RuntimeContext getRuntimeContext() { + return runtimeContext; + } + } diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 70ad03af4916c..4c070ed88d7bb 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -173,4 +173,22 @@ List getCheckpointsForActor(UniqueId actorId) { checkpoints.sort((x, y) -> Long.compare(y.timestamp, x.timestamp)); return checkpoints; } + + + /** + * Query whether the actor exists in Gcs. + */ + boolean actorExistsInGcs(UniqueId actorId) { + byte[] key = ArrayUtils.addAll("ACTOR".getBytes(), actorId.getBytes()); + + // TODO(qwang): refactor this with `GlobalState` after this issue + // getting finished. https://github.com/ray-project/ray/issues/3933 + for (RedisClient client : redisClients) { + if (client.exists(key)) { + return true; + } + } + + return false; + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java new file mode 100644 index 0000000000000..f0780cc2d8cdf --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/RuntimeContextImpl.java @@ -0,0 +1,56 @@ +package org.ray.runtime; + +import com.google.common.base.Preconditions; +import org.ray.api.RuntimeContext; +import org.ray.api.id.UniqueId; +import org.ray.runtime.config.RunMode; +import org.ray.runtime.config.WorkerMode; +import org.ray.runtime.task.TaskSpec; + +public class RuntimeContextImpl implements RuntimeContext { + + private AbstractRayRuntime runtime; + + public RuntimeContextImpl(AbstractRayRuntime runtime) { + this.runtime = runtime; + } + + @Override + public UniqueId getCurrentDriverId() { + return runtime.getWorkerContext().getCurrentDriverId(); + } + + @Override + public UniqueId getCurrentActorId() { + Preconditions.checkState(runtime.rayConfig.workerMode == WorkerMode.WORKER); + return runtime.getWorker().getCurrentActorId(); + } + + @Override + public boolean wasCurrentActorReconstructed() { + TaskSpec currentTask = runtime.getWorkerContext().getCurrentTask(); + Preconditions.checkState(currentTask != null && currentTask.isActorCreationTask(), + "This method can only be called from an actor creation task."); + if (isSingleProcess()) { + return false; + } + + return ((RayNativeRuntime) runtime).actorExistsInGcs(getCurrentActorId()); + } + + @Override + public String getRayletSocketName() { + return runtime.getRayConfig().rayletSocketName; + } + + @Override + public String getObjectStoreSocketName() { + return runtime.getRayConfig().objectStoreSocketName; + } + + @Override + public boolean isSingleProcess() { + return RunMode.SINGLE_PROCESS == runtime.getRayConfig().runMode; + } + +} diff --git a/java/runtime/src/main/java/org/ray/runtime/Worker.java b/java/runtime/src/main/java/org/ray/runtime/Worker.java index 79ef9010772b0..e6a069efce765 100644 --- a/java/runtime/src/main/java/org/ray/runtime/Worker.java +++ b/java/runtime/src/main/java/org/ray/runtime/Worker.java @@ -63,6 +63,10 @@ public Worker(AbstractRayRuntime runtime) { this.runtime = runtime; } + public UniqueId getCurrentActorId() { + return currentActorId; + } + public void loop() { while (true) { LOGGER.info("Fetching new task in thread {}.", Thread.currentThread().getName()); @@ -86,6 +90,11 @@ public void execute(TaskSpec spec) { // Set context runtime.getWorkerContext().setCurrentTask(spec, rayFunction.classLoader); Thread.currentThread().setContextClassLoader(rayFunction.classLoader); + + if (spec.isActorCreationTask()) { + currentActorId = returnId; + } + // Get local actor object and arguments. Object actor = null; if (spec.isActorTask()) { @@ -94,6 +103,7 @@ public void execute(TaskSpec spec) { throw actorCreationException; } actor = currentActor; + } Object[] args = ArgumentsBuilder.unwrap(spec, rayFunction.classLoader); // Execute the task. @@ -112,7 +122,6 @@ public void execute(TaskSpec spec) { } else { maybeLoadCheckpoint(result, returnId); currentActor = result; - currentActorId = returnId; } LOGGER.info("Finished executing task {}", spec.taskId); } catch (Exception e) { @@ -121,7 +130,6 @@ public void execute(TaskSpec spec) { runtime.put(returnId, new RayTaskException("Error executing task " + spec, e)); } else { actorCreationException = e; - currentActorId = returnId; } } finally { Thread.currentThread().setContextClassLoader(oldLoader); diff --git a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java index 07a5640ee27b2..57f23cf31b195 100644 --- a/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java +++ b/java/runtime/src/main/java/org/ray/runtime/WorkerContext.java @@ -26,6 +26,8 @@ public class WorkerContext { */ private ThreadLocal taskIndex; + private ThreadLocal currentTask; + private UniqueId currentDriverId; private ClassLoader currentClassLoader; @@ -46,6 +48,7 @@ public WorkerContext(WorkerMode workerMode, UniqueId driverId, RunMode runMode) putIndex = ThreadLocal.withInitial(() -> 0); currentTaskId = ThreadLocal.withInitial(UniqueId::randomId); this.runMode = runMode; + currentTask = ThreadLocal.withInitial(() -> null); currentClassLoader = null; if (workerMode == WorkerMode.DRIVER) { workerId = driverId; @@ -83,6 +86,7 @@ public void setCurrentTask(TaskSpec task, ClassLoader classLoader) { this.currentDriverId = task.driverId; taskIndex.set(0); putIndex.set(0); + this.currentTask.set(task); currentClassLoader = classLoader; } @@ -124,4 +128,10 @@ public ClassLoader getCurrentClassLoader() { return currentClassLoader; } + /** + * Get the current task. + */ + public TaskSpec getCurrentTask() { + return this.currentTask.get(); + } } diff --git a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java index 94f189785a7bb..62e82a9eca11b 100644 --- a/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java +++ b/java/runtime/src/main/java/org/ray/runtime/gcs/RedisClient.java @@ -85,4 +85,14 @@ public List lrange(String key, long start, long end) { return jedis.lrange(key, start, end); } } + + /** + * Whether the key exists in Redis. + */ + public boolean exists(byte[] key) { + try (Jedis jedis = jedisPool.getResource()) { + return jedis.exists(key); + } + } + } diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java index 347ec3388b947..f0f1df8befa9f 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/RunManager.java @@ -268,6 +268,10 @@ private String buildWorkerCommandRaylet() { cmd.add("-Dray.logging.file.path=" + logFile); } + // socket names + cmd.add("-Dray.raylet.socket-name=" + rayConfig.rayletSocketName); + cmd.add("-Dray.object-store.socket-name=" + rayConfig.objectStoreSocketName); + // Config overwrite cmd.add("-Dray.redis.address=" + rayConfig.getRedisAddress()); diff --git a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java index 516da9ced302c..12d7d1a8a9313 100644 --- a/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java +++ b/java/test/src/main/java/org/ray/api/test/ActorReconstructionTest.java @@ -24,6 +24,16 @@ public static class Counter { protected int value = 0; + private boolean wasCurrentActorReconstructed = false; + + public Counter() { + wasCurrentActorReconstructed = Ray.getRuntimeContext().wasCurrentActorReconstructed(); + } + + public boolean wasCurrentActorReconstructed() { + return wasCurrentActorReconstructed; + } + public int increase() { value += 1; return value; @@ -48,6 +58,8 @@ public void testActorReconstruction() throws InterruptedException, IOException { Ray.call(Counter::increase, actor).get(); } + Assert.assertFalse(Ray.call(Counter::wasCurrentActorReconstructed, actor).get()); + // Kill the actor process. int pid = Ray.call(Counter::getPid, actor).get(); Runtime.getRuntime().exec("kill -9 " + pid); @@ -58,6 +70,8 @@ public void testActorReconstruction() throws InterruptedException, IOException { int value = Ray.call(Counter::increase, actor).get(); Assert.assertEquals(value, 4); + Assert.assertTrue(Ray.call(Counter::wasCurrentActorReconstructed, actor).get()); + // Kill the actor process again. pid = Ray.call(Counter::getPid, actor).get(); Runtime.getRuntime().exec("kill -9 " + pid); diff --git a/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java new file mode 100644 index 0000000000000..b6fdca32f170a --- /dev/null +++ b/java/test/src/main/java/org/ray/api/test/RuntimeContextTest.java @@ -0,0 +1,52 @@ +package org.ray.api.test; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.annotation.RayRemote; +import org.ray.api.id.UniqueId; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class RuntimeContextTest extends BaseTest { + + private static UniqueId DRIVER_ID = + UniqueId.fromHexString("0011223344556677889900112233445566778899"); + private static String RAYLET_SOCKET_NAME = "/tmp/ray/test/raylet_socket"; + private static String OBJECT_STORE_SOCKET_NAME = "/tmp/ray/test/object_store_socket"; + + @Override + public void beforeInitRay() { + System.setProperty("ray.driver.id", DRIVER_ID.toString()); + System.setProperty("ray.raylet.socket-name", RAYLET_SOCKET_NAME); + System.setProperty("ray.object-store.socket-name", OBJECT_STORE_SOCKET_NAME); + } + + @Test + public void testRuntimeContextInDriver() { + Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId()); + Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName()); + Assert.assertEquals(OBJECT_STORE_SOCKET_NAME, + Ray.getRuntimeContext().getObjectStoreSocketName()); + } + + @RayRemote + public static class RuntimeContextTester { + + public String testRuntimeContext(UniqueId actorId) { + Assert.assertEquals(DRIVER_ID, Ray.getRuntimeContext().getCurrentDriverId()); + Assert.assertEquals(actorId, Ray.getRuntimeContext().getCurrentActorId()); + Assert.assertEquals(RAYLET_SOCKET_NAME, Ray.getRuntimeContext().getRayletSocketName()); + Assert.assertEquals(OBJECT_STORE_SOCKET_NAME, + Ray.getRuntimeContext().getObjectStoreSocketName()); + return "ok"; + } + } + + @Test + public void testRuntimeContextInActor() { + RayActor actor = Ray.createActor(RuntimeContextTester::new); + Assert.assertEquals("ok", + Ray.call(RuntimeContextTester::testRuntimeContext, actor, actor.getId()).get()); + } + +}