@@ -276,7 +276,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
276276 raise NotImplementedError ("not implemented" )
277277
278278 def configure (
279- self , store_addr : str , replica_id : str , rank : int , world_size : int
279+ self ,
280+ store_addr : str ,
281+ replica_id : str ,
282+ rank : int ,
283+ world_size : int ,
284+ quorum_id : Optional [int ] = None ,
285+ group_rank : Optional [int ] = None ,
286+ group_world_size : Optional [int ] = None ,
287+ global_ranks : Optional [list [int ]] = None ,
280288 ) -> None :
281289 """
282290 This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -292,6 +300,10 @@ def configure(
292300 replica_id: the replica_id for this group
293301 rank: rank of this process
294302 world_size: world size of this process group
303+ quorum_id: current quorum's identifier
304+ group_rank: local rank within the replica group
305+ group_world_size: the number of ranks within a replica
306+ global_ranks: the global ranks part of this process group
295307 """
296308 raise NotImplementedError ("not implemented" )
297309
@@ -406,6 +418,10 @@ def __init__(
406418 self ._timeout = timeout
407419 self ._replica_id : str | None = None
408420 self ._rank : int | None = None
421+ self ._quorum_id : int | None = None
422+ self ._group_rank : int | None = None
423+ self ._group_world_size : int | None = None
424+ self ._global_ranks : list [int ] | None = None
409425
410426 self .errors_logger : logging .Logger = logging .getLogger ("torchft_errors" )
411427
@@ -417,13 +433,34 @@ def getBackendName(self) -> str:
417433 raise NotImplementedError ("not implemented" )
418434
419435 def configure (
420- self , store_addr : str , replica_id : str , rank : int , world_size : int
436+ self ,
437+ store_addr : str ,
438+ replica_id : str ,
439+ rank : int ,
440+ world_size : int ,
441+ quorum_id : Optional [int ] = None ,
442+ group_rank : Optional [int ] = None ,
443+ group_world_size : Optional [int ] = None ,
444+ global_ranks : Optional [list [int ]] = None ,
421445 ) -> None :
422446 pg = self ._pg
423447 self ._replica_id = replica_id
448+ self ._quorum_id = quorum_id
449+ self ._group_rank = group_rank
450+ self ._group_world_size = group_world_size
424451 self ._rank = rank
452+ self ._global_ranks = global_ranks
425453 if isinstance (pg , ProcessGroup ):
426- pg .configure (store_addr , replica_id , rank , world_size )
454+ pg .configure (
455+ store_addr ,
456+ replica_id ,
457+ rank ,
458+ world_size ,
459+ quorum_id ,
460+ group_rank ,
461+ group_world_size ,
462+ global_ranks ,
463+ )
427464 return
428465
429466 # abort if already initialized
@@ -441,6 +478,7 @@ def abort(self, errored: bool = True) -> None:
441478 "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
442479 "replica_id" : self ._replica_id ,
443480 "rank" : self ._rank ,
481+ "quorum_id" : self ._quorum_id ,
444482 "error" : "process_group_abort" ,
445483 },
446484 )
@@ -613,6 +651,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
613651 # pyre-fixme[16]: no attribute ProcessGroupGloo
614652 backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
615653 backend_class ._set_sequence_number_for_group ()
654+
655+ if self ._global_ranks :
656+ backend_class .options .global_ranks_in_group = self ._global_ranks
657+ if self ._group_rank and self ._group_world_size :
658+ backend_class .options .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
659+
616660 pg ._register_backend (
617661 torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
618662 )
@@ -810,7 +854,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
810854 # pyre-fixme[16]: no attribute ProcessGroupNCCL
811855 opts = BaseProcessGroupNCCL .Options ()
812856 opts .config .blocking = False
813- opts .global_ranks_in_group = list (range (world_size ))
857+ if self ._global_ranks :
858+ opts .global_ranks_in_group = self ._global_ranks
859+ if self ._group_rank and self ._group_world_size :
860+ opts .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
814861
815862 pg = BaseProcessGroup (store , rank , world_size )
816863 pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -977,7 +1024,15 @@ def __init__(self, rank: int, world: int) -> None:
9771024 self .configure_count = 0
9781025
9791026 def configure (
980- self , store_addr : str , replica_id : str , rank : int , world_size : int
1027+ self ,
1028+ store_addr : str ,
1029+ replica_id : str ,
1030+ rank : int ,
1031+ world_size : int ,
1032+ quorum_id : Optional [int ] = None ,
1033+ group_rank : Optional [int ] = None ,
1034+ group_world_size : Optional [int ] = None ,
1035+ global_ranks : Optional [list [int ]] = None ,
9811036 ) -> None :
9821037 self .configure_count += 1
9831038
@@ -1136,11 +1191,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11361191 self ._error : Optional [Exception ] = None
11371192
11381193 def configure (
1139- self , store_addr : str , replica_id : str , rank : int , world_size : int
1194+ self ,
1195+ store_addr : str ,
1196+ replica_id : str ,
1197+ rank : int ,
1198+ world_size : int ,
1199+ quorum_id : Optional [int ] = None ,
1200+ group_rank : Optional [int ] = None ,
1201+ group_world_size : Optional [int ] = None ,
1202+ global_ranks : Optional [list [int ]] = None ,
11401203 ) -> None :
11411204 self ._error = None
11421205
1143- super ().configure (store_addr , replica_id , rank , world_size )
1206+ super ().configure (
1207+ store_addr ,
1208+ replica_id ,
1209+ rank ,
1210+ world_size ,
1211+ quorum_id ,
1212+ group_rank ,
1213+ group_world_size ,
1214+ global_ranks ,
1215+ )
11441216
11451217 def report_error (self , e : Exception ) -> None :
11461218 """
@@ -1192,11 +1264,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11921264 self ._future_error : Optional [Exception ] = None
11931265
11941266 def configure (
1195- self , store_addr : str , replica_id : str , rank : int , world_size : int
1267+ self ,
1268+ store_addr : str ,
1269+ replica_id : str ,
1270+ rank : int ,
1271+ world_size : int ,
1272+ quorum_id : Optional [int ] = None ,
1273+ group_rank : Optional [int ] = None ,
1274+ group_world_size : Optional [int ] = None ,
1275+ global_ranks : Optional [list [int ]] = None ,
11961276 ) -> None :
11971277 self ._future_error = None
11981278
1199- super ().configure (store_addr , replica_id , rank , world_size )
1279+ super ().configure (
1280+ store_addr ,
1281+ replica_id ,
1282+ rank ,
1283+ world_size ,
1284+ quorum_id ,
1285+ group_rank ,
1286+ group_world_size ,
1287+ global_ranks ,
1288+ )
12001289
12011290 def report_future_error (self , e : Exception ) -> None :
12021291 """
@@ -1410,7 +1499,15 @@ def shutdown(self) -> None:
14101499 self ._p .kill ()
14111500
14121501 def configure (
1413- self , store_addr : str , replica_id : str , rank : int , world_size : int
1502+ self ,
1503+ store_addr : str ,
1504+ replica_id : str ,
1505+ rank : int ,
1506+ world_size : int ,
1507+ quorum_id : Optional [int ] = None ,
1508+ group_rank : Optional [int ] = None ,
1509+ group_world_size : Optional [int ] = None ,
1510+ global_ranks : Optional [list [int ]] = None ,
14141511 ) -> None :
14151512 self ._world_size = world_size
14161513
0 commit comments