Skip to content

Commit 8646e0d

Browse files
H-Huangpytorchmergebot
authored andcommitted
[Dynamic RPC] Allow existing ranks to communicate with newly joined ranks
Pull Request resolved: pytorch#74035 Approved by: https://github.com/mrshenli
1 parent 690bc1c commit 8646e0d

File tree

6 files changed

+309
-77
lines changed

6 files changed

+309
-77
lines changed

torch/_C/_distributed_rpc.pyi

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Tuple, Union, overload
22
from datetime import timedelta
33
import enum
44
import torch
5+
from torch.types import Device
56
from . import Future
67
from ._autograd import ProfilerConfig, ProfilerState, ProfilerEvent
78
from ._distributed_c10d import ProcessGroup, Store
@@ -43,6 +44,8 @@ class RpcAgent:
4344
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
4445
def get_debug_info(self) -> Dict[str, str]: ...
4546
def get_metrics(self) -> Dict[str, str]: ...
47+
def _update_group_membership(self, worker_info: WorkerInfo, my_devices: List[torch.device], reverse_device_map: Dict[str, Dict[torch.device, torch.device]], is_join: bool): ...
48+
def _get_backend_options(self): ...
4649

4750
class PyRRef:
4851
def __init__(self, value: Any, type_hint: Any = None): ...
@@ -100,6 +103,8 @@ class TensorPipeAgent(RpcAgent):
100103
def get_worker_info(self, id: int) -> WorkerInfo: ...
101104
def get_worker_infos(self) -> List[WorkerInfo]: ...
102105
def _get_device_map(self, dst: WorkerInfo) -> Dict[torch.device, torch.device]: ...
106+
def _update_group_membership(self): ...
107+
def _get_backend_options(self): ...
103108

104109
def _is_current_rpc_agent_set() -> bool: ...
105110
def _get_current_rpc_agent()-> RpcAgent: ...

torch/csrc/distributed/rpc/init.cpp

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -632,6 +632,14 @@ PyObject* rpc_init(PyObject* _unused, PyObject* noargs) {
632632
"_get_device_map",
633633
(DeviceMap(TensorPipeAgent::*)(const WorkerInfo& dst) const) &
634634
TensorPipeAgent::getDeviceMap,
635+
py::call_guard<py::gil_scoped_release>())
636+
.def(
637+
"_get_backend_options",
638+
&TensorPipeAgent::getBackendOptions,
639+
py::call_guard<py::gil_scoped_release>())
640+
.def(
641+
"_update_group_membership",
642+
&TensorPipeAgent::updateGroupMembership,
635643
py::call_guard<py::gil_scoped_release>());
636644

637645
#endif // USE_TENSORPIPE

torch/csrc/distributed/rpc/tensorpipe_agent.cpp

Lines changed: 98 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -561,7 +561,11 @@ void TensorPipeAgent::pipeRead(
561561
return;
562562
}
563563

564-
std::vector<c10::Stream> streams = getStreamsFromPoolForDevices(devices_);
564+
std::vector<c10::Stream> streams;
565+
{
566+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
567+
streams = getStreamsFromPoolForDevices(devices_);
568+
}
565569
tensorpipe::Allocation tpAllocation;
566570
TensorpipeReadBuffers tpBuffers;
567571
std::tie(tpAllocation, tpBuffers) =
@@ -641,24 +645,26 @@ void TensorPipeAgent::sendCompletedResponseMessage(
641645

642646
for (const auto& tensor : responseMessage->tensors()) {
643647
const auto device = tensor.device();
644-
if (!device.is_cpu() &&
645-
std::find(devices_.begin(), devices_.end(), device) ==
646-
devices_.end()) {
647-
std::ostringstream oss;
648-
std::copy(
649-
devices_.begin(),
650-
devices_.end(),
651-
std::ostream_iterator<c10::Device>(oss, ", "));
652-
responseMessage = createExceptionResponse(
653-
c10::str(
654-
"RPC detected that a user-function output tensor on device ",
655-
device,
656-
". This device is not one of the input tensor devices: ",
657-
oss.str(),
658-
"which is not yet supported. Please file a feature request "
659-
"issue in PyTorch GitHub repo."),
660-
messageId);
661-
break;
648+
if (!device.is_cpu()) {
649+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
650+
if (std::find(devices_.begin(), devices_.end(), device) ==
651+
devices_.end()) {
652+
std::ostringstream oss;
653+
std::copy(
654+
devices_.begin(),
655+
devices_.end(),
656+
std::ostream_iterator<c10::Device>(oss, ", "));
657+
responseMessage = createExceptionResponse(
658+
c10::str(
659+
"RPC detected that a user-function output tensor on device ",
660+
device,
661+
". This device is not one of the input tensor devices: ",
662+
oss.str(),
663+
"which is not yet supported. Please file a feature request "
664+
"issue in PyTorch GitHub repo."),
665+
messageId);
666+
break;
667+
}
662668
}
663669
}
664670

@@ -821,7 +827,12 @@ c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
821827
}
822828
ClientPipe& clientPipe = it->second;
823829

824-
auto futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
830+
std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
831+
futureResponseMessage;
832+
{
833+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
834+
futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
835+
}
825836
uint64_t messageId = nextMessageID_++;
826837
requestMessage->setId(messageId);
827838

@@ -881,7 +892,11 @@ c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
881892
VLOG(1) << "RPC agent for " << workerInfo_.name_ << " is sending request #"
882893
<< messageId << " to " << clientPipe.pipe_->getRemoteName();
883894

884-
std::vector<c10::Stream> streams = getStreamsFromPoolForDevices(devices_);
895+
std::vector<c10::Stream> streams;
896+
{
897+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
898+
streams = getStreamsFromPoolForDevices(devices_);
899+
}
885900
makeStreamsWaitOnOthers(
886901
streams,
887902
getCurrentStreamsForDevices(
@@ -1133,14 +1148,22 @@ void TensorPipeAgent::shutdownImpl() {
11331148

11341149
const WorkerInfo& TensorPipeAgent::getWorkerInfo(
11351150
const std::string& workerName) const {
1136-
const auto& it = workerNameToInfo_.find(workerName);
1151+
std::unordered_map<std::string, WorkerInfo>::const_iterator it;
1152+
{
1153+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1154+
it = workerNameToInfo_.find(workerName);
1155+
}
11371156
TORCH_CHECK(
11381157
it != workerNameToInfo_.end(), "Unknown destination worker ", workerName);
11391158
return it->second;
11401159
}
11411160

11421161
const WorkerInfo& TensorPipeAgent::getWorkerInfo(worker_id_t workerId) const {
1143-
const auto& it = workerIdToInfo_.find(workerId);
1162+
std::unordered_map<worker_id_t, WorkerInfo>::const_iterator it;
1163+
{
1164+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1165+
it = workerIdToInfo_.find(workerId);
1166+
}
11441167
TORCH_CHECK(
11451168
it != workerIdToInfo_.end(), "Unknown destination worker ", workerId);
11461169
return it->second;
@@ -1156,12 +1179,53 @@ std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
11561179

11571180
const std::string& TensorPipeAgent::findWorkerURL(
11581181
const WorkerInfo& worker) const {
1159-
const auto it = workerNameToURL_.find(worker.name_);
1182+
std::unordered_map<std::string, std::string>::const_iterator it;
1183+
{
1184+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1185+
it = workerNameToURL_.find(worker.name_);
1186+
}
11601187
TORCH_CHECK(
11611188
it != workerNameToURL_.end(), "Unknown worker name: ", worker.name_);
11621189
return it->second;
11631190
}
11641191

1192+
void TensorPipeAgent::updateGroupMembership(
1193+
const WorkerInfo& workerInfo,
1194+
const std::vector<c10::Device> devices,
1195+
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
1196+
bool isJoin = true) {
1197+
std::string name = workerInfo.name_;
1198+
worker_id_t id = workerInfo.id_;
1199+
// Rank with workerInfo is joining the group, update internal mappings
1200+
if (isJoin) {
1201+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1202+
workerIdToInfo_.emplace(id, workerInfo);
1203+
workerNameToInfo_.emplace(name, workerInfo);
1204+
1205+
// TODO: we should get nodeAddrStr in the joining process, then pass in as
1206+
// an argument rather than getting from store each time
1207+
auto nodeAddrData = nameToAddressStore_.get(name);
1208+
auto nodeAddrStr =
1209+
std::string((const char*)nodeAddrData.data(), nodeAddrData.size());
1210+
workerNameToURL_.insert({name, nodeAddrStr});
1211+
1212+
for (const auto& it : reverseDeviceMaps) {
1213+
if (reverseDeviceMaps_.find(it.first) == reverseDeviceMaps_.end()) {
1214+
reverseDeviceMaps_[it.first] = it.second;
1215+
}
1216+
}
1217+
// TODO: clean up mutex for devices_ usage
1218+
// Add devices that have not been added yet
1219+
for (const auto& it : devices) {
1220+
if (std::find(devices_.begin(), devices_.end(), it) == devices_.end()) {
1221+
devices_.push_back(it);
1222+
}
1223+
}
1224+
}
1225+
// TODO: Rank with workerInfo is leaving, update internal mappings
1226+
else {
1227+
}
1228+
}
11651229
std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics() {
11661230
std::unordered_map<std::string, std::string> metrics;
11671231
metrics[kThreadPoolSize] = c10::to_string(threadPool_.size());
@@ -1289,8 +1353,11 @@ void TensorPipeAgent::markFutureWithError(
12891353
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote(
12901354
const std::string& remoteName,
12911355
const Message& message) const {
1292-
const auto& deviceMaps =
1293-
message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
1356+
std::unordered_map<std::string, DeviceMap> deviceMaps;
1357+
{
1358+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
1359+
deviceMaps = message.isRequest() ? opts_.deviceMaps : reverseDeviceMaps_;
1360+
}
12941361

12951362
const auto errStr = c10::str(
12961363
"TensorPipe RPC backend only supports CPU tensors by default, please "
@@ -1324,7 +1391,12 @@ DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
13241391
return it->second;
13251392
}
13261393

1394+
TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions() const {
1395+
return opts_;
1396+
}
1397+
13271398
const std::vector<c10::Device>& TensorPipeAgent::getDevices() const {
1399+
GroupMembershipLockGuard guard(groupMembershipMutex_, isStaticGroup_);
13281400
return devices_;
13291401
}
13301402

torch/csrc/distributed/rpc/tensorpipe_agent.h

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -192,11 +192,18 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
192192
const WorkerInfo& getWorkerInfo(const std::string& workerName) const override;
193193
const WorkerInfo& getWorkerInfo(worker_id_t workerId) const override;
194194
std::vector<WorkerInfo> getWorkerInfos() const override;
195+
void updateGroupMembership(
196+
const WorkerInfo& workerInfo,
197+
const std::vector<c10::Device> devices,
198+
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
199+
bool isJoin);
195200

196201
std::unordered_map<std::string, std::string> getMetrics() override;
197202

198203
void addGilWaitTime(const std::chrono::microseconds gilWaitTime) override;
199204

205+
TensorPipeRpcBackendOptions getBackendOptions() const;
206+
200207
DeviceMap getDeviceMap(const WorkerInfo& dest) const override;
201208

202209
const std::vector<c10::Device>& getDevices() const override;
@@ -311,11 +318,13 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
311318
};
312319

313320
const TensorPipeRpcBackendOptions opts_;
314-
const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
321+
// For dynamic RPC, the reverse device maps are updated whenever a new rank
322+
// joins or leaves the group
323+
std::unordered_map<std::string, DeviceMap> reverseDeviceMaps_;
315324
// Local devices used by this agent. If application didn't specify this
316325
// field, it will be initialized using corresponding local devices in
317326
// opts_.deviceMaps and reverseDeviceMaps_;
318-
const std::vector<c10::Device> devices_;
327+
std::vector<c10::Device> devices_;
319328

320329
ThreadPool threadPool_;
321330
std::shared_ptr<tensorpipe::Context> context_;
@@ -414,6 +423,31 @@ class TORCH_API TensorPipeAgent : public RpcAgent {
414423
// Mutex to guard timeSeriesMetrics_
415424
std::mutex metricsMutex_;
416425

426+
// Custom lock guard used to check if the RPC group is dynamic and lock the
427+
// mutex if so
428+
struct GroupMembershipLockGuard {
429+
GroupMembershipLockGuard(std::mutex& mutex, bool isStaticGroup)
430+
: ref_(mutex), isStaticGroup_(isStaticGroup) {
431+
if (isStaticGroup_) {
432+
ref_.lock();
433+
}
434+
}
435+
436+
~GroupMembershipLockGuard() {
437+
if (isStaticGroup_) {
438+
ref_.unlock();
439+
}
440+
}
441+
442+
private:
443+
GroupMembershipLockGuard(const GroupMembershipLockGuard&);
444+
std::mutex& ref_;
445+
bool isStaticGroup_;
446+
};
447+
// Mutex to guard access to group membership data
448+
// e.g. updates to (workerIdToInfo_, workerNameToInfo_, workerNameToURL_)
449+
mutable std::mutex groupMembershipMutex_;
450+
417451
// Map to Track Network Data
418452
NetworkDataDict networkData_;
419453
// Mutex to guard networkData_

0 commit comments

Comments
 (0)