Skip to content

Commit c59b506

Browse files
jovany-wangraulchen
authored andcommitted
[Java] Support calling Ray APIs from multiple threads (#3646)
1 parent 0b682d0 commit c59b506

File tree

6 files changed

+201
-29
lines changed

6 files changed

+201
-29
lines changed

java/runtime/src/main/java/org/ray/runtime/AbstractRayRuntime.java

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -108,9 +108,7 @@ public <T> T get(UniqueId objectId) throws RayException {
108108
@Override
109109
public <T> List<T> get(List<UniqueId> objectIds) {
110110
boolean wasBlocked = false;
111-
// TODO(swang): If we are not on the main thread, then we should generate a
112-
// random task ID to pass to the backend.
113-
UniqueId taskId = workerContext.getCurrentTask().taskId;
111+
UniqueId taskId = workerContext.getCurrentThreadTaskId();
114112

115113
try {
116114
int numObjectIds = objectIds.size();
@@ -218,10 +216,8 @@ private List<List<UniqueId>> splitIntoBatches(List<UniqueId> objectIds, int batc
218216

219217
@Override
220218
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
221-
// TODO(swang): If we are not on the main thread, then we should generate a
222-
// random task ID to pass to the backend.
223-
return rayletClient.wait(waitList, numReturns, timeoutMs,
224-
workerContext.getCurrentTask().taskId);
219+
return rayletClient.wait(waitList, numReturns,
220+
timeoutMs, workerContext.getCurrentThreadTaskId());
225221
}
226222

227223
@Override
@@ -237,9 +233,12 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) {
237233
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
238234
}
239235
RayActorImpl actorImpl = (RayActorImpl)actor;
240-
TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null);
241-
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
242-
actorImpl.setTaskCursor(spec.returnIds[1]);
236+
TaskSpec spec;
237+
synchronized (actor) {
238+
spec = createTaskSpec(func, actorImpl, args, false, null);
239+
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
240+
actorImpl.setTaskCursor(spec.returnIds[1]);
241+
}
243242
rayletClient.submitTask(spec);
244243
return new RayObjectImpl(spec.returnIds[0]);
245244
}
@@ -342,4 +341,8 @@ public RayletClient getRayletClient() {
342341
public FunctionManager getFunctionManager() {
343342
return functionManager;
344343
}
344+
345+
public RayConfig getRayConfig() {
346+
return rayConfig;
347+
}
345348
}

java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,21 @@ public RayDevRuntime(RayConfig rayConfig) {
1111
super(rayConfig);
1212
}
1313

14+
private MockObjectStore store;
15+
1416
@Override
1517
public void start() {
16-
MockObjectStore store = new MockObjectStore(this);
17-
objectStoreProxy = new ObjectStoreProxy(this, store);
18+
store = new MockObjectStore(this);
19+
objectStoreProxy = new ObjectStoreProxy(this, null);
1820
rayletClient = new MockRayletClient(this, store);
1921
}
2022

2123
@Override
2224
public void shutdown() {
2325
// nothing to do
2426
}
27+
28+
public MockObjectStore getObjectStore() {
29+
return store;
30+
}
2531
}

java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -74,8 +74,7 @@ public void start() throws Exception {
7474
}
7575
kvStore = new RedisClient(rayConfig.getRedisAddress());
7676

77-
ObjectStoreLink store = new PlasmaClient(rayConfig.objectStoreSocketName, "", 0);
78-
objectStoreProxy = new ObjectStoreProxy(this, store);
77+
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);
7978

8079
rayletClient = new RayletClientImpl(
8180
rayConfig.rayletSocketName,

java/runtime/src/main/java/org/ray/runtime/WorkerContext.java

Lines changed: 48 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,17 @@
11
package org.ray.runtime;
22

3+
import com.google.common.base.Preconditions;
34
import java.util.HashMap;
5+
import java.util.concurrent.atomic.AtomicBoolean;
6+
import java.util.concurrent.atomic.AtomicInteger;
47
import org.ray.api.id.UniqueId;
58
import org.ray.runtime.config.WorkerMode;
69
import org.ray.runtime.task.TaskSpec;
10+
import org.slf4j.Logger;
11+
import org.slf4j.LoggerFactory;
712

813
public class WorkerContext {
14+
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);
915

1016
/**
1117
* Worker id.
@@ -25,19 +31,53 @@ public class WorkerContext {
2531
/**
2632
* How many puts have been done by current task.
2733
*/
28-
private int currentTaskPutCount;
34+
private AtomicInteger currentTaskPutCount;
2935

3036
/**
3137
* How many calls have been done by current task.
3238
*/
33-
private int currentTaskCallCount;
39+
private AtomicInteger currentTaskCallCount;
40+
41+
/**
42+
* The ID of main thread which created the worker context.
43+
*/
44+
private long mainThreadId;
45+
/**
46+
* If the multi-threading warning message has been logged.
47+
*/
48+
private AtomicBoolean multiThreadingWarned;
3449

3550
public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
3651
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
37-
currentTaskPutCount = 0;
38-
currentTaskCallCount = 0;
52+
currentTaskPutCount = new AtomicInteger(0);
53+
currentTaskCallCount = new AtomicInteger(0);
3954
currentClassLoader = null;
4055
currentTask = createDummyTask(workerMode, driverId);
56+
mainThreadId = Thread.currentThread().getId();
57+
multiThreadingWarned = new AtomicBoolean(false);
58+
}
59+
60+
/**
61+
* Get the current thread's task ID.
62+
* This returns the assigned task ID if called on the main thread, else a
63+
* random task ID.
64+
*/
65+
public UniqueId getCurrentThreadTaskId() {
66+
UniqueId taskId;
67+
if (Thread.currentThread().getId() == mainThreadId) {
68+
taskId = currentTask.taskId;
69+
} else {
70+
taskId = UniqueId.randomId();
71+
if (multiThreadingWarned.compareAndSet(false, true)) {
72+
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
73+
"may lead to deadlock if the main thread blocks on this " +
74+
"thread and there are not enough resources to execute " +
75+
"more tasks");
76+
}
77+
}
78+
79+
Preconditions.checkState(!taskId.isNil());
80+
return taskId;
4181
}
4282

4383
public void setWorkerId(UniqueId workerId) {
@@ -49,11 +89,11 @@ public TaskSpec getCurrentTask() {
4989
}
5090

5191
public int nextPutIndex() {
52-
return ++currentTaskPutCount;
92+
return currentTaskPutCount.incrementAndGet();
5393
}
5494

5595
public int nextCallIndex() {
56-
return ++currentTaskCallCount;
96+
return currentTaskCallCount.incrementAndGet();
5797
}
5898

5999
public UniqueId getCurrentWorkerId() {
@@ -66,6 +106,8 @@ public ClassLoader getCurrentClassLoader() {
66106

67107
public void setCurrentTask(TaskSpec currentTask) {
68108
this.currentTask = currentTask;
109+
currentTaskCallCount.set(0);
110+
currentTaskPutCount.set(0);
69111
}
70112

71113
public void setCurrentClassLoader(ClassLoader currentClassLoader) {

java/runtime/src/main/java/org/ray/runtime/objectstore/ObjectStoreProxy.java

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,13 @@
33
import java.util.ArrayList;
44
import java.util.List;
55
import org.apache.arrow.plasma.ObjectStoreLink;
6+
import org.apache.arrow.plasma.PlasmaClient;
67
import org.apache.commons.lang3.tuple.Pair;
78
import org.ray.api.exception.RayException;
89
import org.ray.api.id.UniqueId;
910
import org.ray.runtime.AbstractRayRuntime;
11+
import org.ray.runtime.RayDevRuntime;
12+
import org.ray.runtime.config.RunMode;
1013
import org.ray.runtime.util.Serializer;
1114
import org.ray.runtime.util.UniqueIdUtil;
1215

@@ -19,11 +22,18 @@ public class ObjectStoreProxy {
1922
private static final int GET_TIMEOUT_MS = 1000;
2023

2124
private final AbstractRayRuntime runtime;
22-
private final ObjectStoreLink store;
2325

24-
public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) {
26+
private static ThreadLocal<ObjectStoreLink> objectStore;
27+
28+
public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
2529
this.runtime = runtime;
26-
this.store = store;
30+
objectStore = ThreadLocal.withInitial(() -> {
31+
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
32+
return new PlasmaClient(storeSocketName, "", 0);
33+
} else {
34+
return ((RayDevRuntime) runtime).getObjectStore();
35+
}
36+
});
2737
}
2838

2939
public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
@@ -33,10 +43,10 @@ public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
3343

3444
public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
3545
throws RayException {
36-
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
46+
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
3747
if (obj != null) {
3848
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
39-
store.release(id.getBytes());
49+
objectStore.get().release(id.getBytes());
4050
if (t instanceof RayException) {
4151
throw (RayException) t;
4252
}
@@ -53,13 +63,13 @@ public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMeta
5363

5464
public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
5565
throws RayException {
56-
List<byte[]> objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
66+
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
5767
List<Pair<T, GetStatus>> ret = new ArrayList<>();
5868
for (int i = 0; i < objs.size(); i++) {
5969
byte[] obj = objs.get(i);
6070
if (obj != null) {
6171
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
62-
store.release(ids.get(i).getBytes());
72+
objectStore.get().release(ids.get(i).getBytes());
6373
if (t instanceof RayException) {
6474
throw (RayException) t;
6575
}
@@ -72,11 +82,11 @@ public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boole
7282
}
7383

7484
public void put(UniqueId id, Object obj, Object metadata) {
75-
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
85+
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
7686
}
7787

7888
public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
79-
store.put(id.getBytes(), obj, metadata);
89+
objectStore.get().put(id.getBytes(), obj, metadata);
8090
}
8191

8292
public enum GetStatus {
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
package org.ray.api.test;
2+
3+
import com.google.common.collect.ImmutableList;
4+
import java.util.ArrayList;
5+
import java.util.List;
6+
import java.util.Random;
7+
import java.util.concurrent.Callable;
8+
import java.util.concurrent.ExecutorService;
9+
import java.util.concurrent.Executors;
10+
import java.util.concurrent.Future;
11+
import java.util.concurrent.TimeUnit;
12+
import org.junit.Assert;
13+
import org.junit.Test;
14+
import org.ray.api.Ray;
15+
import org.ray.api.RayActor;
16+
import org.ray.api.RayObject;
17+
import org.ray.api.WaitResult;
18+
import org.ray.api.annotation.RayRemote;
19+
20+
21+
public class MultiThreadingTest extends BaseTest {
22+
23+
private static final int LOOP_COUNTER = 1000;
24+
private static final int NUM_THREADS = 20;
25+
26+
@RayRemote
27+
public static Integer echo(int num) {
28+
return num;
29+
}
30+
31+
@RayRemote
32+
public static class Echo {
33+
34+
@RayRemote
35+
public Integer echo(int num) {
36+
return num;
37+
}
38+
}
39+
40+
public static String testMultiThreading() {
41+
Random random = new Random();
42+
// Test calling normal functions.
43+
runTestCaseInMultipleThreads(() -> {
44+
int arg = random.nextInt();
45+
RayObject<Integer> obj = Ray.call(MultiThreadingTest::echo, arg);
46+
Assert.assertEquals(arg, (int) obj.get());
47+
}, LOOP_COUNTER);
48+
49+
// Test calling actors.
50+
RayActor<Echo> echoActor = Ray.createActor(Echo::new);
51+
runTestCaseInMultipleThreads(() -> {
52+
int arg = random.nextInt();
53+
RayObject<Integer> obj = Ray.call(Echo::echo, echoActor, arg);
54+
Assert.assertEquals(arg, (int) obj.get());
55+
}, LOOP_COUNTER);
56+
57+
// Test put and get.
58+
runTestCaseInMultipleThreads(() -> {
59+
int arg = random.nextInt();
60+
RayObject<Integer> obj = Ray.put(arg);
61+
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
62+
}, LOOP_COUNTER);
63+
64+
// Test wait for one object in multi threads.
65+
RayObject<Integer> obj = Ray.call(MultiThreadingTest::echo, 100);
66+
runTestCaseInMultipleThreads(() -> {
67+
WaitResult<Integer> result = Ray.wait(ImmutableList.of(obj), 1, 1000);
68+
Assert.assertEquals(1, result.getReady().size());
69+
}, 1);
70+
71+
return "ok";
72+
}
73+
74+
@Test
75+
public void testInDriver() {
76+
testMultiThreading();
77+
}
78+
79+
@Test
80+
public void testInWorker() {
81+
RayObject<String> obj = Ray.call(MultiThreadingTest::testMultiThreading);
82+
Assert.assertEquals("ok", obj.get());
83+
}
84+
85+
private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
86+
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);
87+
88+
try {
89+
List<Future<String>> futures = new ArrayList<>();
90+
for (int i = 0; i < NUM_THREADS; i++) {
91+
Callable<String> task = () -> {
92+
for (int j = 0; j < numRepeats; j++) {
93+
TimeUnit.MILLISECONDS.sleep(1);
94+
testCase.run();
95+
}
96+
return "ok";
97+
};
98+
futures.add(service.submit(task));
99+
}
100+
for (Future<String> future : futures) {
101+
try {
102+
Assert.assertEquals(future.get(), "ok");
103+
} catch (Exception e) {
104+
throw new RuntimeException("Test case failed.", e);
105+
}
106+
}
107+
} finally {
108+
service.shutdown();
109+
}
110+
}
111+
112+
}

0 commit comments

Comments
 (0)