Skip to content

Commit 9216cf8

Browse files
mehrdadhtrevor-m
authored andcommitted
[RPC] microtvm: fix RPC large transfer size issue (apache#7838)
* fix rpc for microtvm * apply feedbacks * bundle deploy fix * fix func registry size * mv constant * fix copyfromremote * address comments and fix error * change rpc default max size * Trigger Build * add checks * Trigger Build * fix ICHECK
1 parent d993927 commit 9216cf8

File tree

7 files changed

+137
-26
lines changed

7 files changed

+137
-26
lines changed

apps/bundle_deploy/crt_config/crt_config.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,4 +45,7 @@
4545
/*! Size of the global function registry, in bytes. */
4646
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 200
4747

48+
/*! Maximum packet size, in bytes, including the length header. */
49+
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 512
50+
4851
#endif // TVM_RUNTIME_CRT_CONFIG_H_

src/runtime/crt/common/crt_runtime_api.c

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -307,8 +307,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
307307
}
308308

309309
int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
310-
return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name,
311-
out);
310+
tvm_crt_error_t to_return =
311+
FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
312+
// For compatibility with the C++ runtime equivalent, in src/runtime/registry.cc.
313+
if (to_return == kTvmErrorFunctionNameNotFound) {
314+
*out = NULL;
315+
to_return = kTvmErrorNoError;
316+
}
317+
return to_return;
312318
}
313319

314320
int TVMModGetFunction(TVMModuleHandle mod, const char* func_name, int query_imports,
@@ -352,7 +358,6 @@ int ModuleGetFunction(TVMValue* args, int* type_codes, int num_args, TVMValue* r
352358
if (to_return == kTvmErrorFunctionNameNotFound) {
353359
to_return = kTvmErrorNoError;
354360
}
355-
356361
return to_return;
357362
}
358363

@@ -381,6 +386,17 @@ int TVMFuncFree(TVMFunctionHandle func) {
381386

382387
int RPCTimeEvaluator(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_val,
383388
int* ret_type_code);
389+
390+
// Sends CRT max packet size.
391+
int RPCGetCRTMaxPacketSize(TVMValue* args, int* type_codes, int num_args, TVMValue* ret_value,
392+
int* ret_type_codes) {
393+
// 11 bytes is for microtvm overhead:
394+
// packet start(2), length(4), session header(3), crc(2)
395+
ret_value[0].v_int64 = TVM_CRT_MAX_PACKET_SIZE_BYTES - 11;
396+
ret_type_codes[0] = kTVMArgInt;
397+
return 0;
398+
}
399+
384400
tvm_crt_error_t TVMInitializeRuntime() {
385401
int idx = 0;
386402
tvm_crt_error_t error = kTvmErrorNoError;
@@ -421,6 +437,10 @@ tvm_crt_error_t TVMInitializeRuntime() {
421437
error = TVMFuncRegisterGlobal("runtime.RPCTimeEvaluator", &RPCTimeEvaluator, 0);
422438
}
423439

440+
if (error == kTvmErrorNoError) {
441+
error = TVMFuncRegisterGlobal("tvm.rpc.server.GetCRTMaxPacketSize", &RPCGetCRTMaxPacketSize, 0);
442+
}
443+
424444
if (error != kTvmErrorNoError) {
425445
TVMPlatformMemoryFree(registry_backing_memory, dev);
426446
TVMPlatformMemoryFree(func_registry_memory, dev);

src/runtime/crt/host/crt_config.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,10 +43,10 @@
4343
#define TVM_CRT_MAX_REGISTERED_MODULES 2
4444

4545
/*! Size of the global function registry, in bytes. */
46-
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
46+
#define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 512
4747

4848
/*! Maximum packet size, in bytes, including the length header. */
49-
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
49+
#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024
5050

5151
/*! \brief Maximum length of a PackedFunc function name. */
5252
#define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30

src/runtime/crt/host/main.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,9 +136,12 @@ int main(int argc, char** argv) {
136136
"failed to register GraphExecutor TVMModule");
137137
#endif
138138

139-
if (TVMFuncRegisterGlobal("tvm.testing.reset_server", (TVMFunctionHandle)&testonly_reset_server,
140-
0)) {
141-
fprintf(stderr, "utvm runtime: internal error registering global packedfunc; exiting\n");
139+
int error = TVMFuncRegisterGlobal("tvm.testing.reset_server",
140+
(TVMFunctionHandle)&testonly_reset_server, 0);
141+
if (error) {
142+
fprintf(stderr,
143+
"utvm runtime: internal error (error#: %x) registering global packedfunc; exiting\n",
144+
error);
142145
return 2;
143146
}
144147

src/runtime/minrpc/rpc_reference.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ namespace runtime {
3030
/*! \brief The current RPC procotol version. */
3131
constexpr const char* kRPCProtocolVer = "0.8.0";
3232

33+
// When tvm.rpc.server.GetCRTMaxPacketSize global function is not registered.
34+
const uint64_t kRPCMaxTransferSizeBytesDefault = UINT64_MAX;
35+
3336
/*! \brief The RPC code */
3437
enum class RPCCode : int {
3538
kNone,

src/runtime/rpc/rpc_endpoint.cc

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
330330
}
331331

332332
/*!
333-
* \brief Recive incoming packed seq from the stream.
333+
* \brief Receive incoming packed seq from the stream.
334334
* \return The received argments.
335335
* \note The TVMArgs is available until we switchstate.
336336
*/
@@ -369,7 +369,6 @@ class RPCEndpoint::EventHandler : public dmlc::Stream {
369369
*/
370370
void HandleReturn(RPCCode code, RPCSession::FEncodeReturn setreturn) {
371371
TVMArgs args = RecvPackedSeq();
372-
373372
if (code == RPCCode::kException) {
374373
// switch to the state before sending exception.
375374
this->SwitchToState(kRecvPacketNumBytes);
@@ -802,14 +801,13 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
802801
std::lock_guard<std::mutex> lock(mutex_);
803802
RPCCode code = RPCCode::kCopyToRemote;
804803

805-
uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
806-
ICHECK_EQ(nbytes, num_data_bytes);
804+
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
805+
ICHECK_LE(to->byte_offset + nbytes, tensor_total_size_bytes)
806+
<< "CopyToRemote: overflow in tensor size: (byte_offset=" << to->byte_offset
807+
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";
807808

808-
uint64_t to_data = reinterpret_cast<uint64_t>(to->data);
809-
uint64_t shape_bytes = to->ndim * sizeof(int64_t);
810-
uint64_t packet_nbytes = sizeof(code) + sizeof(to_data) + sizeof(to->device) + sizeof(to->ndim) +
811-
sizeof(to->dtype) + sizeof(to->byte_offset) + shape_bytes +
812-
sizeof(nbytes) + num_data_bytes;
809+
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(to, code, nbytes);
810+
uint64_t packet_nbytes = overhead + nbytes;
813811

814812
handler_->Write(packet_nbytes);
815813
handler_->Write(code);
@@ -823,14 +821,13 @@ void RPCEndpoint::CopyFromRemote(DLTensor* from, void* to_bytes, uint64_t nbytes
823821
std::lock_guard<std::mutex> lock(mutex_);
824822
RPCCode code = RPCCode::kCopyFromRemote;
825823

826-
uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*from));
827-
CHECK_EQ(nbytes, num_data_bytes);
824+
uint64_t tensor_total_size_bytes = static_cast<uint64_t>(GetDataSize(*from));
825+
ICHECK_LE(from->byte_offset + nbytes, tensor_total_size_bytes)
826+
<< "CopyFromRemote: overflow in tensor size: (byte_offset=" << from->byte_offset
827+
<< ", nbytes=" << nbytes << ", tensor_total_size=" << tensor_total_size_bytes << ")";
828828

829-
uint64_t from_data = reinterpret_cast<uint64_t>(from->data);
830-
uint64_t shape_bytes = from->ndim * sizeof(int64_t);
831-
uint64_t packet_nbytes = sizeof(code) + sizeof(from_data) + sizeof(from->device) +
832-
sizeof(from->ndim) + sizeof(from->dtype) + sizeof(from->byte_offset) +
833-
shape_bytes + sizeof(nbytes);
829+
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(from, code, nbytes);
830+
uint64_t packet_nbytes = overhead;
834831

835832
handler_->Write(packet_nbytes);
836833
handler_->Write(code);
@@ -1009,11 +1006,55 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
10091006
}
10101007

10111008
void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
1012-
endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes);
1009+
RPCCode code = RPCCode::kCopyToRemote;
1010+
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_to, code, nbytes);
1011+
uint64_t rpc_max_size = GetRPCMaxTransferSize();
1012+
ICHECK_GT(rpc_max_size, overhead) << "CopyToRemote: Invalid block size!";
1013+
const uint64_t block_size = rpc_max_size - overhead;
1014+
uint64_t block_count = 0;
1015+
const uint64_t num_blocks = nbytes / block_size;
1016+
void* from_bytes;
1017+
1018+
for (block_count = 0; block_count < num_blocks; block_count++) {
1019+
remote_to->byte_offset = block_count * block_size;
1020+
from_bytes = reinterpret_cast<void*>(
1021+
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
1022+
endpoint_->CopyToRemote(from_bytes, remote_to, block_size);
1023+
}
1024+
1025+
const uint64_t remainder_bytes = nbytes % block_size;
1026+
if (remainder_bytes != 0) {
1027+
remote_to->byte_offset = block_count * block_size;
1028+
from_bytes = reinterpret_cast<void*>(
1029+
(reinterpret_cast<uint8_t*>(local_from_bytes) + block_count * block_size));
1030+
endpoint_->CopyToRemote(from_bytes, remote_to, remainder_bytes);
1031+
}
10131032
}
10141033

10151034
void CopyFromRemote(DLTensor* remote_from, void* local_to_bytes, uint64_t nbytes) final {
1016-
endpoint_->CopyFromRemote(remote_from, local_to_bytes, nbytes);
1035+
RPCCode code = RPCCode::kCopyFromRemote;
1036+
uint64_t overhead = RemoteCopyCalculatePacketOverheadSize(remote_from, code, nbytes);
1037+
uint64_t rpc_max_size = GetRPCMaxTransferSize();
1038+
ICHECK_GT(rpc_max_size, overhead) << "CopyFromRemote: Invalid block size!";
1039+
const uint64_t block_size = rpc_max_size - overhead;
1040+
uint64_t block_count = 0;
1041+
const uint64_t num_blocks = nbytes / block_size;
1042+
void* to_bytes;
1043+
1044+
for (block_count = 0; block_count < num_blocks; block_count++) {
1045+
remote_from->byte_offset = block_count * block_size;
1046+
to_bytes = reinterpret_cast<void*>(
1047+
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
1048+
endpoint_->CopyFromRemote(remote_from, to_bytes, block_size);
1049+
}
1050+
1051+
const uint64_t remainder_bytes = nbytes % block_size;
1052+
if (remainder_bytes != 0) {
1053+
remote_from->byte_offset = block_count * block_size;
1054+
to_bytes = reinterpret_cast<void*>(
1055+
(reinterpret_cast<uint8_t*>(local_to_bytes) + block_count * block_size));
1056+
endpoint_->CopyFromRemote(remote_from, to_bytes, remainder_bytes);
1057+
}
10171058
}
10181059

10191060
void FreeHandle(void* handle, int type_code) final {
@@ -1082,12 +1123,43 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
10821123
bool IsLocalSession() const final { return false; }
10831124

10841125
private:
1126+
uint64_t GetRPCMaxTransferSize() {
1127+
if (rpc_chunk_max_size_bytes_ > 0) {
1128+
return (uint64_t)rpc_chunk_max_size_bytes_;
1129+
}
1130+
1131+
PackedFuncHandle rpc_func = GetFunction("tvm.rpc.server.GetCRTMaxPacketSize");
1132+
if (rpc_func == nullptr) {
1133+
rpc_chunk_max_size_bytes_ = (int64_t)kRPCMaxTransferSizeBytesDefault;
1134+
} else {
1135+
CallFunc(rpc_func, nullptr, nullptr, 0, [this](TVMArgs args) {
1136+
// Use args[1] as return value, args[0] is tcode
1137+
// Look at RPCWrappedFunc in src/runtime/rpc/rpc_module.cc
1138+
rpc_chunk_max_size_bytes_ = (int64_t)args[1];
1139+
ICHECK_GT(rpc_chunk_max_size_bytes_, 0)
1140+
<< "RPC max transfer size is <= 0! (remote value = " << rpc_chunk_max_size_bytes_
1141+
<< ")";
1142+
});
1143+
}
1144+
return (uint64_t)rpc_chunk_max_size_bytes_;
1145+
}
1146+
10851147
std::shared_ptr<RPCEndpoint> endpoint_;
1148+
int64_t rpc_chunk_max_size_bytes_ = -1;
10861149
};
10871150

10881151
std::shared_ptr<RPCSession> CreateClientSession(std::shared_ptr<RPCEndpoint> endpoint) {
10891152
return std::make_shared<RPCClientSession>(endpoint);
10901153
}
10911154

1155+
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes) {
1156+
uint64_t shape_bytes = tensor->ndim * sizeof(int64_t);
1157+
uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<uint8_t*>(tensor->data));
1158+
uint64_t overhead = sizeof(code) + sizeof(to_data) + sizeof(tensor->device) +
1159+
sizeof(tensor->ndim) + sizeof(tensor->dtype) + sizeof(tensor->byte_offset) +
1160+
shape_bytes + sizeof(nbytes);
1161+
return overhead;
1162+
}
1163+
10921164
} // namespace runtime
10931165
} // namespace tvm

src/runtime/rpc/rpc_endpoint.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,16 @@ template <typename... Args>
204204
inline TVMRetValue RPCEndpoint::SysCallRemote(RPCCode code, Args&&... args) {
205205
return syscall_remote_(static_cast<int>(code), std::forward<Args>(args)...);
206206
}
207+
208+
/*!
209+
* \brief Calculates overhead size of a CopyToRemote packet.
210+
* \param to DLTensor to copy.
211+
* \param code RPCCode for this transfer.
212+
* \param nbytes Number of bytes to transfer.
213+
* \return The remote-copy packet overhead size.
214+
*/
215+
uint64_t RemoteCopyCalculatePacketOverheadSize(DLTensor* tensor, RPCCode code, uint64_t nbytes);
216+
207217
} // namespace runtime
208218
} // namespace tvm
209219
#endif // TVM_RUNTIME_RPC_RPC_ENDPOINT_H_

0 commit comments

Comments
 (0)