8888# crash if call to quorum fails, all replicas will crash.
8989QUORUM_RETRIES_ENV : str = "TORCHFT_QUORUM_RETRIES"
9090
91+ TORCH_FR_DUMP_TEMP_FILE_ENV : str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193T = TypeVar ("T" )
9294
9395
@@ -109,6 +111,17 @@ def get_timeout(
109111 return default_timeout_sec
110112
111113
114+ def extract_trailing_digits (s : str ) -> int :
115+ """
116+ Extracts the trailing digits from the end of the string s.
117+ Returns an empty string if no trailing digits are found.
118+ """
119+ i = len (s ) - 1
120+ while i >= 0 and s [i ].isdigit ():
121+ i -= 1
122+ return int (s [i + 1 :]) if i < len (s ) - 1 else 0
123+
124+
112125class WorldSizeMode (Enum ):
113126 """
114127 This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236 self ._load_state_dict_fns : Dict [str , Callable [[object ], None ]] = {}
224237 self ._user_state_dicts : Dict [str , Callable [[], object ]] = {}
225238
239+ self ._original_fr_dump_temp_file : Optional [str ] = os .environ .get (
240+ TORCH_FR_DUMP_TEMP_FILE_ENV
241+ )
226242 self ._replica_id = replica_id
227243
228244 # Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273 store_port = store_port or int (os .environ ["MASTER_PORT" ])
258274 self ._group_rank : int = rank if rank is not None else int (os .environ ["RANK" ])
259275 group_rank = self ._group_rank
260- group_world_size = world_size or int (os .environ ["WORLD_SIZE" ])
276+ self . _group_world_size : int = world_size or int (os .environ ["WORLD_SIZE" ])
261277 self ._min_replica_size = min_replica_size
262278
263279 if checkpoint_transport is None :
@@ -310,7 +326,7 @@ def __init__(
310326 hostname = hostname ,
311327 bind = bind ,
312328 store_addr = f"{ store_addr } :{ store_port } " ,
313- world_size = group_world_size ,
329+ world_size = self . _group_world_size ,
314330 heartbeat_interval = heartbeat_interval ,
315331 connect_timeout = connect_timeout ,
316332 quorum_retries = self ._quorum_retries ,
@@ -338,6 +354,17 @@ def __init__(
338354 self ._participating_replica_world_size : int = 0
339355 self ._is_state_dict_read_allowed = True
340356
357+ self ._global_rank : int = (
358+ self ._group_rank
359+ if self ._replica_id is None
360+ else (
361+ extract_trailing_digits (self ._replica_id ) * self ._group_world_size
362+ + self ._group_rank
363+ )
364+ )
365+
366+ self ._update_fr_path ()
367+
341368 def allow_state_dict_read (self ) -> None :
342369 if self ._is_state_dict_read_allowed :
343370 return
@@ -446,7 +473,7 @@ def allreduce(
446473 # on the Future
447474 @torch .profiler .record_function ("torchft::manager::allreduce::callback" )
448475 def callback (
449- fut : torch .futures .Future [list [ torch .Tensor ] ],
476+ fut : torch .futures .Future [torch .Tensor ],
450477 ) -> torch .Tensor :
451478 nonlocal tensor
452479 if reduce_op == ReduceOp .AVG :
@@ -455,6 +482,7 @@ def callback(
455482
456483 managed_work = _ManagedWork (self , work , tensor )
457484 fut = managed_work .get_future ()
485+ fut = cast (torch .futures .Future [torch .Tensor ], fut )
458486 fut = fut .then (callback )
459487 return managed_work
460488
@@ -634,6 +662,13 @@ def _async_quorum(
634662 max_replica_rank = quorum .max_replica_rank
635663 max_replica_world_size = quorum .max_world_size
636664 heal = quorum .heal
665+ replica_ids = quorum .replica_ids
666+
667+ ranks_in_quorum = [
668+ extract_trailing_digits (replica_id .split (":" )[0 ]) * self ._group_world_size
669+ + self ._group_rank
670+ for replica_id in replica_ids
671+ ]
637672
638673 # When using async quorum we need to take the recovered workers.
639674 # When not using async quorum we need to take the max world size as all
@@ -674,16 +709,30 @@ def _async_quorum(
674709 self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
675710 # We use the replica rank and world as we want all replicas in the PG.
676711 try :
712+ self ._quorum_id = quorum_id
677713 with torch .profiler .record_function ("torchft::manager::_pg::configure" ):
714+ # Reset GPU state for Flight Recorder
678715 if torch .accelerator .is_available ():
679716 torch .accelerator .synchronize ()
717+
680718 self ._pg .configure (
681719 store_prefixed_addr ,
682720 self ._replica_id if self ._replica_id is not None else "0" ,
683721 replica_rank ,
684722 replica_world_size ,
723+ quorum_id ,
724+ self ._group_rank ,
725+ self ._group_world_size ,
726+ ranks_in_quorum ,
685727 )
686- self ._quorum_id = quorum_id
728+
729+ # We need to reset the trace after reconfiguring the PG because that
730+ # calls abort which may trigger a dump
731+ self ._logger .info (
732+ f"resetting fr recording for quorum id { self ._quorum_id } "
733+ )
734+ self ._update_fr_path ()
735+ torch ._C ._distributed_c10d ._reset_fr_recording_nccl () # pyre-ignore
687736 except Exception as e :
688737 self ._logger .exception (f"got exception in pg configure: { e } " )
689738 self .report_error (e )
@@ -758,6 +807,17 @@ def _async_quorum(
758807 else None
759808 )
760809
810+ def _update_fr_path (self ) -> None :
811+ """
812+ Update the path that flight recorder will dump the traces to.
813+ The format is
814+ <TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
815+ """
816+ if self ._original_fr_dump_temp_file is not None :
817+ folder = f"{ self ._original_fr_dump_temp_file } _quorum_{ self ._quorum_id } "
818+ os .makedirs (folder , exist_ok = True )
819+ os .environ [TORCH_FR_DUMP_TEMP_FILE_ENV ] = f"{ folder } /{ self ._global_rank } "
820+
761821 def _apply_pending_state_dict (self ) -> None :
762822 assert self ._healing , "must be in healing state"
763823
0 commit comments