Skip to content

Commit

Permalink
[RUNTIME][RPC] Enable RPCObjectRef over multi-hop RPC (#16635)
Browse files Browse the repository at this point in the history
This PR enables RPCObjectRef over multi-hop RPC.
It is necessary to rewrap the argument as RPCObjectRef
so that the intermediate validation and re-encoding logic can
follow through.
  • Loading branch information
tqchen authored Feb 24, 2024
1 parent 8194b48 commit 7e269dc
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 8 deletions.
18 changes: 15 additions & 3 deletions src/runtime/rpc/rpc_endpoint.cc
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
if (type_index == kRuntimeRPCObjectRefTypeIndex) {
uint64_t handle;
this->template Read<uint64_t>(&handle);
tcode[0] = kTVMObjectHandle;
value[0].v_handle = reinterpret_cast<void*>(handle);
// Always wrap things back in RPCObjectRef
// this is because we want to enable multi-hop RPC
// and next hop would also need to check the object index
RPCObjectRef rpc_obj(make_object<RPCObjectRefObj>(reinterpret_cast<void*>(handle), nullptr));
TVMArgsSetter(value, tcode)(0, rpc_obj);
object_arena_.push_back(rpc_obj);
} else {
LOG(FATAL) << "ValueError: Object type is not supported in Disco calling convention: "
<< Object::TypeIndex2Key(type_index) << " (type_index = " << type_index << ")";
Expand All @@ -276,6 +280,12 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
return arena_.template allocate_<T>(count);
}

/*! \brief Recycle all the memory used in the arena */
void RecycleAll() {
this->object_arena_.clear();
this->arena_.RecycleAll();
}

protected:
enum State {
kInitHeader,
Expand All @@ -296,6 +306,8 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
bool async_server_mode_{false};
// Internal arena
support::Arena arena_;
// internal arena for temp objects
std::vector<ObjectRef> object_arena_;

// State switcher
void SwitchToState(State state) {
Expand All @@ -313,7 +325,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
if (state == kRecvPacketNumBytes) {
this->RequestBytes(sizeof(uint64_t));
// recycle arena for the next session.
arena_.RecycleAll();
this->RecycleAll();
}
}

Expand Down
7 changes: 5 additions & 2 deletions src/runtime/rpc/rpc_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,13 +295,16 @@ class RPCObjectRefObj : public Object {
/*!
* \brief constructor
* \param object_handle handle that points to the remote object
* \param sess The remote session
*
* \param sess The remote session, when session is nullptr
* it indicate the object is a temp object during rpc transmission
* and we don't have to free it
*/
RPCObjectRefObj(void* object_handle, std::shared_ptr<RPCSession> sess)
: object_handle_(object_handle), sess_(sess) {}

~RPCObjectRefObj() {
if (object_handle_ != nullptr) {
if (object_handle_ != nullptr && sess_ != nullptr) {
try {
sess_->FreeHandle(object_handle_, kTVMObjectHandle);
} catch (const Error& e) {
Expand Down
19 changes: 16 additions & 3 deletions tests/python/runtime/test_runtime_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,10 +449,15 @@ def check(client, is_local):
assert get_size(shape) == 2

# start server
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")

check(rpc.LocalSession(), True)
check(client, False)

def check_remote():
server = rpc.Server(key="x1")
client = rpc.connect("127.0.0.1", server.port, key="x1")
check(client, False)

check_remote()

def check_minrpc():
if tvm.get_global_func("rpc.CreatePipeClient", allow_missing=True) is None:
Expand All @@ -462,6 +467,14 @@ def check_minrpc():
minrpc_exec = temp.relpath("minrpc")
tvm.rpc.with_minrpc(cc.create_executable)(minrpc_exec, [])
check(rpc.PopenSession(minrpc_exec), False)
# minrpc on the remote
server = rpc.Server()
client = rpc.connect(
"127.0.0.1",
server.port,
session_constructor_args=["rpc.PopenSession", open(minrpc_exec, "rb").read()],
)
check(client, False)

check_minrpc()

Expand Down

0 comments on commit 7e269dc

Please sign in to comment.