Skip to content

Commit

Permalink
[RUNTIME][RPC] Enable RPCObjectRef return in RPC (apache#16387)
Browse files Browse the repository at this point in the history
[Runtime] Enable RPCObjectRef return in RPC

This PR enables RPCObjectRef return object similar to the disco transporation.
This allows us to do advanced remote debugging when remote vm requires
advanced object input like kv cache and shape.

To keep the implementation with minRPC(used in some of the limited protocols) forn now,
we only support RPCObjectRef for now and do not enable unpacking Shape and String.
  • Loading branch information
tqchen authored Jan 12, 2024
1 parent f1bf20a commit 4258c86
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 13 deletions.
4 changes: 3 additions & 1 deletion include/tvm/runtime/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,10 @@ struct TypeIndex {
kRuntimeShapeTuple = 6,
/*! \brief runtime::PackedFunc. */
kRuntimePackedFunc = 7,
/*! \brief runtime::DRef */
/*! \brief runtime::DRef for disco distributed runtime */
kRuntimeDiscoDRef = 8,
/*! \brief runtime::RPCObjectRef */
kRuntimeRPCObjectRef = 9,
// static assignments that may subject to change.
kRuntimeClosure,
kRuntimeADT,
Expand Down
15 changes: 13 additions & 2 deletions src/runtime/minrpc/minrpc_server.h
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,8 @@ class MinRPCExecute : public MinRPCExecInterface {
ret_tcode[1] = kTVMBytes;
ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); // NOLINT(*)
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle ||
rv_tcode == kTVMObjectHandle) {
ret_tcode[1] = kTVMOpaqueHandle;
ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
} else {
Expand Down Expand Up @@ -755,7 +756,17 @@ class MinRPCServer {
}

void ReadObject(int* tcode, TVMValue* value) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
// handles RPCObject in minRPC
// NOTE: object needs to be supported by C runtime
// because minrpc's restriction of C only
// we only handle RPCObjectRef
uint32_t type_index;
Read(&type_index);
MINRPC_CHECK(type_index == kRuntimeRPCObjectRefTypeIndex);
uint64_t object_handle;
Read(&object_handle);
tcode[0] = kTVMObjectHandle;
value[0].v_handle = reinterpret_cast<void*>(object_handle);
}

private:
Expand Down
8 changes: 8 additions & 0 deletions src/runtime/minrpc/rpc_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,14 @@ class Object;
/*! \brief The current RPC procotol version. */
constexpr const char* kRPCProtocolVer = "0.8.0";

/*!
* \brief type index of kRuntimeRPCObjectRefTypeIndex
* \note this needs to be kept consistent with runtime/object.h
* but we explicitly declare it here because minrpc needs to be minimum dep
* only c C API
*/
constexpr const int kRuntimeRPCObjectRefTypeIndex = 9;

// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;

Expand Down
51 changes: 44 additions & 7 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
for (int i = 0; i < num_args; ++i) {
int tcode = type_codes[i];
if (tcode == kTVMObjectHandle || tcode == kTVMObjectRValueRefArg) {
LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
<< args[i].AsObjectRef<ObjectRef>()->GetTypeKey() << " is not supported by RPC";
if (!args[i].IsObjectRef<RPCObjectRef>()) {
LOG(FATAL) << "ValueError: Cannot pass argument " << i << ", type "
<< args[i].AsObjectRef<ObjectRef>()->GetTypeKey()
<< " is not supported by RPC";
}
} else if (tcode == kDLDevice) {
DLDevice dev = args[i];
ICHECK(!IsRPCSessionDevice(dev)) << "InternalError: cannot pass RPC device in the channel";
Expand Down Expand Up @@ -219,14 +222,48 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
this->Write(cdata);
}

void WriteObject(void* obj) { this->ThrowError(RPCServerStatus::kUnknownTypeCode); }
uint64_t GetObjectBytes(void* obj) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
return 0;
void WriteObject(Object* obj) {
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
//
// Rationale note: Only handle remote object allows the same mechanism to work for minRPC
// which is needed for wasm and other env that goes through C API
if (obj->IsInstance<RPCObjectRefObj>()) {
auto* ref = static_cast<RPCObjectRefObj*>(obj);
this->template Write<uint32_t>(kRuntimeRPCObjectRefTypeIndex);
uint64_t handle = reinterpret_cast<uint64_t>(ref->object_handle());
this->template Write<int64_t>(handle);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: "
<< obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")";
}
}
uint64_t GetObjectBytes(Object* obj) {
if (obj->IsInstance<RPCObjectRefObj>()) {
return sizeof(uint32_t) + sizeof(int64_t);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in RPC calling convention: "
<< obj->GetTypeKey() << " (type_index = " << obj->type_index() << ")";
}
}

void ReadObject(int* tcode, TVMValue* value) {
this->ThrowError(RPCServerStatus::kUnknownTypeCode);
// NOTE: for now all remote object are encoded as RPCObjectRef
// follow the same disco protocol in case we would like to upgrade later
//
// Rationale note: Only handle remote object allows the same mechanism to work for minRPC
// which is needed for wasm and other env that goes through C API
uint32_t type_index;
this->template Read<uint32_t>(&type_index);
if (type_index == kRuntimeRPCObjectRefTypeIndex) {
uint64_t handle;
this->template Read<uint64_t>(&handle);
tcode[0] = kTVMObjectHandle;
value[0].v_handle = reinterpret_cast<void*>(handle);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: "
<< Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")";
}
}

void MessageDone() {
Expand Down
20 changes: 18 additions & 2 deletions src/runtime/rpc/rpc_local_session.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/runtime/registry.h>

#include <memory>
#include <vector>

namespace tvm {
namespace runtime {
Expand Down Expand Up @@ -64,7 +65,8 @@ void LocalSession::EncodeReturn(TVMRetValue rv, const FEncodeReturn& encode_retu
ret_value_pack[2].v_handle = ret_value_pack[1].v_handle;
ret_tcode_pack[2] = kTVMOpaqueHandle;
encode_return(TVMArgs(ret_value_pack, ret_tcode_pack, 3));
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle ||
rv_tcode == kTVMObjectHandle) {
// MoveToCHost means rv no longer manages the object.
// return handle instead.
rv.MoveToCHost(&ret_value_pack[1], &ret_tcode_pack[1]);
Expand All @@ -88,7 +90,21 @@ void LocalSession::CallFunc(RPCSession::PackedFuncHandle func, const TVMValue* a
const FEncodeReturn& encode_return) {
PackedFuncObj* pf = static_cast<PackedFuncObj*>(func);
TVMRetValue rv;
pf->CallPacked(TVMArgs(arg_values, arg_type_codes, num_args), &rv);

// unwrap RPCObjectRef in case we are directly using it to call LocalSession
std::vector<TVMValue> values(arg_values, arg_values + num_args);
std::vector<int> type_codes(arg_type_codes, arg_type_codes + num_args);
TVMArgs args(arg_values, arg_type_codes, num_args);

for (int i = 0; i < num_args; ++i) {
if (args[i].IsObjectRef<RPCObjectRef>()) {
RPCObjectRef obj_ref = args[i];
values[i].v_handle = obj_ref->object_handle();
continue;
}
}

pf->CallPacked(TVMArgs(values.data(), type_codes.data(), args.size()), &rv);
this->EncodeReturn(std::move(rv), encode_return);
}

Expand Down
7 changes: 7 additions & 0 deletions src/runtime/rpc/rpc_module.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,6 +157,8 @@ class RPCWrappedFunc : public Object {
}
};

TVM_REGISTER_OBJECT_TYPE(RPCObjectRefObj);

// RPC that represents a remote module session.
class RPCModuleNode final : public ModuleNode {
public:
Expand Down Expand Up @@ -294,6 +296,11 @@ void RPCWrappedFunc::WrapRemoteReturnToValue(TVMArgs args, TVMRetValue* rv) cons
void* handle = args[1];
auto n = make_object<RPCModuleNode>(handle, sess_);
*rv = Module(n);
} else if (tcode == kTVMObjectHandle) {
ICHECK_EQ(args.size(), 2);
void* handle = args[1];
auto n = make_object<RPCObjectRefObj>(handle, sess_);
*rv = ObjectRef(n);
} else if (tcode == kTVMDLTensorHandle || tcode == kTVMNDArrayHandle) {
ICHECK_EQ(args.size(), 3);
DLTensor* tensor = args[1];
Expand Down
51 changes: 50 additions & 1 deletion src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ class RPCSession {

/*!
* \brief Free a remote function.
* \param handle The remote handle, can be NDArray/PackedFunc/Module
* \param handle The remote handle, can be NDArray/PackedFunc/Module/Object
* \param type_code The type code of the underlying type.
*/
virtual void FreeHandle(void* handle, int type_code) = 0;
Expand Down Expand Up @@ -287,6 +287,55 @@ struct RemoteSpace {
std::shared_ptr<RPCSession> sess;
};

/*!
* \brief Object wrapper that represents a reference to a remote object
*/
class RPCObjectRefObj : public Object {
public:
/*!
* \brief constructor
* \param object_handle handle that points to the remote object
* \param sess The remote session
*/
RPCObjectRefObj(void* object_handle, std::shared_ptr<RPCSession> sess)
: object_handle_(object_handle), sess_(sess) {}

~RPCObjectRefObj() {
if (object_handle_ != nullptr) {
try {
sess_->FreeHandle(object_handle_, kTVMObjectHandle);
} catch (const Error& e) {
// fault tolerance to remote close
}
object_handle_ = nullptr;
}
}

const std::shared_ptr<RPCSession>& sess() const { return sess_; }

void* object_handle() const { return object_handle_; }

static constexpr const uint32_t _type_index = TypeIndex::kRuntimeRPCObjectRef;
static constexpr const char* _type_key = "runtime.RPCObjectRef";
TVM_DECLARE_FINAL_OBJECT_INFO(RPCObjectRefObj, Object);

private:
// The object handle
void* object_handle_{nullptr};
// The local channel
std::shared_ptr<RPCSession> sess_;
};

/*!
* \brief Managed reference to RPCObjectRefObj.
* \sa RPCObjectRefObj
* \note No public constructor is provided as it is not supposed to be directly created by users.
*/
class RPCObjectRef : public ObjectRef {
public:
TVM_DEFINE_MUTABLE_NOTNULLABLE_OBJECT_REF_METHODS(RPCObjectRef, ObjectRef, RPCObjectRefObj);
};

/*!
* \brief Create a Global RPC module that refers to the session.
* \param sess The RPC session of the global module.
Expand Down
31 changes: 31 additions & 0 deletions tests/python/runtime/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,6 +426,7 @@ def test_rpc_return_ndarray():
ref_count = m("ref_count")
get_elem = m("get_elem")
get_arr_elem = m("get_arr_elem")

# array test
def run_arr_test():
arr = get_arr()
Expand All @@ -435,6 +436,36 @@ def run_arr_test():
run_arr_test()


@tvm.testing.requires_rpc
def test_rpc_return_remote_object():
def check(client, is_local):
make_shape = client.get_function("runtime.ShapeTuple")
get_elem = client.get_function("runtime.GetShapeTupleElem")
get_size = client.get_function("runtime.GetShapeTupleSize")
shape = make_shape(2, 3)
assert shape.type_key == "runtime.RPCObjectRef"
assert get_elem(shape, 0) == 2
assert get_elem(shape, 1) == 3
assert get_size(shape) == 2

# start server
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")
check(rpc.LocalSession(), True)
check(client, False)

def check_minrpc():
if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None:
return
# Test minrpc server.
temp = utils.tempdir()
minrpc_exec = temp.relpath("minrpc")
tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, [])
check(rpc.PopenSession(minrpc_exec), False)

check_minrpc()


@tvm.testing.requires_rpc
def test_local_func():
client = rpc.LocalSession()
Expand Down

0 comments on commit 4258c86

Please sign in to comment.