Skip to content

Commit 15bd67c

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum - return replica id's in quorum response so we can determine global ranks in the pg - this is used to set the metadata on the pg for flight recorder to work Reviewed By: d4l3k Differential Revision: D84260745
1 parent 75ee24f commit 15bd67c

File tree

6 files changed

+175
-13
lines changed

6 files changed

+175
-13
lines changed

proto/torchft.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ message ManagerQuorumResponse {
9696
int64 replica_world_size = 10;
9797
bool heal = 11;
9898
int64 commit_failures = 12;
99+
repeated string replica_ids = 13;
99100
}
100101

101102
message CheckpointMetadataRequest {

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ impl ManagerClient {
213213
max_replica_rank: resp.max_replica_rank,
214214
max_world_size: resp.max_world_size,
215215
heal: resp.heal,
216+
replica_ids: resp.replica_ids,
216217
})
217218
})
218219
}
@@ -293,6 +294,7 @@ struct QuorumResult {
293294
max_replica_rank: Option<i64>,
294295
max_world_size: i64,
295296
heal: bool,
297+
replica_ids: Vec<String>,
296298
}
297299

298300
#[pymethods]
@@ -311,6 +313,7 @@ impl QuorumResult {
311313
max_replica_rank: None,
312314
max_world_size: 1,
313315
heal: false,
316+
replica_ids: Vec::new(),
314317
}
315318
}
316319
}

src/manager.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ fn compute_quorum_results(
620620
.map(|p| p.commit_failures)
621621
.max()
622622
.unwrap_or(0),
623+
replica_ids: participants.iter().map(|p| p.replica_id.clone()).collect(),
623624
})
624625
}
625626

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class QuorumResult:
3636
max_world_size: int
3737
heal: bool
3838
commit_failures: int
39+
replica_ids: list[str]
3940

4041
class ManagerServer:
4142
def __init__(

torchft/manager.py

Lines changed: 62 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,17 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._global_rank: int = (
358+
self._group_rank
359+
if self._replica_id is None
360+
else (
361+
extract_trailing_digits(self._replica_id) * self._group_world_size
362+
+ self._group_rank
363+
)
364+
)
365+
366+
self._update_fr_path()
367+
341368
def allow_state_dict_read(self) -> None:
342369
if self._is_state_dict_read_allowed:
343370
return
@@ -635,6 +662,13 @@ def _async_quorum(
635662
max_replica_rank = quorum.max_replica_rank
636663
max_replica_world_size = quorum.max_world_size
637664
heal = quorum.heal
665+
replica_ids = quorum.replica_ids
666+
667+
ranks_in_quorum = [
668+
extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
669+
+ self._group_rank
670+
for replica_id in replica_ids
671+
]
638672

639673
# When using async quorum we need to take the recovered workers.
640674
# When not using async quorum we need to take the max world size as all
@@ -675,16 +709,30 @@ def _async_quorum(
675709
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
676710
# We use the replica rank and world as we want all replicas in the PG.
677711
try:
712+
self._quorum_id = quorum_id
678713
with torch.profiler.record_function("torchft::manager::_pg::configure"):
714+
# Reset GPU state for Flight Recorder
679715
if torch.accelerator.is_available():
680716
torch.accelerator.synchronize()
717+
681718
self._pg.configure(
682719
store_prefixed_addr,
683720
self._replica_id if self._replica_id is not None else "0",
684721
replica_rank,
685722
replica_world_size,
723+
quorum_id,
724+
self._group_rank,
725+
self._group_world_size,
726+
ranks_in_quorum,
686727
)
687-
self._quorum_id = quorum_id
728+
729+
# We need to reset the trace after reconfiguring the PG because that
730+
# calls abort which may trigger a dump
731+
self._logger.info(
732+
f"resetting fr recording for quorum id {self._quorum_id}"
733+
)
734+
self._update_fr_path()
735+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
688736
except Exception as e:
689737
self._logger.exception(f"got exception in pg configure: {e}")
690738
self.report_error(e)
@@ -759,6 +807,17 @@ def _async_quorum(
759807
else None
760808
)
761809

810+
def _update_fr_path(self) -> None:
811+
"""
812+
Update the path that flight recorder will dump the traces to.
813+
The format is
814+
<TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
815+
"""
816+
if self._original_fr_dump_temp_file is not None:
817+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
818+
os.makedirs(folder, exist_ok=True)
819+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
820+
762821
def _apply_pending_state_dict(self) -> None:
763822
assert self._healing, "must be in healing state"
764823

torchft/process_group.py

Lines changed: 107 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -276,7 +276,15 @@ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
276276
raise NotImplementedError("not implemented")
277277

278278
def configure(
279-
self, store_addr: str, replica_id: str, rank: int, world_size: int
279+
self,
280+
store_addr: str,
281+
replica_id: str,
282+
rank: int,
283+
world_size: int,
284+
quorum_id: Optional[int] = None,
285+
group_rank: Optional[int] = None,
286+
group_world_size: Optional[int] = None,
287+
global_ranks: Optional[list[int]] = None,
280288
) -> None:
281289
"""
282290
This reconfigures the ProcessGroup to use a new store, rank and world size.
@@ -292,6 +300,10 @@ def configure(
292300
replica_id: the replica_id for this group
293301
rank: rank of this process
294302
world_size: world size of this process group
303+
quorum_id: current quorum's identifier
304+
group_rank: local rank within the replica group
305+
group_world_size: the number of ranks within a replica
306+
global_ranks: the global ranks part of this process group
295307
"""
296308
raise NotImplementedError("not implemented")
297309

@@ -406,6 +418,10 @@ def __init__(
406418
self._timeout = timeout
407419
self._replica_id: str | None = None
408420
self._rank: int | None = None
421+
self._quorum_id: int | None = None
422+
self._group_rank: int | None = None
423+
self._group_world_size: int | None = None
424+
self._global_ranks: list[int] | None = None
409425

410426
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
411427

@@ -417,13 +433,34 @@ def getBackendName(self) -> str:
417433
raise NotImplementedError("not implemented")
418434

419435
def configure(
420-
self, store_addr: str, replica_id: str, rank: int, world_size: int
436+
self,
437+
store_addr: str,
438+
replica_id: str,
439+
rank: int,
440+
world_size: int,
441+
quorum_id: Optional[int] = None,
442+
group_rank: Optional[int] = None,
443+
group_world_size: Optional[int] = None,
444+
global_ranks: Optional[list[int]] = None,
421445
) -> None:
422446
pg = self._pg
423447
self._replica_id = replica_id
448+
self._quorum_id = quorum_id
449+
self._group_rank = group_rank
450+
self._group_world_size = group_world_size
424451
self._rank = rank
452+
self._global_ranks = global_ranks
425453
if isinstance(pg, ProcessGroup):
426-
pg.configure(store_addr, replica_id, rank, world_size)
454+
pg.configure(
455+
store_addr,
456+
replica_id,
457+
rank,
458+
world_size,
459+
quorum_id,
460+
group_rank,
461+
group_world_size,
462+
global_ranks,
463+
)
427464
return
428465

429466
# abort if already initialized
@@ -441,6 +478,7 @@ def abort(self, errored: bool = True) -> None:
441478
"job_id": os.environ.get("JOB_ID", "unknown"),
442479
"replica_id": self._replica_id,
443480
"rank": self._rank,
481+
"quorum_id": self._quorum_id,
444482
"error": "process_group_abort",
445483
},
446484
)
@@ -613,6 +651,12 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
613651
# pyre-fixme[16]: no attribute ProcessGroupGloo
614652
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
615653
backend_class._set_sequence_number_for_group()
654+
655+
if self._global_ranks:
656+
backend_class.options.global_ranks_in_group = self._global_ranks
657+
if self._group_rank and self._group_world_size:
658+
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
659+
616660
pg._register_backend(
617661
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
618662
)
@@ -810,7 +854,10 @@ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGro
810854
# pyre-fixme[16]: no attribute ProcessGroupNCCL
811855
opts = BaseProcessGroupNCCL.Options()
812856
opts.config.blocking = False
813-
opts.global_ranks_in_group = list(range(world_size))
857+
if self._global_ranks:
858+
opts.global_ranks_in_group = self._global_ranks
859+
if self._group_rank and self._group_world_size:
860+
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
814861

815862
pg = BaseProcessGroup(store, rank, world_size)
816863
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
@@ -977,7 +1024,15 @@ def __init__(self, rank: int, world: int) -> None:
9771024
self.configure_count = 0
9781025

9791026
def configure(
980-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1027+
self,
1028+
store_addr: str,
1029+
replica_id: str,
1030+
rank: int,
1031+
world_size: int,
1032+
quorum_id: Optional[int] = None,
1033+
group_rank: Optional[int] = None,
1034+
group_world_size: Optional[int] = None,
1035+
global_ranks: Optional[list[int]] = None,
9811036
) -> None:
9821037
self.configure_count += 1
9831038

@@ -1136,11 +1191,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11361191
self._error: Optional[Exception] = None
11371192

11381193
def configure(
1139-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1194+
self,
1195+
store_addr: str,
1196+
replica_id: str,
1197+
rank: int,
1198+
world_size: int,
1199+
quorum_id: Optional[int] = None,
1200+
group_rank: Optional[int] = None,
1201+
group_world_size: Optional[int] = None,
1202+
global_ranks: Optional[list[int]] = None,
11401203
) -> None:
11411204
self._error = None
11421205

1143-
super().configure(store_addr, replica_id, rank, world_size)
1206+
super().configure(
1207+
store_addr,
1208+
replica_id,
1209+
rank,
1210+
world_size,
1211+
quorum_id,
1212+
group_rank,
1213+
group_world_size,
1214+
global_ranks,
1215+
)
11441216

11451217
def report_error(self, e: Exception) -> None:
11461218
"""
@@ -1192,11 +1264,28 @@ def __init__(self, pg: ProcessGroup) -> None:
11921264
self._future_error: Optional[Exception] = None
11931265

11941266
def configure(
1195-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1267+
self,
1268+
store_addr: str,
1269+
replica_id: str,
1270+
rank: int,
1271+
world_size: int,
1272+
quorum_id: Optional[int] = None,
1273+
group_rank: Optional[int] = None,
1274+
group_world_size: Optional[int] = None,
1275+
global_ranks: Optional[list[int]] = None,
11961276
) -> None:
11971277
self._future_error = None
11981278

1199-
super().configure(store_addr, replica_id, rank, world_size)
1279+
super().configure(
1280+
store_addr,
1281+
replica_id,
1282+
rank,
1283+
world_size,
1284+
quorum_id,
1285+
group_rank,
1286+
group_world_size,
1287+
global_ranks,
1288+
)
12001289

12011290
def report_future_error(self, e: Exception) -> None:
12021291
"""
@@ -1410,7 +1499,15 @@ def shutdown(self) -> None:
14101499
self._p.kill()
14111500

14121501
def configure(
1413-
self, store_addr: str, replica_id: str, rank: int, world_size: int
1502+
self,
1503+
store_addr: str,
1504+
replica_id: str,
1505+
rank: int,
1506+
world_size: int,
1507+
quorum_id: Optional[int] = None,
1508+
group_rank: Optional[int] = None,
1509+
group_world_size: Optional[int] = None,
1510+
global_ranks: Optional[list[int]] = None,
14141511
) -> None:
14151512
self._world_size = world_size
14161513

0 commit comments

Comments
 (0)