Skip to content

Commit

Permalink
[xlang] Cross language serialization for ActorHandle (#10335)
Browse files Browse the repository at this point in the history
  • Loading branch information
fyrestone authored Sep 2, 2020
1 parent 65f17f2 commit b04222d
Show file tree
Hide file tree
Showing 8 changed files with 48 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import com.google.protobuf.InvalidProtocolBufferException;
import io.ray.api.id.ObjectId;
import io.ray.runtime.actor.NativeActorHandle;
import io.ray.runtime.exception.RayActorException;
import io.ray.runtime.exception.RayTaskException;
import io.ray.runtime.exception.RayWorkerException;
Expand Down Expand Up @@ -35,6 +36,10 @@ public class ObjectSerializer {
public static final byte[] OBJECT_METADATA_TYPE_JAVA = "JAVA".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_PYTHON = "PYTHON".getBytes();
public static final byte[] OBJECT_METADATA_TYPE_RAW = "RAW".getBytes();
// A constant used as object metadata to indicate the object is an actor handle.
// This value should be synchronized with the Python definition in ray_constants.py
// TODO(fyrestone): Serialize the ActorHandle via the custom type feature of XLANG.
public static final byte[] OBJECT_METADATA_TYPE_ACTOR_HANDLE = "ACTOR_HANDLE".getBytes();

// When an outer object is being serialized, the nested ObjectRefs are all
// serialized and the writeExternal method of the nested ObjectRefs are
Expand Down Expand Up @@ -86,6 +91,9 @@ public static Object deserialize(NativeRayObject nativeRayObject, ObjectId objec
"Can't deserialize RayTaskException object: " + objectId
.toString());
}
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_ACTOR_HANDLE)) {
byte[] serialized = Serializer.decode(data, byte[].class);
return NativeActorHandle.fromBytes(serialized);
} else if (Arrays.equals(meta, OBJECT_METADATA_TYPE_PYTHON)) {
throw new IllegalArgumentException("Can't deserialize Python object: " + objectId
.toString());
Expand Down Expand Up @@ -129,6 +137,13 @@ public static NativeRayObject serialize(Object object) {
// Only OBJECT_METADATA_TYPE_RAW is raw bytes,
// any other type should be the MessagePack serialized bytes.
return new NativeRayObject(serializedBytes, TASK_EXECUTION_EXCEPTION_META);
} else if (object instanceof NativeActorHandle) {
NativeActorHandle actorHandle = (NativeActorHandle)object;
byte[] serializedBytes = Serializer.encode(actorHandle.toBytes()).getLeft();
// serializedBytes is MessagePack serialized bytes
// Only OBJECT_METADATA_TYPE_RAW is raw bytes,
// any other type should be the MessagePack serialized bytes.
return new NativeRayObject(serializedBytes, OBJECT_METADATA_TYPE_ACTOR_HANDLE);
} else {
try {
Pair<byte[], Boolean> serialized = Serializer.encode(object);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ public static List<FunctionArg> wrap(Object[] args, Language language) {
if (language != Language.JAVA) {
boolean isCrossData =
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_CROSS_LANGUAGE) ||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW);
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_RAW) ||
Arrays.equals(value.metadata, ObjectSerializer.OBJECT_METADATA_TYPE_ACTOR_HANDLE);
if (!isCrossData) {
throw new IllegalArgumentException(String.format("Can't transfer %s data to %s",
Arrays.toString(value.metadata), language.getValueDescriptor().getName()));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -167,23 +167,21 @@ public void testPassActorHandleFromJavaToPython() {
// Create a java actor, and pass actor handle to python.
ActorHandle<TestActor> javaActor = Ray.actor(TestActor::new, "1".getBytes()).remote();
Preconditions.checkState(javaActor instanceof NativeActorHandle);
byte[] actorHandleBytes = ((NativeActorHandle) javaActor).toBytes();
ObjectRef<byte[]> res = Ray.task(
PyFunction.of(PYTHON_MODULE,
"py_func_call_java_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
javaActor).remote();
Assert.assertEquals(res.get(), "12".getBytes());
// Create a python actor, and pass actor handle to python.
PyActorHandle pyActor = Ray.actor(
PyActorClass.of(PYTHON_MODULE, "Counter"), "1".getBytes()).remote();
Preconditions.checkState(pyActor instanceof NativeActorHandle);
actorHandleBytes = ((NativeActorHandle) pyActor).toBytes();
res = Ray.task(
PyFunction.of(PYTHON_MODULE,
"py_func_call_python_actor_from_handle",
byte[].class),
actorHandleBytes).remote();
pyActor).remote();
Assert.assertEquals(res.get(), "3".getBytes());
}

Expand Down Expand Up @@ -301,9 +299,8 @@ public static int[] returnInputIntArray(int[] l) {
return l;
}

public static byte[] callPythonActorHandle(byte[] value) {
public static byte[] callPythonActorHandle(PyActorHandle actor) {
// This function will be called from test_cross_language_invocation.py
NativePyActorHandle actor = (NativePyActorHandle) NativeActorHandle.fromBytes(value);
ObjectRef<byte[]> res = actor.task(
PyActorMethod.of("increase", byte[].class),
"1".getBytes()).remote();
Expand Down
10 changes: 3 additions & 7 deletions java/test/src/main/resources/test_cross_language_invocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,13 @@ def py_func_call_java_actor(value):


@ray.remote
def py_func_call_java_actor_from_handle(value):
assert isinstance(value, bytes)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value)
def py_func_call_java_actor_from_handle(actor_handle):
r = actor_handle.concat.remote(b"2")
return ray.get(r)


@ray.remote
def py_func_call_python_actor_from_handle(value):
assert isinstance(value, bytes)
actor_handle = ray.actor.ActorHandle._deserialization_helper(value)
def py_func_call_python_actor_from_handle(actor_handle):
r = actor_handle.increase.remote(2)
return ray.get(r)

Expand All @@ -79,7 +75,7 @@ def py_func_pass_python_actor_handle():
counter = Counter.remote(2)
f = ray.java_function("io.ray.test.CrossLanguageInvocationTest",
"callPythonActorHandle")
r = f.remote(counter._serialization_helper()[0])
r = f.remote(counter)
return ray.get(r)


Expand Down
3 changes: 2 additions & 1 deletion python/ray/_raylet.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,8 @@ cdef prepare_args(
if language != Language.PYTHON:
if metadata not in [
ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE,
ray_constants.OBJECT_METADATA_TYPE_RAW]:
ray_constants.OBJECT_METADATA_TYPE_RAW,
ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE]:
raise Exception("Can't transfer {} data to {}".format(
metadata, language))
size = serialized_arg.total_bytes
Expand Down
5 changes: 2 additions & 3 deletions python/ray/includes/serialization.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -436,15 +436,14 @@ cdef class MessagePackSerializedObject(SerializedObject):
const uint8_t *msgpack_header_ptr
const uint8_t *msgpack_data_ptr

def __init__(self, metadata, msgpack_data,
def __init__(self, metadata, msgpack_data, contained_object_refs,
SerializedObject nest_serialized_object=None):
if nest_serialized_object:
contained_object_refs = (
contained_object_refs.extend(
nest_serialized_object.contained_object_refs
)
total_bytes = nest_serialized_object.total_bytes
else:
contained_object_refs = []
total_bytes = 0
super(MessagePackSerializedObject, self).__init__(
metadata,
Expand Down
7 changes: 7 additions & 0 deletions python/ray/ray_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,13 @@ def to_memory_units(memory_bytes, round_up):
# A constant used as object metadata to indicate the object is raw bytes.
OBJECT_METADATA_TYPE_RAW = b"RAW"

# A constant used as object metadata to indicate the object is an actor handle.
# This value should be synchronized with the Java definition in
# ObjectSerializer.java
# TODO(fyrestone): Serialize the ActorHandle via the custom type feature
# of XLANG.
OBJECT_METADATA_TYPE_ACTOR_HANDLE = b"ACTOR_HANDLE"

AUTOSCALER_RESOURCE_REQUEST_CHANNEL = b"autoscaler_resource_request"

# The default password to prevent redis port scanning attack.
Expand Down
14 changes: 14 additions & 0 deletions python/ray/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,9 @@ def _deserialize_object(self, data, metadata, object_ref):
if data is None:
return b""
return data.to_pybytes()
elif metadata == ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE:
obj = self._deserialize_msgpack_data(data, metadata)
return actor_handle_deserializer(obj)
# Otherwise, return an exception object based on
# the error type.
try:
Expand Down Expand Up @@ -349,10 +352,20 @@ def _serialize_to_pickle5(self, metadata, value):
def _serialize_to_msgpack(self, value):
# Only RayTaskError is possible to be serialized here. We don't
# need to deal with other exception types here.
contained_object_refs = []

if isinstance(value, RayTaskError):
metadata = str(
ErrorType.Value("TASK_EXECUTION_EXCEPTION")).encode("ascii")
value = value.to_bytes()
elif isinstance(value, ray.actor.ActorHandle):
# TODO(fyresone): ActorHandle should be serialized via the
# custom type feature of cross-language.
serialized, actor_handle_id = value._serialization_helper()
contained_object_refs.append(actor_handle_id)
# Update ref counting for the actor handle
metadata = ray_constants.OBJECT_METADATA_TYPE_ACTOR_HANDLE
value = serialized
else:
metadata = ray_constants.OBJECT_METADATA_TYPE_CROSS_LANGUAGE

Expand All @@ -373,6 +386,7 @@ def _python_serializer(o):
pickle5_serialized_object = None

return MessagePackSerializedObject(metadata, msgpack_data,
contained_object_refs,
pickle5_serialized_object)

def serialize(self, value):
Expand Down

0 comments on commit b04222d

Please sign in to comment.