Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,7 @@ public <T> T get(UniqueId objectId) throws RayException {
@Override
public <T> List<T> get(List<UniqueId> objectIds) {
boolean wasBlocked = false;
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
UniqueId taskId = workerContext.getCurrentTask().taskId;
UniqueId taskId = workerContext.getCurrentThreadTaskId();

try {
int numObjectIds = objectIds.size();
Expand Down Expand Up @@ -218,10 +216,8 @@ private List<List<UniqueId>> splitIntoBatches(List<UniqueId> objectIds, int batc

@Override
public <T> WaitResult<T> wait(List<RayObject<T>> waitList, int numReturns, int timeoutMs) {
// TODO(swang): If we are not on the main thread, then we should generate a
// random task ID to pass to the backend.
return rayletClient.wait(waitList, numReturns, timeoutMs,
workerContext.getCurrentTask().taskId);
return rayletClient.wait(waitList, numReturns,
timeoutMs, workerContext.getCurrentThreadTaskId());
}

@Override
Expand All @@ -237,9 +233,12 @@ public RayObject call(RayFunc func, RayActor actor, Object[] args) {
throw new IllegalArgumentException("Unsupported actor type: " + actor.getClass().getName());
}
RayActorImpl actorImpl = (RayActorImpl)actor;
TaskSpec spec = createTaskSpec(func, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
TaskSpec spec;
synchronized (actor) {
spec = createTaskSpec(func, actorImpl, args, false, null);
spec.getExecutionDependencies().add(((RayActorImpl) actor).getTaskCursor());
actorImpl.setTaskCursor(spec.returnIds[1]);
}
rayletClient.submitTask(spec);
return new RayObjectImpl(spec.returnIds[0]);
}
Expand Down Expand Up @@ -342,4 +341,8 @@ public RayletClient getRayletClient() {
public FunctionManager getFunctionManager() {
return functionManager;
}

public RayConfig getRayConfig() {
return rayConfig;
}
}
10 changes: 8 additions & 2 deletions java/runtime/src/main/java/org/ray/runtime/RayDevRuntime.java
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,21 @@ public RayDevRuntime(RayConfig rayConfig) {
super(rayConfig);
}

private MockObjectStore store;

@Override
public void start() {
MockObjectStore store = new MockObjectStore(this);
objectStoreProxy = new ObjectStoreProxy(this, store);
store = new MockObjectStore(this);
objectStoreProxy = new ObjectStoreProxy(this, null);
rayletClient = new MockRayletClient(this, store);
}

@Override
public void shutdown() {
// nothing to do
}

public MockObjectStore getObjectStore() {
return store;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ public void start() throws Exception {
}
kvStore = new RedisClient(rayConfig.getRedisAddress());

ObjectStoreLink store = new PlasmaClient(rayConfig.objectStoreSocketName, "", 0);
objectStoreProxy = new ObjectStoreProxy(this, store);
objectStoreProxy = new ObjectStoreProxy(this, rayConfig.objectStoreSocketName);

rayletClient = new RayletClientImpl(
rayConfig.rayletSocketName,
Expand Down
54 changes: 48 additions & 6 deletions java/runtime/src/main/java/org/ray/runtime/WorkerContext.java
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
package org.ray.runtime;

import com.google.common.base.Preconditions;
import java.util.HashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.ray.api.id.UniqueId;
import org.ray.runtime.config.WorkerMode;
import org.ray.runtime.task.TaskSpec;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class WorkerContext {
private static final Logger LOGGER = LoggerFactory.getLogger(WorkerContext.class);

/**
* Worker id.
Expand All @@ -25,19 +31,53 @@ public class WorkerContext {
/**
* How many puts have been done by current task.
*/
private int currentTaskPutCount;
private AtomicInteger currentTaskPutCount;

/**
* How many calls have been done by current task.
*/
private int currentTaskCallCount;
private AtomicInteger currentTaskCallCount;

/**
* The ID of main thread which created the worker context.
*/
private long mainThreadId;
/**
* If the multi-threading warning message has been logged.
*/
private AtomicBoolean multiThreadingWarned;

public WorkerContext(WorkerMode workerMode, UniqueId driverId) {
workerId = workerMode == WorkerMode.DRIVER ? driverId : UniqueId.randomId();
currentTaskPutCount = 0;
currentTaskCallCount = 0;
currentTaskPutCount = new AtomicInteger(0);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should these also be reset to 0 in setCurrentTask?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yep, @jovan-wong can you fix this?

currentTaskCallCount = new AtomicInteger(0);
currentClassLoader = null;
currentTask = createDummyTask(workerMode, driverId);
mainThreadId = Thread.currentThread().getId();
multiThreadingWarned = new AtomicBoolean(false);
}

/**
* Get the current thread's task ID.
* This returns the assigned task ID if called on the main thread, else a
* random task ID.
*/
public UniqueId getCurrentThreadTaskId() {
UniqueId taskId;
if (Thread.currentThread().getId() == mainThreadId) {
taskId = currentTask.taskId;
} else {
taskId = UniqueId.randomId();
if (multiThreadingWarned.compareAndSet(false, true)) {
LOGGER.warn("Calling Ray.get or Ray.wait in a separate thread " +
"may lead to deadlock if the main thread blocks on this " +
"thread and there are not enough resources to execute " +
"more tasks");
}
}

Preconditions.checkState(!taskId.isNil());
return taskId;
}

public void setWorkerId(UniqueId workerId) {
Expand All @@ -49,11 +89,11 @@ public TaskSpec getCurrentTask() {
}

public int nextPutIndex() {
return ++currentTaskPutCount;
return currentTaskPutCount.incrementAndGet();
}

public int nextCallIndex() {
return ++currentTaskCallCount;
return currentTaskCallCount.incrementAndGet();
}

public UniqueId getCurrentWorkerId() {
Expand All @@ -66,6 +106,8 @@ public ClassLoader getCurrentClassLoader() {

public void setCurrentTask(TaskSpec currentTask) {
this.currentTask = currentTask;
currentTaskCallCount.set(0);
currentTaskPutCount.set(0);
}

public void setCurrentClassLoader(ClassLoader currentClassLoader) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@
import java.util.ArrayList;
import java.util.List;
import org.apache.arrow.plasma.ObjectStoreLink;
import org.apache.arrow.plasma.PlasmaClient;
import org.apache.commons.lang3.tuple.Pair;
import org.ray.api.exception.RayException;
import org.ray.api.id.UniqueId;
import org.ray.runtime.AbstractRayRuntime;
import org.ray.runtime.RayDevRuntime;
import org.ray.runtime.config.RunMode;
import org.ray.runtime.util.Serializer;
import org.ray.runtime.util.UniqueIdUtil;

Expand All @@ -19,11 +22,18 @@ public class ObjectStoreProxy {
private static final int GET_TIMEOUT_MS = 1000;

private final AbstractRayRuntime runtime;
private final ObjectStoreLink store;

public ObjectStoreProxy(AbstractRayRuntime runtime, ObjectStoreLink store) {
private static ThreadLocal<ObjectStoreLink> objectStore;

public ObjectStoreProxy(AbstractRayRuntime runtime, String storeSocketName) {
this.runtime = runtime;
this.store = store;
objectStore = ThreadLocal.withInitial(() -> {
if (runtime.getRayConfig().runMode == RunMode.CLUSTER) {
return new PlasmaClient(storeSocketName, "", 0);
} else {
return ((RayDevRuntime) runtime).getObjectStore();
}
});
}

public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)
Expand All @@ -33,10 +43,10 @@ public <T> Pair<T, GetStatus> get(UniqueId objectId, boolean isMetadata)

public <T> Pair<T, GetStatus> get(UniqueId id, int timeoutMs, boolean isMetadata)
throws RayException {
byte[] obj = store.get(id.getBytes(), timeoutMs, isMetadata);
byte[] obj = objectStore.get().get(id.getBytes(), timeoutMs, isMetadata);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
store.release(id.getBytes());
objectStore.get().release(id.getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
Expand All @@ -53,13 +63,13 @@ public <T> List<Pair<T, GetStatus>> get(List<UniqueId> objectIds, boolean isMeta

public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boolean isMetadata)
throws RayException {
List<byte[]> objs = store.get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<byte[]> objs = objectStore.get().get(UniqueIdUtil.getIdBytes(ids), timeoutMs, isMetadata);
List<Pair<T, GetStatus>> ret = new ArrayList<>();
for (int i = 0; i < objs.size(); i++) {
byte[] obj = objs.get(i);
if (obj != null) {
T t = Serializer.decode(obj, runtime.getWorkerContext().getCurrentClassLoader());
store.release(ids.get(i).getBytes());
objectStore.get().release(ids.get(i).getBytes());
if (t instanceof RayException) {
throw (RayException) t;
}
Expand All @@ -72,11 +82,11 @@ public <T> List<Pair<T, GetStatus>> get(List<UniqueId> ids, int timeoutMs, boole
}

public void put(UniqueId id, Object obj, Object metadata) {
store.put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
objectStore.get().put(id.getBytes(), Serializer.encode(obj), Serializer.encode(metadata));
}

public void putSerialized(UniqueId id, byte[] obj, byte[] metadata) {
store.put(id.getBytes(), obj, metadata);
objectStore.get().put(id.getBytes(), obj, metadata);
}

public enum GetStatus {
Expand Down
112 changes: 112 additions & 0 deletions java/test/src/main/java/org/ray/api/test/MultiThreadingTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package org.ray.api.test;

import com.google.common.collect.ImmutableList;
import java.util.ArrayList;
import java.util.List;
import java.util.Random;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import org.junit.Assert;
import org.junit.Test;
import org.ray.api.Ray;
import org.ray.api.RayActor;
import org.ray.api.RayObject;
import org.ray.api.WaitResult;
import org.ray.api.annotation.RayRemote;


public class MultiThreadingTest extends BaseTest {

private static final int LOOP_COUNTER = 1000;
private static final int NUM_THREADS = 20;

@RayRemote
public static Integer echo(int num) {
return num;
}

@RayRemote
public static class Echo {

@RayRemote
public Integer echo(int num) {
return num;
}
}

public static String testMultiThreading() {
Random random = new Random();
// Test calling normal functions.
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayObject<Integer> obj = Ray.call(MultiThreadingTest::echo, arg);
Assert.assertEquals(arg, (int) obj.get());
}, LOOP_COUNTER);

// Test calling actors.
RayActor<Echo> echoActor = Ray.createActor(Echo::new);
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayObject<Integer> obj = Ray.call(Echo::echo, echoActor, arg);
Assert.assertEquals(arg, (int) obj.get());
}, LOOP_COUNTER);

// Test put and get.
runTestCaseInMultipleThreads(() -> {
int arg = random.nextInt();
RayObject<Integer> obj = Ray.put(arg);
Assert.assertEquals(arg, (int) Ray.get(obj.getId()));
}, LOOP_COUNTER);

// Test wait for one object in multi threads.
RayObject<Integer> obj = Ray.call(MultiThreadingTest::echo, 100);
runTestCaseInMultipleThreads(() -> {
WaitResult<Integer> result = Ray.wait(ImmutableList.of(obj), 1, 1000);
Assert.assertEquals(1, result.getReady().size());
}, 1);

return "ok";
}

@Test
public void testInDriver() {
testMultiThreading();
}

@Test
public void testInWorker() {
RayObject<String> obj = Ray.call(MultiThreadingTest::testMultiThreading);
Assert.assertEquals("ok", obj.get());
}

private static void runTestCaseInMultipleThreads(Runnable testCase, int numRepeats) {
ExecutorService service = Executors.newFixedThreadPool(NUM_THREADS);

try {
List<Future<String>> futures = new ArrayList<>();
for (int i = 0; i < NUM_THREADS; i++) {
Callable<String> task = () -> {
for (int j = 0; j < numRepeats; j++) {
TimeUnit.MILLISECONDS.sleep(1);
testCase.run();
}
return "ok";
};
futures.add(service.submit(task));
}
for (Future<String> future : futures) {
try {
Assert.assertEquals(future.get(), "ok");
} catch (Exception e) {
throw new RuntimeException("Test case failed.", e);
}
}
} finally {
service.shutdown();
}
}

}