Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Java] Improve JNI performance when submitting and executing tasks #9032

Merged
merged 7 commits into from
Jul 10, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package io.ray.runtime.functionmanager;

import com.google.common.base.Objects;
import io.ray.runtime.generated.Common.Language;
import java.util.Arrays;
import java.util.List;
Expand All @@ -26,6 +27,25 @@ public String toString() {
return moduleName + "." + className + "." + functionName;
}

@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
PyFunctionDescriptor that = (PyFunctionDescriptor) o;
return Objects.equal(moduleName, that.moduleName) &&
Objects.equal(className, that.className) &&
Objects.equal(functionName, that.functionName);
}

@Override
public int hashCode() {
return Objects.hashCode(moduleName, className, functionName);
}

@Override
public List<String> toList() {
return Arrays.asList(moduleName, className, functionName, "" /* function hash */);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package io.ray.runtime.task;

import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import io.ray.api.BaseActorHandle;
import io.ray.api.id.ObjectId;
import io.ray.api.options.ActorCreationOptions;
Expand All @@ -18,14 +19,19 @@ public class NativeTaskSubmitter implements TaskSubmitter {
@Override
public List<ObjectId> submitTask(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
int numReturns, CallOptions options) {
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, args, numReturns, options);
List<byte[]> returnIds = nativeSubmitTask(functionDescriptor, functionDescriptor.hashCode(),
args, numReturns, options);
if (returnIds == null) {
return ImmutableList.of();
}
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}

@Override
public BaseActorHandle createActor(FunctionDescriptor functionDescriptor, List<FunctionArg> args,
ActorCreationOptions options) {
byte[] actorId = nativeCreateActor(functionDescriptor, args, options);
byte[] actorId = nativeCreateActor(functionDescriptor, functionDescriptor.hashCode(), args,
options);
return NativeActorHandle.create(actorId, functionDescriptor.getLanguage());
}

Expand All @@ -35,17 +41,21 @@ public List<ObjectId> submitActorTask(
List<FunctionArg> args, int numReturns, CallOptions options) {
Preconditions.checkState(actor instanceof NativeActorHandle);
List<byte[]> returnIds = nativeSubmitActorTask(actor.getId().getBytes(),
functionDescriptor, args, numReturns, options);
functionDescriptor, functionDescriptor.hashCode(), args, numReturns, options);
if (returnIds == null) {
return ImmutableList.of();
}
return returnIds.stream().map(ObjectId::new).collect(Collectors.toList());
}

private static native List<byte[]> nativeSubmitTask(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, int numReturns, CallOptions callOptions);
int functionDescriptorHash, List<FunctionArg> args, int numReturns, CallOptions callOptions);

private static native byte[] nativeCreateActor(FunctionDescriptor functionDescriptor,
List<FunctionArg> args, ActorCreationOptions actorCreationOptions);
int functionDescriptorHash, List<FunctionArg> args,
ActorCreationOptions actorCreationOptions);

private static native List<byte[]> nativeSubmitActorTask(byte[] actorId,
FunctionDescriptor functionDescriptor, List<FunctionArg> args, int numReturns,
CallOptions callOptions);
FunctionDescriptor functionDescriptor, int functionDescriptorHash, List<FunctionArg> args,
int numReturns, CallOptions callOptions);
}
96 changes: 83 additions & 13 deletions src/ray/common/function_descriptor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class FunctionDescriptorInterface : public MessageWrapper<rpc::FunctionDescripto

virtual size_t Hash() const = 0;

// DO NOT define operator==() or operator!=() in the base class.
// Let the derived classes define and implement.
// This is to avoid unexpected behaviors when comparing function descriptors of
// different declard types, as in this case, the base class version is invoked.

virtual std::string ToString() const = 0;

// A one-word summary of the function call site (e.g., __main__.foo).
Expand All @@ -67,6 +72,10 @@ class EmptyFunctionDescriptor : public FunctionDescriptorInterface {
return std::hash<int>()(ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET);
}

inline bool operator==(const EmptyFunctionDescriptor &other) const { return true; }

inline bool operator!=(const EmptyFunctionDescriptor &other) const { return false; }

virtual std::string ToString() const { return "{type=EmptyFunctionDescriptor}"; }
};

Expand All @@ -90,17 +99,30 @@ class JavaFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->signature());
}

inline bool operator==(const JavaFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->ClassName() == other.ClassName() &&
this->FunctionName() == other.FunctionName() &&
this->Signature() == other.Signature();
}

inline bool operator!=(const JavaFunctionDescriptor &other) const {
return !(*this == other);
}

virtual std::string ToString() const {
return "{type=JavaFunctionDescriptor, class_name=" + typed_message_->class_name() +
", function_name=" + typed_message_->function_name() +
", signature=" + typed_message_->signature() + "}";
}

std::string ClassName() const { return typed_message_->class_name(); }
const std::string &ClassName() const { return typed_message_->class_name(); }

std::string FunctionName() const { return typed_message_->function_name(); }
const std::string &FunctionName() const { return typed_message_->function_name(); }

std::string Signature() const { return typed_message_->signature(); }
const std::string &Signature() const { return typed_message_->signature(); }

private:
const rpc::JavaFunctionDescriptor *typed_message_;
Expand All @@ -127,6 +149,20 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->function_hash());
}

inline bool operator==(const PythonFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->ModuleName() == other.ModuleName() &&
this->ClassName() == other.ClassName() &&
this->FunctionName() == other.FunctionName() &&
this->FunctionHash() == other.FunctionHash();
}

inline bool operator!=(const PythonFunctionDescriptor &other) const {
return !(*this == other);
}

virtual std::string ToString() const {
return "{type=PythonFunctionDescriptor, module_name=" +
typed_message_->module_name() +
Expand All @@ -140,13 +176,13 @@ class PythonFunctionDescriptor : public FunctionDescriptorInterface {
typed_message_->function_name();
}

std::string ModuleName() const { return typed_message_->module_name(); }
const std::string &ModuleName() const { return typed_message_->module_name(); }

std::string ClassName() const { return typed_message_->class_name(); }
const std::string &ClassName() const { return typed_message_->class_name(); }

std::string FunctionName() const { return typed_message_->function_name(); }
const std::string &FunctionName() const { return typed_message_->function_name(); }

std::string FunctionHash() const { return typed_message_->function_hash(); }
const std::string &FunctionHash() const { return typed_message_->function_hash(); }

private:
const rpc::PythonFunctionDescriptor *typed_message_;
Expand All @@ -172,17 +208,30 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface {
std::hash<std::string>()(typed_message_->exec_function_offset());
}

inline bool operator==(const CppFunctionDescriptor &other) const {
if (this == &other) {
return true;
}
return this->LibName() == other.LibName() &&
this->FunctionOffset() == other.FunctionOffset() &&
this->ExecFunctionOffset() == other.ExecFunctionOffset();
}

inline bool operator!=(const CppFunctionDescriptor &other) const {
return !(*this == other);
}

virtual std::string ToString() const {
return "{type=CppFunctionDescriptor, lib_name=" + typed_message_->lib_name() +
", function_offset=" + typed_message_->function_offset() +
", exec_function_offset=" + typed_message_->exec_function_offset() + "}";
}

std::string LibName() const { return typed_message_->lib_name(); }
const std::string &LibName() const { return typed_message_->lib_name(); }

std::string FunctionOffset() const { return typed_message_->function_offset(); }
const std::string &FunctionOffset() const { return typed_message_->function_offset(); }

std::string ExecFunctionOffset() const {
const std::string &ExecFunctionOffset() const {
return typed_message_->exec_function_offset();
}

Expand All @@ -193,11 +242,32 @@ class CppFunctionDescriptor : public FunctionDescriptorInterface {
typedef std::shared_ptr<FunctionDescriptorInterface> FunctionDescriptor;

inline bool operator==(const FunctionDescriptor &left, const FunctionDescriptor &right) {
if (left.get() != nullptr && right.get() != nullptr && left->Type() == right->Type() &&
left->ToString() == right->ToString()) {
if (left.get() == right.get()) {
return true;
}
return left.get() == right.get();
if (left.get() == nullptr || right.get() == nullptr) {
return false;
}
if (left->Type() != right->Type()) {
return false;
}
switch (left->Type()) {
case ray::FunctionDescriptorType::FUNCTION_DESCRIPTOR_NOT_SET:
return static_cast<const EmptyFunctionDescriptor &>(*left) ==
static_cast<const EmptyFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kJavaFunctionDescriptor:
return static_cast<const JavaFunctionDescriptor &>(*left) ==
static_cast<const JavaFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kPythonFunctionDescriptor:
return static_cast<const PythonFunctionDescriptor &>(*left) ==
static_cast<const PythonFunctionDescriptor &>(*right);
case ray::FunctionDescriptorType::kCppFunctionDescriptor:
return static_cast<const CppFunctionDescriptor &>(*left) ==
static_cast<const CppFunctionDescriptor &>(*right);
default:
RAY_LOG(FATAL) << "Unknown function descriptor type: " << left->Type();
return false;
}
}

inline bool operator!=(const FunctionDescriptor &left, const FunctionDescriptor &right) {
Expand Down
44 changes: 33 additions & 11 deletions src/ray/core_worker/lib/java/io_ray_runtime_RayNativeRuntime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@
thread_local JNIEnv *local_env = nullptr;
jobject java_task_executor = nullptr;

/// Store Java instances of function descriptor in the cache to avoid unnessesary JNI
/// operations.
thread_local std::unordered_map<size_t,
std::vector<std::pair<ray::FunctionDescriptor, jobject>>>
executor_function_descriptor_cache;

inline ray::gcs::GcsClientOptions ToGcsClientOptions(JNIEnv *env,
jobject gcs_client_options) {
std::string ip = JavaStringToNativeString(
Expand Down Expand Up @@ -73,9 +79,24 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(

RAY_CHECK(env);
RAY_CHECK(java_task_executor);

// convert RayFunction
jobject ray_function_array_list = NativeRayFunctionDescriptorToJavaStringList(
env, ray_function.GetFunctionDescriptor());
auto function_descriptor = ray_function.GetFunctionDescriptor();
size_t fd_hash = function_descriptor->Hash();
auto &fd_vector = executor_function_descriptor_cache[fd_hash];
jobject ray_function_array_list = nullptr;
for (auto &pair : fd_vector) {
if (pair.first == function_descriptor) {
ray_function_array_list = pair.second;
break;
}
}
if (!ray_function_array_list) {
ray_function_array_list =
NativeRayFunctionDescriptorToJavaStringList(env, function_descriptor);
fd_vector.emplace_back(function_descriptor, ray_function_array_list);
}

// convert args
// TODO (kfstorm): Avoid copying binary data from Java to C++
jobject args_array_list = NativeVectorToJavaList<std::shared_ptr<ray::RayObject>>(
Expand All @@ -86,19 +107,20 @@ JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
env->CallObjectMethod(java_task_executor, java_task_executor_execute,
ray_function_array_list, args_array_list);
RAY_CHECK_JAVA_EXCEPTION(env);
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
JavaListToNativeVector<std::shared_ptr<ray::RayObject>>(
env, java_return_objects, &return_objects,
[](JNIEnv *env, jobject java_native_ray_object) {
return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object);
});
for (auto &obj : return_objects) {
results->push_back(obj);
if (!return_ids.empty()) {
std::vector<std::shared_ptr<ray::RayObject>> return_objects;
JavaListToNativeVector<std::shared_ptr<ray::RayObject>>(
env, java_return_objects, &return_objects,
[](JNIEnv *env, jobject java_native_ray_object) {
return JavaNativeRayObjectToNativeRayObject(env, java_native_ray_object);
});
for (auto &obj : return_objects) {
results->push_back(obj);
}
}

env->DeleteLocalRef(java_return_objects);
env->DeleteLocalRef(args_array_list);
env->DeleteLocalRef(ray_function_array_list);
return ray::Status::OK();
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ extern "C" {
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeInitialize
* Signature:
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;ILjava/lang/String;Ljava/util/Map;)V
* (ILjava/lang/String;ILjava/lang/String;Ljava/lang/String;Ljava/lang/String;[BLio/ray/runtime/gcs/GcsClientOptions;Ljava/lang/String;ILjava/lang/String;Ljava/util/Map;[B)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeInitialize(
JNIEnv *, jclass, jint, jstring, jint, jstring, jstring, jstring, jbyteArray, jobject,
Expand All @@ -42,7 +42,7 @@ Java_io_ray_runtime_RayNativeRuntime_nativeRunTaskExecutor(JNIEnv *, jclass, job
/*
* Class: io_ray_runtime_RayNativeRuntime
* Method: nativeShutdown
* Signature: ()V
* Signature: (Z)V
*/
JNIEXPORT void JNICALL Java_io_ray_runtime_RayNativeRuntime_nativeShutdown(JNIEnv *,
jclass);
Expand Down
Loading