Skip to content

Commit e467748

Browse files
authored
[CPP_RPC] allow user supplied work dir (#7670)
* [CPP_RPC] allow user supplied work dir * clang format
1 parent 431a7d6 commit e467748

File tree

5 files changed

+44
-27
lines changed

5 files changed

+44
-27
lines changed

apps/cpp_rpc/main.cc

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ static const string kUsage =
5555
"--tracker - The RPC tracker address in host:port format e.g. 10.1.1.2:9190 Default=\"\"\n"
5656
"--key - The key used to identify the device type in tracker. Default=\"\"\n"
5757
"--custom-addr - Custom IP Address to Report to RPC Tracker. Default=\"\"\n"
58+
"--work-dir - Custom work directory. Default=\"\"\n"
5859
"--silent - Whether to run in silent mode. Default=False\n"
5960
"\n"
6061
" Example\n"
@@ -70,6 +71,7 @@ static const string kUsage =
7071
* \arg tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
7172
* \arg key The key used to identify the device type in tracker. Default=""
7273
* \arg custom_addr Custom IP Address to Report to RPC Tracker. Default=""
74+
* \arg work_dir Custom work directory. Default=""
7375
* \arg silent Whether run in silent mode. Default=False
7476
*/
7577
struct RpcServerArgs {
@@ -79,6 +81,7 @@ struct RpcServerArgs {
7981
string tracker;
8082
string key;
8183
string custom_addr;
84+
string work_dir;
8285
bool silent = false;
8386
#if defined(WIN32)
8487
std::string mmap_path;
@@ -96,6 +99,7 @@ void PrintArgs(const RpcServerArgs& args) {
9699
LOG(INFO) << "tracker = " << args.tracker;
97100
LOG(INFO) << "key = " << args.key;
98101
LOG(INFO) << "custom_addr = " << args.custom_addr;
102+
LOG(INFO) << "work_dir = " << args.work_dir;
99103
LOG(INFO) << "silent = " << ((args.silent) ? ("True") : ("False"));
100104
}
101105

@@ -238,6 +242,10 @@ void ParseCmdArgs(int argc, char* argv[], struct RpcServerArgs& args) {
238242
dmlc::InitLogging("--minloglevel=0");
239243
}
240244
#endif
245+
const string work_dir = GetCmdOption(argc, argv, "--work-dir=");
246+
if (!work_dir.empty()) {
247+
args.work_dir = work_dir;
248+
}
241249
}
242250

243251
/*!
@@ -274,7 +282,7 @@ int RpcServer(int argc, char* argv[]) {
274282
#endif
275283

276284
RPCServerCreate(args.host, args.port, args.port_end, args.tracker, args.key, args.custom_addr,
277-
args.silent);
285+
args.work_dir, args.silent);
278286
return 0;
279287
}
280288

apps/cpp_rpc/rpc_env.cc

Lines changed: 20 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
3939
#include <iostream>
4040
#include <string>
4141
#include <vector>
42-
4342
#include "../../src/support/utils.h"
4443
#include "rpc_env.h"
4544

@@ -85,25 +84,31 @@ void CleanDir(const std::string& dirname);
8584
*/
8685
std::string BuildSharedLibrary(std::string file_in);
8786

88-
RPCEnv::RPCEnv() {
87+
RPCEnv::RPCEnv(const std::string& wd) {
88+
if (wd != "") {
89+
base_ = wd + "/.cache";
90+
mkdir(wd.c_str(), 0777);
91+
mkdir(base_.c_str(), 0777);
92+
} else {
8993
#if defined(ANDROID) || defined(__ANDROID__)
90-
char cwd[PATH_MAX];
91-
auto cmdline = fopen("/proc/self/cmdline", "r");
92-
fread(cwd, 1, sizeof(cwd), cmdline);
93-
fclose(cmdline);
94-
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
94+
char cwd[PATH_MAX];
95+
auto cmdline = fopen("/proc/self/cmdline", "r");
96+
fread(cwd, 1, sizeof(cwd), cmdline);
97+
fclose(cmdline);
98+
base_ = "/data/data/" + std::string(cwd) + "/cache/rpc";
9599
#elif !defined(_WIN32)
96-
char cwd[PATH_MAX];
97-
if (getcwd(cwd, sizeof(cwd))) {
98-
base_ = std::string(cwd) + "/rpc";
99-
} else {
100-
base_ = "./rpc";
101-
}
100+
char cwd[PATH_MAX];
101+
if (getcwd(cwd, sizeof(cwd))) {
102+
base_ = std::string(cwd) + "/rpc";
103+
} else {
104+
base_ = "./rpc";
105+
}
102106
#else
103-
base_ = "./rpc";
107+
base_ = "./rpc";
104108
#endif
109+
mkdir(base_.c_str(), 0777);
110+
}
105111

106-
mkdir(base_.c_str(), 0777);
107112
TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([this](TVMArgs args, TVMRetValue* rv) {
108113
*rv = this->GetPath(args[0]);
109114
});

apps/cpp_rpc/rpc_env.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ struct RPCEnv {
3939
/*!
4040
* \brief Constructor Init The RPC Environment initialize function
4141
*/
42-
RPCEnv();
42+
RPCEnv(const std::string& word_dir = "");
4343
/*!
4444
* \brief GetPath To get the workpath from packed function
4545
* \param name The file name

apps/cpp_rpc/rpc_server.cc

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -98,14 +98,15 @@ class RPCServer {
9898
* \brief Constructor.
9999
*/
100100
RPCServer(std::string host, int port, int port_end, std::string tracker_addr, std::string key,
101-
std::string custom_addr)
101+
std::string custom_addr, std::string work_dir)
102102
: host_(std::move(host)),
103103
port_(port),
104104
my_port_(0),
105105
port_end_(port_end),
106106
tracker_addr_(std::move(tracker_addr)),
107107
key_(std::move(key)),
108-
custom_addr_(std::move(custom_addr)) {}
108+
custom_addr_(std::move(custom_addr)),
109+
work_dir_(std::move(work_dir)) {}
109110

110111
/*!
111112
* \brief Destructor.
@@ -174,7 +175,7 @@ class RPCServer {
174175
const pid_t worker_pid = fork();
175176
if (worker_pid == 0) {
176177
// Worker process
177-
ServerLoopProc(conn, addr);
178+
ServerLoopProc(conn, addr, work_dir_);
178179
_exit(0);
179180
}
180181

@@ -201,7 +202,7 @@ class RPCServer {
201202
} else {
202203
auto pid = fork();
203204
if (pid == 0) {
204-
ServerLoopProc(conn, addr);
205+
ServerLoopProc(conn, addr, work_dir_);
205206
exit(0);
206207
}
207208
// Wait for the result
@@ -308,9 +309,10 @@ class RPCServer {
308309
* \param sock The socket information
309310
* \param addr The socket address information
310311
*/
311-
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
312+
static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr,
313+
std::string work_dir) {
312314
// Server loop
313-
const auto env = RPCEnv();
315+
const auto env = RPCEnv(work_dir);
314316
RPCServerLoop(int(sock.sockfd));
315317
LOG(INFO) << "Finish serving " << addr.AsString();
316318
env.CleanUp();
@@ -339,6 +341,7 @@ class RPCServer {
339341
std::string tracker_addr_;
340342
std::string key_;
341343
std::string custom_addr_;
344+
std::string work_dir_;
342345
support::TCPSocket listen_sock_;
343346
support::TCPSocket tracker_sock_;
344347
};
@@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
370373
* silent mode. Default=True
371374
*/
372375
void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
373-
std::string key, std::string custom_addr, bool silent) {
376+
std::string key, std::string custom_addr, std::string work_dir, bool silent) {
374377
if (silent) {
375378
// Only errors and fatal is logged
376379
dmlc::InitLogging("--minloglevel=2");
377380
}
378381
// Start the rpc server
379382
RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key),
380-
std::move(custom_addr));
383+
std::move(custom_addr), std::move(work_dir));
381384
rpc.Start();
382385
}
383386

384387
TVM_REGISTER_GLOBAL("rpc.ServerCreate").set_body([](TVMArgs args, TVMRetValue* rv) {
385-
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6]);
388+
RPCServerCreate(args[0], args[1], args[2], args[3], args[4], args[5], args[6], args[7]);
386389
});
387390
} // namespace runtime
388391
} // namespace tvm

apps/cpp_rpc/rpc_server.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,11 +48,12 @@ void ServerLoopFromChild(SOCKET socket);
4848
* \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
4949
* \param key The key used to identify the device type in tracker. Default=""
5050
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
51+
* \param work_dir Custom work directory. Default=""
5152
* \param silent Whether run in silent mode. Default=True
5253
*/
5354
void RPCServerCreate(std::string host = "", int port = 9090, int port_end = 9099,
5455
std::string tracker_addr = "", std::string key = "",
55-
std::string custom_addr = "", bool silent = true);
56+
std::string custom_addr = "", std::string work_dir = "", bool silent = true);
5657
} // namespace runtime
5758
} // namespace tvm
5859
#endif // TVM_APPS_CPP_RPC_SERVER_H_

0 commit comments

Comments
 (0)