Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import io.ray.api.id.ObjectId;
import io.ray.runtime.generated.Gcs.ErrorType;
import io.ray.runtime.serializer.Serializer;
import java.nio.ByteBuffer;
import java.util.Arrays;
import org.apache.commons.lang3.tuple.Pair;

Expand Down Expand Up @@ -45,6 +46,9 @@ public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objec
if (meta != null && meta.length > 0) {
// If meta is not null, deserialize the object from meta.
if (Arrays.equals(meta, OBJECT_METADATA_TYPE_RAW)) {
if (objectType == ByteBuffer.class) {
return ByteBuffer.wrap(data);
}
return data;
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(meta, OBJECT_METADATA_TYPE_JAVA)) {
Expand Down Expand Up @@ -81,6 +85,17 @@ public static NativeRayObject serialize(Object object) {
// If the object is a byte array, skip serializing it and use a special metadata to
// indicate it's raw binary. So that this object can also be read by Python.
return new NativeRayObject((byte[]) object, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof ByteBuffer) {
// Serialize ByteBuffer to raw bytes.
ByteBuffer buffer = (ByteBuffer) object;
byte[] bytes;
if (buffer.hasArray()) {
bytes = buffer.array();
} else {
bytes = new byte[buffer.remaining()];
buffer.get(bytes);
}
return new NativeRayObject(bytes, OBJECT_METADATA_TYPE_RAW);
} else if (object instanceof RayTaskException) {
byte[] serializedBytes = Serializer.encode(object).getLeft();
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
package io.ray.runtime.task;

import com.google.common.base.Preconditions;
import io.ray.api.ObjectRef;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import io.ray.runtime.RayRuntimeInternal;
import io.ray.runtime.generated.Common.Language;
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
Expand Down Expand Up @@ -68,12 +70,19 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
}

/**
* Convert list of NativeRayObject to real function arguments.
* Convert list of NativeRayObject/ByteBuffer to real function arguments.
*/
public static Object[] unwrap(List<NativeRayObject> args, Class<?>[] types) {
public static Object[] unwrap(List<Object> args, Class<?>[] types) {
Object[] realArgs = new Object[args.size()];
for (int i = 0; i < args.size(); i++) {
realArgs[i] = ObjectSerializer.deserialize(args.get(i), null, types[i]);
Object arg = args.get(i);
Preconditions.checkState(arg instanceof ByteBuffer || arg instanceof NativeRayObject);
if (arg instanceof ByteBuffer) {
Preconditions.checkState(types[i] == ByteBuffer.class);
realArgs[i] = arg;
} else {
realArgs[i] = ObjectSerializer.deserialize((NativeRayObject) arg, null, types[i]);
}
}
return realArgs;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -311,8 +311,9 @@ private void executeTask(TaskSpec taskSpec) {
? ((LocalModeTaskExecutor.LocalActorContext) actorContext).getWorkerId()
: UniqueId.randomId();
((LocalModeWorkerContext) runtime.getWorkerContext()).setCurrentWorkerId(workerId);
List<NativeRayObject> returnObjects = taskExecutor
.execute(getJavaFunctionDescriptor(taskSpec).toList(), args);
List<String> rayFunctionInfo = getJavaFunctionDescriptor(taskSpec).toList();
taskExecutor.checkByteBufferArguments(rayFunctionInfo);
List<NativeRayObject> returnObjects = taskExecutor.execute(rayFunctionInfo, args);
if (taskSpec.getType() == TaskType.ACTOR_CREATION_TASK) {
// Update actor context map ASAP in case objectStore.putRaw triggered the next actor task
// on this actor.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import io.ray.runtime.object.NativeRayObject;
import io.ray.runtime.object.ObjectSerializer;
import java.lang.reflect.InvocationTargetException;
import java.nio.ByteBuffer;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
Expand All @@ -30,6 +31,8 @@ public abstract class TaskExecutor<T extends TaskExecutor.ActorContext> {

private final ConcurrentHashMap<UniqueId, T> actorContextMap = new ConcurrentHashMap<>();

private final ThreadLocal<RayFunction> localRayFunction = new ThreadLocal<>();

static class ActorContext {

/**
Expand Down Expand Up @@ -61,10 +64,34 @@ void setActorContext(T actorContext) {
this.actorContextMap.put(runtime.getWorkerContext().getCurrentWorkerId(), actorContext);
}

private RayFunction getRayFunction(List<String> rayFunctionInfo) {
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
return runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
}

/**
* The return value indicates which parameters are ByteBuffer.
*/
protected boolean[] checkByteBufferArguments(List<String> rayFunctionInfo) {
localRayFunction.set(null);
try {
localRayFunction.set(getRayFunction(rayFunctionInfo));
} catch (Throwable e) {
// Ignore the exception.
return null;
}
Class<?>[] types = localRayFunction.get().executable.getParameterTypes();
boolean[] results = new boolean[types.length];
for (int i = 0; i < types.length; i++) {
results[i] = types[i] == ByteBuffer.class;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should also check DirectByteBuffer.class?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

DirectByteBuffer is not a class. Instead, ByteBuffer has a method called isDirect.

}
return results;
}

protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
List<NativeRayObject> argsBytes) {
List<Object> argsBytes) {
runtime.setIsContextSet(true);
JobId jobId = runtime.getWorkerContext().getCurrentJobId();
TaskType taskType = runtime.getWorkerContext().getCurrentTaskType();
TaskId taskId = runtime.getWorkerContext().getCurrentTaskId();
LOGGER.debug("Executing task {}", taskId);
Expand All @@ -80,11 +107,14 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo,

List<NativeRayObject> returnObjects = new ArrayList<>();
ClassLoader oldLoader = Thread.currentThread().getContextClassLoader();
JavaFunctionDescriptor functionDescriptor = parseFunctionDescriptor(rayFunctionInfo);
RayFunction rayFunction = null;
RayFunction rayFunction = localRayFunction.get();
try {
// Find the executable object.
rayFunction = runtime.getFunctionManager().getFunction(jobId, functionDescriptor);
if (rayFunction == null) {
// Failed to get RayFunction in checkByteBufferArguments. Redo here to throw
// the exception again.
rayFunction = getRayFunction(rayFunctionInfo);
}
Thread.currentThread().setContextClassLoader(rayFunction.classLoader);
runtime.getWorkerContext().setCurrentClassLoader(rayFunction.classLoader);

Expand Down Expand Up @@ -132,7 +162,7 @@ protected List<NativeRayObject> execute(List<String> rayFunctionInfo,
LOGGER.error("Error executing task " + taskId, e);
if (taskType != TaskType.ACTOR_CREATION_TASK) {
boolean hasReturn = rayFunction != null && rayFunction.hasReturn();
boolean isCrossLanguage = functionDescriptor.signature.equals("");
boolean isCrossLanguage = parseFunctionDescriptor(rayFunctionInfo).signature.equals("");
if (hasReturn || isCrossLanguage) {
returnObjects.add(ObjectSerializer
.serialize(new RayTaskException("Error executing task " + taskId, e)));
Expand Down
11 changes: 11 additions & 0 deletions java/test/src/main/java/io/ray/test/RayCallTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import com.google.common.collect.ImmutableMap;
import io.ray.api.Ray;
import io.ray.api.id.ObjectId;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.List;
import java.util.Map;
import org.testng.Assert;
Expand Down Expand Up @@ -63,6 +65,10 @@ private static void testNoReturn(ObjectId objectId) {
TestUtils.getRuntime().getObjectStore().put(1, objectId);
}

private static ByteBuffer testByteBuffer(ByteBuffer buffer) {
return buffer;
}

/**
* Test calling and returning different types.
*/
Expand All @@ -82,6 +88,11 @@ public void testType() {
Assert.assertEquals(map, Ray.task(RayCallTest::testMap, map).remote().get());
TestUtils.LargeObject largeObject = new TestUtils.LargeObject();
Assert.assertNotNull(Ray.task(RayCallTest::testLargeObject, largeObject).remote().get());
ByteBuffer buffer1 = ByteBuffer.wrap("foo".getBytes(StandardCharsets.UTF_8));
ByteBuffer buffer2 = Ray.task(RayCallTest::testByteBuffer, buffer1).remote().get();
byte[] bytes = new byte[buffer2.remaining()];
buffer2.get(bytes);
Assert.assertEquals("foo", new String(bytes, StandardCharsets.UTF_8));

// TODO(edoakes): this test doesn't work now that we've switched to direct call
// mode. To make it work, we need to implement the same protocol for resolving
Expand Down
8 changes: 6 additions & 2 deletions src/ray/core_worker/core_worker.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1666,8 +1666,12 @@ Status CoreWorker::GetAndPinArgsForExecutor(const TaskSpecification &task,
metadata = std::make_shared<LocalMemoryBuffer>(
const_cast<uint8_t *>(task.ArgMetadata(i)), task.ArgMetadataSize(i));
}
args->at(i) = std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i),
/*copy_data*/ true);
// NOTE: this is a workaround to avoid an extra copy for Java workers.
// Python workers need this copy to pass test case
// test_inline_arg_memory_corruption.
bool copy_data = options_.language == Language::PYTHON;
args->at(i) =
std::make_shared<RayObject>(data, metadata, task.ArgInlinedIds(i), copy_data);
arg_reference_ids->at(i) = ObjectID::Nil();
// The task borrows all ObjectIDs that were serialized in the inlined
// arguments. The task will receive references to these IDs, so it is
Expand Down
38 changes: 36 additions & 2 deletions src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,35 @@ inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
return ray::gcs::GcsClientOptions(ip, port, password, /*is_test_client=*/false);
}

jobject ToJavaArgs(JNIEnv *env, jbooleanArray java_check_results,
const std::vector<std::shared_ptr<ray::RayObject>> &args) {
if (java_check_results == nullptr) {
// If `java_check_results` is null, it means that `checkByteBufferArguments`
// failed. In this case, just return null here. The args won't be used anyway.
return nullptr;
} else {
jboolean *check_results = env->GetBooleanArrayElements(java_check_results, nullptr);
size_t i = 0;
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
env, args,
[check_results, &i](JNIEnv *env,
const std::shared_ptr<ray::RayObject> &native_object) {
if (*(check_results + (i++))) {
// If the type of this argument is ByteBuffer, we create a
// DirectByteBuffer here To avoid data copy.
// TODO: Check native_object->GetMetadata() == "RAW"
jobject obj = env->NewDirectByteBuffer(native_object->GetData()->Data(),
native_object->GetData()->Size());
RAY_CHECK(obj);
return obj;
}
return NativeRayObjectToJavaNativeRayObject(env, native_object);
});
env->ReleaseBooleanArrayElements(java_check_results, check_results, JNI_ABORT);
return args_array_list;
}
}

#ifdef __cplusplus
extern "C" {
#endif
Expand Down Expand Up @@ -100,8 +129,12 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(

// convert args
// TODO (kfstorm): Avoid copying binary data from Java to C++
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
env, args, NativeRayObjectToJavaNativeRayObject);
jbooleanArray java_check_results =
static_cast<jbooleanArray>(env->CallObjectMethod(
java_task_executor, java_task_executor_parse_function_arguments,
ray_function_array_list));
RAY_CHECK_JAVA_EXCEPTION(env);
jobject args_array_list = ToJavaArgs(env, java_check_results, args);

// invoke Java method
jobject java_return_objects =
Expand All @@ -120,6 +153,7 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
}
}

env->DeleteLocalRef(java_check_results);
env->DeleteLocalRef(java_return_objects);
env->DeleteLocalRef(args_array_list);
return ray::Status::OK();
Expand Down
3 changes: 3 additions & 0 deletions src/ray/core_worker/lib/java/jni_init.cc
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ jfieldID java_native_ray_object_data;
jfieldID java_native_ray_object_metadata;

jclass java_task_executor_class;
jmethodID java_task_executor_parse_function_arguments;
jmethodID java_task_executor_execute;

JavaVM *jvm;
Expand Down Expand Up @@ -205,6 +206,8 @@ jint JNI_OnLoad(JavaVM *vm, void *reserved) {
env->GetFieldID(java_native_ray_object_class, "metadata", "[B");

java_task_executor_class = LoadClass(env, "io/ray/runtime/task/TaskExecutor");
java_task_executor_parse_function_arguments = env->GetMethodID(
java_task_executor_class, "checkByteBufferArguments", "(Ljava/util/List;)[Z");
java_task_executor_execute =
env->GetMethodID(java_task_executor_class, "execute",
"(Ljava/util/List;Ljava/util/List;)Ljava/util/List;");
Expand Down
2 changes: 2 additions & 0 deletions src/ray/core_worker/lib/java/jni_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ extern jfieldID java_native_ray_object_metadata;

/// TaskExecutor class
extern jclass java_task_executor_class;
/// checkByteBufferArguments method of TaskExecutor class
extern jmethodID java_task_executor_parse_function_arguments;
/// execute method of TaskExecutor class
extern jmethodID java_task_executor_execute;

Expand Down