@@ -561,7 +561,11 @@ void TensorPipeAgent::pipeRead(
561
561
return ;
562
562
}
563
563
564
- std::vector<c10::Stream> streams = getStreamsFromPoolForDevices (devices_);
564
+ std::vector<c10::Stream> streams;
565
+ {
566
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
567
+ streams = getStreamsFromPoolForDevices (devices_);
568
+ }
565
569
tensorpipe::Allocation tpAllocation;
566
570
TensorpipeReadBuffers tpBuffers;
567
571
std::tie (tpAllocation, tpBuffers) =
@@ -641,24 +645,26 @@ void TensorPipeAgent::sendCompletedResponseMessage(
641
645
642
646
for (const auto & tensor : responseMessage->tensors ()) {
643
647
const auto device = tensor.device ();
644
- if (!device.is_cpu () &&
645
- std::find (devices_.begin (), devices_.end (), device) ==
646
- devices_.end ()) {
647
- std::ostringstream oss;
648
- std::copy (
649
- devices_.begin (),
650
- devices_.end (),
651
- std::ostream_iterator<c10::Device>(oss, " , " ));
652
- responseMessage = createExceptionResponse (
653
- c10::str (
654
- " RPC detected that a user-function output tensor on device " ,
655
- device,
656
- " . This device is not one of the input tensor devices: " ,
657
- oss.str (),
658
- " which is not yet supported. Please file a feature request "
659
- " issue in PyTorch GitHub repo." ),
660
- messageId);
661
- break ;
648
+ if (!device.is_cpu ()) {
649
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
650
+ if (std::find (devices_.begin (), devices_.end (), device) ==
651
+ devices_.end ()) {
652
+ std::ostringstream oss;
653
+ std::copy (
654
+ devices_.begin (),
655
+ devices_.end (),
656
+ std::ostream_iterator<c10::Device>(oss, " , " ));
657
+ responseMessage = createExceptionResponse (
658
+ c10::str (
659
+ " RPC detected that a user-function output tensor on device " ,
660
+ device,
661
+ " . This device is not one of the input tensor devices: " ,
662
+ oss.str (),
663
+ " which is not yet supported. Please file a feature request "
664
+ " issue in PyTorch GitHub repo." ),
665
+ messageId);
666
+ break ;
667
+ }
662
668
}
663
669
}
664
670
@@ -821,7 +827,12 @@ c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
821
827
}
822
828
ClientPipe& clientPipe = it->second ;
823
829
824
- auto futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
830
+ std::shared_ptr<torch::distributed::rpc::TensorPipeAgent::AtomicJitFuture>
831
+ futureResponseMessage;
832
+ {
833
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
834
+ futureResponseMessage = std::make_shared<AtomicJitFuture>(devices_);
835
+ }
825
836
uint64_t messageId = nextMessageID_++;
826
837
requestMessage->setId (messageId);
827
838
@@ -881,7 +892,11 @@ c10::intrusive_ptr<JitFuture> TensorPipeAgent::send(
881
892
VLOG (1 ) << " RPC agent for " << workerInfo_.name_ << " is sending request #"
882
893
<< messageId << " to " << clientPipe.pipe_ ->getRemoteName ();
883
894
884
- std::vector<c10::Stream> streams = getStreamsFromPoolForDevices (devices_);
895
+ std::vector<c10::Stream> streams;
896
+ {
897
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
898
+ streams = getStreamsFromPoolForDevices (devices_);
899
+ }
885
900
makeStreamsWaitOnOthers (
886
901
streams,
887
902
getCurrentStreamsForDevices (
@@ -1133,14 +1148,22 @@ void TensorPipeAgent::shutdownImpl() {
1133
1148
1134
1149
const WorkerInfo& TensorPipeAgent::getWorkerInfo (
1135
1150
const std::string& workerName) const {
1136
- const auto & it = workerNameToInfo_.find (workerName);
1151
+ std::unordered_map<std::string, WorkerInfo>::const_iterator it;
1152
+ {
1153
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1154
+ it = workerNameToInfo_.find (workerName);
1155
+ }
1137
1156
TORCH_CHECK (
1138
1157
it != workerNameToInfo_.end (), " Unknown destination worker " , workerName);
1139
1158
return it->second ;
1140
1159
}
1141
1160
1142
1161
const WorkerInfo& TensorPipeAgent::getWorkerInfo (worker_id_t workerId) const {
1143
- const auto & it = workerIdToInfo_.find (workerId);
1162
+ std::unordered_map<worker_id_t , WorkerInfo>::const_iterator it;
1163
+ {
1164
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1165
+ it = workerIdToInfo_.find (workerId);
1166
+ }
1144
1167
TORCH_CHECK (
1145
1168
it != workerIdToInfo_.end (), " Unknown destination worker " , workerId);
1146
1169
return it->second ;
@@ -1156,12 +1179,53 @@ std::vector<WorkerInfo> TensorPipeAgent::getWorkerInfos() const {
1156
1179
1157
1180
const std::string& TensorPipeAgent::findWorkerURL (
1158
1181
const WorkerInfo& worker) const {
1159
- const auto it = workerNameToURL_.find (worker.name_ );
1182
+ std::unordered_map<std::string, std::string>::const_iterator it;
1183
+ {
1184
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1185
+ it = workerNameToURL_.find (worker.name_ );
1186
+ }
1160
1187
TORCH_CHECK (
1161
1188
it != workerNameToURL_.end (), " Unknown worker name: " , worker.name_ );
1162
1189
return it->second ;
1163
1190
}
1164
1191
1192
+ void TensorPipeAgent::updateGroupMembership (
1193
+ const WorkerInfo& workerInfo,
1194
+ const std::vector<c10::Device> devices,
1195
+ const std::unordered_map<std::string, DeviceMap> reverseDeviceMaps,
1196
+ bool isJoin = true ) {
1197
+ std::string name = workerInfo.name_ ;
1198
+ worker_id_t id = workerInfo.id_ ;
1199
+ // Rank with workerInfo is joining the group, update internal mappings
1200
+ if (isJoin) {
1201
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1202
+ workerIdToInfo_.emplace (id, workerInfo);
1203
+ workerNameToInfo_.emplace (name, workerInfo);
1204
+
1205
+ // TODO: we should get nodeAddrStr in the joining process, then pass in as
1206
+ // an argument rather than getting from store each time
1207
+ auto nodeAddrData = nameToAddressStore_.get (name);
1208
+ auto nodeAddrStr =
1209
+ std::string ((const char *)nodeAddrData.data (), nodeAddrData.size ());
1210
+ workerNameToURL_.insert ({name, nodeAddrStr});
1211
+
1212
+ for (const auto & it : reverseDeviceMaps) {
1213
+ if (reverseDeviceMaps_.find (it.first ) == reverseDeviceMaps_.end ()) {
1214
+ reverseDeviceMaps_[it.first ] = it.second ;
1215
+ }
1216
+ }
1217
+ // TODO: clean up mutex for devices_ usage
1218
+ // Add devices that have not been added yet
1219
+ for (const auto & it : devices) {
1220
+ if (std::find (devices_.begin (), devices_.end (), it) == devices_.end ()) {
1221
+ devices_.push_back (it);
1222
+ }
1223
+ }
1224
+ }
1225
+ // TODO: Rank with workerInfo is leaving, update internal mappings
1226
+ else {
1227
+ }
1228
+ }
1165
1229
std::unordered_map<std::string, std::string> TensorPipeAgent::getMetrics () {
1166
1230
std::unordered_map<std::string, std::string> metrics;
1167
1231
metrics[kThreadPoolSize ] = c10::to_string (threadPool_.size ());
@@ -1289,8 +1353,11 @@ void TensorPipeAgent::markFutureWithError(
1289
1353
std::vector<c10::Device> TensorPipeAgent::getDevicesForRemote (
1290
1354
const std::string& remoteName,
1291
1355
const Message& message) const {
1292
- const auto & deviceMaps =
1293
- message.isRequest () ? opts_.deviceMaps : reverseDeviceMaps_;
1356
+ std::unordered_map<std::string, DeviceMap> deviceMaps;
1357
+ {
1358
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1359
+ deviceMaps = message.isRequest () ? opts_.deviceMaps : reverseDeviceMaps_;
1360
+ }
1294
1361
1295
1362
const auto errStr = c10::str (
1296
1363
" TensorPipe RPC backend only supports CPU tensors by default, please "
@@ -1324,7 +1391,12 @@ DeviceMap TensorPipeAgent::getDeviceMap(const WorkerInfo& dst) const {
1324
1391
return it->second ;
1325
1392
}
1326
1393
1394
+ TensorPipeRpcBackendOptions TensorPipeAgent::getBackendOptions () const {
1395
+ return opts_;
1396
+ }
1397
+
1327
1398
const std::vector<c10::Device>& TensorPipeAgent::getDevices () const {
1399
+ GroupMembershipLockGuard guard (groupMembershipMutex_, isStaticGroup_);
1328
1400
return devices_;
1329
1401
}
1330
1402
0 commit comments