@@ -278,7 +278,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278 raise NotImplementedError ("not implemented" )
279279
280280 def configure (
281- self , store_addr : str , replica_id : str , rank : int , world_size : int
281+ self ,
282+ store_addr : str ,
283+ replica_id : str ,
284+ rank : int ,
285+ world_size : int ,
286+ quorum_id : Optional [int ] = None ,
287+ group_rank : Optional [int ] = None ,
288+ group_world_size : Optional [int ] = None ,
289+ global_ranks : Optional [list [int ]] = None ,
282290 ) -> None :
283291 """
284292 This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -294,6 +302,10 @@ def configure(
294302 replica_id: the replica_id for this group
295303 rank: rank of this process
296304 world_size: world size of this process group
305+ quorum_id: current quorum's identifier
306+ group_rank: local rank within the replica group
307+ group_world_size: the number of ranks within a replica
308+ global_ranks: the global ranks part of this process group
297309 """
298310 raise NotImplementedError ("not implemented" )
299311
@@ -408,6 +420,10 @@ def __init__(
408420 self ._timeout = timeout
409421 self ._replica_id : str | None = None
410422 self ._rank : int | None = None
423+ self ._quorum_id : int | None = None
424+ self ._group_rank : int | None = None
425+ self ._group_world_size : int | None = None
426+ self ._global_ranks : list [int ] | None = None
411427
412428 self .errors_logger : logging .Logger = logging .getLogger ("torchft_errors" )
413429
@@ -419,13 +435,34 @@ def getBackendName(self) -> str:
419435 raise NotImplementedError ("not implemented" )
420436
421437 def configure (
422- self , store_addr : str , replica_id : str , rank : int , world_size : int
438+ self ,
439+ store_addr : str ,
440+ replica_id : str ,
441+ rank : int ,
442+ world_size : int ,
443+ quorum_id : Optional [int ] = None ,
444+ group_rank : Optional [int ] = None ,
445+ group_world_size : Optional [int ] = None ,
446+ global_ranks : Optional [list [int ]] = None ,
423447 ) -> None :
424448 pg = self ._pg
425449 self ._replica_id = replica_id
450+ self ._quorum_id = quorum_id
451+ self ._group_rank = group_rank
452+ self ._group_world_size = group_world_size
426453 self ._rank = rank
454+ self ._global_ranks = global_ranks
427455 if isinstance (pg , ProcessGroup ):
428- pg .configure (store_addr , replica_id , rank , world_size )
456+ pg .configure (
457+ store_addr ,
458+ replica_id ,
459+ rank ,
460+ world_size ,
461+ quorum_id ,
462+ group_rank ,
463+ group_world_size ,
464+ global_ranks ,
465+ )
429466 return
430467
431468 # abort if already initialized
@@ -443,6 +480,7 @@ def abort(self, errored: bool = True) -> None:
443480 "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444481 "replica_id" : self ._replica_id ,
445482 "rank" : self ._rank ,
483+ "quorum_id" : self ._quorum_id ,
446484 "error" : "process_group_abort" ,
447485 },
448486 )
@@ -615,6 +653,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615653 # pyre-fixme[16]: no attribute ProcessGroupGloo
616654 backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617655 backend_class ._set_sequence_number_for_group ()
656+
657+ if self ._global_ranks :
658+ backend_class .options .global_ranks_in_group = self ._global_ranks
659+ if self ._group_rank and self ._group_world_size :
660+ backend_class .options .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
661+
618662 pg ._register_backend (
619663 torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
620664 )
@@ -812,7 +856,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
812856 # pyre-fixme[16]: no attribute ProcessGroupNCCL
813857 opts = BaseProcessGroupNCCL .Options ()
814858 opts .config .blocking = False
815- opts .global_ranks_in_group = list (range (world_size ))
859+ if self ._global_ranks :
860+ opts .global_ranks_in_group = self ._global_ranks
861+ if self ._group_rank and self ._group_world_size :
862+ opts .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
816863
817864 pg = BaseProcessGroup (store , rank , world_size )
818865 pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +1026,15 @@ def __init__(self, rank: int, world: int) -> None:
9791026 self .configure_count = 0
9801027
9811028 def configure (
982- self , store_addr : str , replica_id : str , rank : int , world_size : int
1029+ self ,
1030+ store_addr : str ,
1031+ replica_id : str ,
1032+ rank : int ,
1033+ world_size : int ,
1034+ quorum_id : Optional [int ] = None ,
1035+ group_rank : Optional [int ] = None ,
1036+ group_world_size : Optional [int ] = None ,
1037+ global_ranks : Optional [list [int ]] = None ,
9831038 ) -> None :
9841039 self .configure_count += 1
9851040
@@ -1138,11 +1193,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11381193 self ._error : Optional [Exception ] = None
11391194
11401195 def configure (
1141- self , store_addr : str , replica_id : str , rank : int , world_size : int
1196+ self ,
1197+ store_addr : str ,
1198+ replica_id : str ,
1199+ rank : int ,
1200+ world_size : int ,
1201+ quorum_id : Optional [int ] = None ,
1202+ group_rank : Optional [int ] = None ,
1203+ group_world_size : Optional [int ] = None ,
1204+ global_ranks : Optional [list [int ]] = None ,
11421205 ) -> None :
11431206 self ._error = None
11441207
1145- super ().configure (store_addr , replica_id , rank , world_size )
1208+ super ().configure (
1209+ store_addr ,
1210+ replica_id ,
1211+ rank ,
1212+ world_size ,
1213+ quorum_id ,
1214+ group_rank ,
1215+ group_world_size ,
1216+ global_ranks ,
1217+ )
11461218
11471219 def report_error (self , e : Exception ) -> None :
11481220 """
@@ -1194,11 +1266,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11941266 self ._future_error : Optional [Exception ] = None
11951267
11961268 def configure (
1197- self , store_addr : str , replica_id : str , rank : int , world_size : int
1269+ self ,
1270+ store_addr : str ,
1271+ replica_id : str ,
1272+ rank : int ,
1273+ world_size : int ,
1274+ quorum_id : Optional [int ] = None ,
1275+ group_rank : Optional [int ] = None ,
1276+ group_world_size : Optional [int ] = None ,
1277+ global_ranks : Optional [list [int ]] = None ,
11981278 ) -> None :
11991279 self ._future_error = None
12001280
1201- super ().configure (store_addr , replica_id , rank , world_size )
1281+ super ().configure (
1282+ store_addr ,
1283+ replica_id ,
1284+ rank ,
1285+ world_size ,
1286+ quorum_id ,
1287+ group_rank ,
1288+ group_world_size ,
1289+ global_ranks ,
1290+ )
12021291
12031292 def report_future_error (self , e : Exception ) -> None :
12041293 """
@@ -1412,7 +1501,15 @@ def shutdown(self) -> None:
14121501 self ._p .kill ()
14131502
14141503 def configure (
1415- self , store_addr : str , replica_id : str , rank : int , world_size : int
1504+ self ,
1505+ store_addr : str ,
1506+ replica_id : str ,
1507+ rank : int ,
1508+ world_size : int ,
1509+ quorum_id : Optional [int ] = None ,
1510+ group_rank : Optional [int ] = None ,
1511+ group_world_size : Optional [int ] = None ,
1512+ global_ranks : Optional [list [int ]] = None ,
14161513 ) -> None :
14171514 self ._world_size = world_size
14181515
0 commit comments