@@ -278,7 +278,14 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
278278 raise NotImplementedError ("not implemented" )
279279
280280 def configure (
281- self , store_addr : str , replica_id : str , rank : int , world_size : int
281+ self ,
282+ store_addr : str ,
283+ replica_id : str ,
284+ rank : int ,
285+ world_size : int ,
286+ quorum_id : Optional [int ] = None ,
287+ group_rank : int = 0 ,
288+ group_world_size : int = 1 ,
282289 ) -> None :
283290 """
284291 This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -408,6 +415,9 @@ def __init__(
408415 self ._timeout = timeout
409416 self ._replica_id : str | None = None
410417 self ._rank : int | None = None
418+ self ._quorum_id : int | None = None
419+ self ._group_rank : int = 0
420+ self ._group_world_size : int = 1
411421
412422 self .errors_logger : logging .Logger = logging .getLogger ("torchft_errors" )
413423
@@ -419,13 +429,31 @@ def getBackendName(self) -> str:
419429 raise NotImplementedError ("not implemented" )
420430
421431 def configure (
422- self , store_addr : str , replica_id : str , rank : int , world_size : int
432+ self ,
433+ store_addr : str ,
434+ replica_id : str ,
435+ rank : int ,
436+ world_size : int ,
437+ quorum_id : Optional [int ] = None ,
438+ group_rank : int = 0 ,
439+ group_world_size : int = 1 ,
423440 ) -> None :
424441 pg = self ._pg
425442 self ._replica_id = replica_id
443+ self ._quorum_id = quorum_id
444+ self ._group_rank = group_rank
445+ self ._group_world_size = group_world_size
426446 self ._rank = rank
427447 if isinstance (pg , ProcessGroup ):
428- pg .configure (store_addr , replica_id , rank , world_size )
448+ pg .configure (
449+ store_addr ,
450+ replica_id ,
451+ rank ,
452+ world_size ,
453+ quorum_id ,
454+ group_rank ,
455+ group_world_size ,
456+ )
429457 return
430458
431459 # abort if already initialized
@@ -443,6 +471,7 @@ def abort(self, errored: bool = True) -> None:
443471 "job_id" : os .environ .get ("JOB_ID" , "unknown" ),
444472 "replica_id" : self ._replica_id ,
445473 "rank" : self ._rank ,
474+ "quorum_id" : self ._quorum_id ,
446475 "error" : "process_group_abort" ,
447476 },
448477 )
@@ -615,6 +644,8 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
615644 # pyre-fixme[16]: no attribute ProcessGroupGloo
616645 backend_class = BaseProcessGroupGloo (store , rank , world_size , self ._timeout )
617646 backend_class ._set_sequence_number_for_group ()
647+ backend_class .options .global_ranks_in_group = list (range (world_size ))
648+ backend_class .options .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
618649 pg ._register_backend (
619650 torch .device ("cpu" ), ProcessGroup .BackendType .GLOO , backend_class
620651 )
@@ -813,6 +844,7 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
813844 opts = BaseProcessGroupNCCL .Options ()
814845 opts .config .blocking = False
815846 opts .global_ranks_in_group = list (range (world_size ))
847+ opts .group_name = f"torchft_quorum_{ self ._quorum_id } _rank_{ self ._group_rank % self ._group_world_size } "
816848
817849 pg = BaseProcessGroup (store , rank , world_size )
818850 pg ._set_default_backend (ProcessGroup .BackendType .NCCL )
@@ -979,7 +1011,12 @@ def __init__(self, rank: int, world: int) -> None:
9791011 self .configure_count = 0
9801012
9811013 def configure (
982- self , store_addr : str , replica_id : str , rank : int , world_size : int
1014+ self ,
1015+ store_addr : str ,
1016+ replica_id : str ,
1017+ rank : int ,
1018+ world_size : int ,
1019+ quorum_id : Optional [int ] = None ,
9831020 ) -> None :
9841021 self .configure_count += 1
9851022
@@ -1138,11 +1175,26 @@ def __init__(self, pg: ProcessGroup) -> None:
11381175 self ._error : Optional [Exception ] = None
11391176
11401177 def configure (
1141- self , store_addr : str , replica_id : str , rank : int , world_size : int
1178+ self ,
1179+ store_addr : str ,
1180+ replica_id : str ,
1181+ rank : int ,
1182+ world_size : int ,
1183+ quorum_id : Optional [int ] = None ,
1184+ group_rank : int = 0 ,
1185+ group_world_size : int = 1 ,
11421186 ) -> None :
11431187 self ._error = None
11441188
1145- super ().configure (store_addr , replica_id , rank , world_size )
1189+ super ().configure (
1190+ store_addr ,
1191+ replica_id ,
1192+ rank ,
1193+ world_size ,
1194+ quorum_id ,
1195+ group_rank ,
1196+ group_world_size ,
1197+ )
11461198
11471199 def report_error (self , e : Exception ) -> None :
11481200 """
@@ -1194,11 +1246,16 @@ def __init__(self, pg: ProcessGroup) -> None:
11941246 self ._future_error : Optional [Exception ] = None
11951247
11961248 def configure (
1197- self , store_addr : str , replica_id : str , rank : int , world_size : int
1249+ self ,
1250+ store_addr : str ,
1251+ replica_id : str ,
1252+ rank : int ,
1253+ world_size : int ,
1254+ quorum_id : Optional [int ] = None ,
11981255 ) -> None :
11991256 self ._future_error = None
12001257
1201- super ().configure (store_addr , replica_id , rank , world_size )
1258+ super ().configure (store_addr , replica_id , rank , world_size , quorum_id )
12021259
12031260 def report_future_error (self , e : Exception ) -> None :
12041261 """
@@ -1412,7 +1469,12 @@ def shutdown(self) -> None:
14121469 self ._p .kill ()
14131470
14141471 def configure (
1415- self , store_addr : str , replica_id : str , rank : int , world_size : int
1472+ self ,
1473+ store_addr : str ,
1474+ replica_id : str ,
1475+ rank : int ,
1476+ world_size : int ,
1477+ quorum_id : Optional [int ] = None ,
14161478 ) -> None :
14171479 self ._world_size = world_size
14181480
0 commit comments