@@ -98,14 +98,15 @@ class RPCServer {
98
98
* \brief Constructor.
99
99
*/
100
100
RPCServer (std::string host, int port, int port_end, std::string tracker_addr, std::string key,
101
- std::string custom_addr)
101
+ std::string custom_addr, std::string work_dir )
102
102
: host_(std::move(host)),
103
103
port_ (port),
104
104
my_port_(0 ),
105
105
port_end_(port_end),
106
106
tracker_addr_(std::move(tracker_addr)),
107
107
key_(std::move(key)),
108
- custom_addr_(std::move(custom_addr)) {}
108
+ custom_addr_(std::move(custom_addr)),
109
+ work_dir_(std::move(work_dir)) {}
109
110
110
111
/* !
111
112
* \brief Destructor.
@@ -174,7 +175,7 @@ class RPCServer {
174
175
const pid_t worker_pid = fork ();
175
176
if (worker_pid == 0 ) {
176
177
// Worker process
177
- ServerLoopProc (conn, addr);
178
+ ServerLoopProc (conn, addr, work_dir_ );
178
179
_exit (0 );
179
180
}
180
181
@@ -201,7 +202,7 @@ class RPCServer {
201
202
} else {
202
203
auto pid = fork ();
203
204
if (pid == 0 ) {
204
- ServerLoopProc (conn, addr);
205
+ ServerLoopProc (conn, addr, work_dir_ );
205
206
exit (0 );
206
207
}
207
208
// Wait for the result
@@ -308,9 +309,10 @@ class RPCServer {
308
309
* \param sock The socket information
309
310
* \param addr The socket address information
310
311
*/
311
- static void ServerLoopProc (support::TCPSocket sock, support::SockAddr addr) {
312
+ static void ServerLoopProc (support::TCPSocket sock, support::SockAddr addr,
313
+ std::string work_dir) {
312
314
// Server loop
313
- const auto env = RPCEnv ();
315
+ const auto env = RPCEnv (work_dir );
314
316
RPCServerLoop (int (sock.sockfd ));
315
317
LOG (INFO) << " Finish serving " << addr.AsString ();
316
318
env.CleanUp ();
@@ -339,6 +341,7 @@ class RPCServer {
339
341
std::string tracker_addr_;
340
342
std::string key_;
341
343
std::string custom_addr_;
344
+ std::string work_dir_;
342
345
support::TCPSocket listen_sock_;
343
346
support::TCPSocket tracker_sock_;
344
347
};
@@ -370,19 +373,19 @@ void ServerLoopFromChild(SOCKET socket) {
370
373
* silent mode. Default=True
371
374
*/
372
375
void RPCServerCreate (std::string host, int port, int port_end, std::string tracker_addr,
373
- std::string key, std::string custom_addr, bool silent) {
376
+ std::string key, std::string custom_addr, std::string work_dir, bool silent) {
374
377
if (silent) {
375
378
// Only errors and fatal is logged
376
379
dmlc::InitLogging (" --minloglevel=2" );
377
380
}
378
381
// Start the rpc server
379
382
RPCServer rpc (std::move (host), port, port_end, std::move (tracker_addr), std::move (key),
380
- std::move (custom_addr));
383
+ std::move (custom_addr), std::move (work_dir) );
381
384
rpc.Start ();
382
385
}
383
386
384
387
TVM_REGISTER_GLOBAL (" rpc.ServerCreate" ).set_body([](TVMArgs args, TVMRetValue* rv) {
385
- RPCServerCreate (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ], args[6 ]);
388
+ RPCServerCreate (args[0 ], args[1 ], args[2 ], args[3 ], args[4 ], args[5 ], args[6 ], args[ 7 ] );
386
389
});
387
390
} // namespace runtime
388
391
} // namespace tvm
0 commit comments