Skip to content

Commit c5a3407

Browse files
tushar00jainfacebook-github-bot
authored andcommitted
reset flight recorder trace (#283)
Summary: - call FR api to reset the trace after every quorum - we reset so that after every quorum, we start a fresh FR trace since the pg's could have changed and we already dumped FR trace from previous errors - change the env var that's used to determine the file after every quorum - return replica id's in quorum response so we can determine global ranks in the pg - this is used to set the metadata on the pg for flight recorder to work Reviewed By: d4l3k Differential Revision: D84260745
1 parent cc91bfb commit c5a3407

File tree

6 files changed

+177
-14
lines changed

6 files changed

+177
-14
lines changed

proto/torchft.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ message ManagerQuorumResponse {
9696
int64 replica_world_size = 10;
9797
bool heal = 11;
9898
int64 commit_failures = 12;
99+
repeated string replica_ids = 13;
99100
}
100101

101102
message CheckpointMetadataRequest {

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,7 @@ impl ManagerClient {
213213
max_replica_rank: resp.max_replica_rank,
214214
max_world_size: resp.max_world_size,
215215
heal: resp.heal,
216+
replica_ids: resp.replica_ids,
216217
})
217218
})
218219
}
@@ -293,6 +294,7 @@ struct QuorumResult {
293294
max_replica_rank: Option<i64>,
294295
max_world_size: i64,
295296
heal: bool,
297+
replica_ids: Vec<String>,
296298
}
297299

298300
#[pymethods]
@@ -311,6 +313,7 @@ impl QuorumResult {
311313
max_replica_rank: None,
312314
max_world_size: 1,
313315
heal: false,
316+
replica_ids: Vec::new(),
314317
}
315318
}
316319
}

src/manager.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -620,6 +620,7 @@ fn compute_quorum_results(
620620
.map(|p| p.commit_failures)
621621
.max()
622622
.unwrap_or(0),
623+
replica_ids: participants.iter().map(|p| p.replica_id.clone()).collect(),
623624
})
624625
}
625626

torchft/_torchft.pyi

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ class QuorumResult:
3636
max_world_size: int
3737
heal: bool
3838
commit_failures: int
39+
replica_ids: list[str]
3940

4041
class ManagerServer:
4142
def __init__(

torchft/manager.py

Lines changed: 64 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,8 @@
8888
# crash if call to quorum fails, all replicas will crash.
8989
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
9090

91+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92+
9193
T = TypeVar("T")
9294

9395

@@ -109,6 +111,17 @@ def get_timeout(
109111
return default_timeout_sec
110112

111113

114+
def extract_trailing_digits(s: str) -> int:
115+
"""
116+
Extracts the trailing digits from the end of the string s.
117+
Returns an empty string if no trailing digits are found.
118+
"""
119+
i = len(s) - 1
120+
while i >= 0 and s[i].isdigit():
121+
i -= 1
122+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
123+
124+
112125
class WorldSizeMode(Enum):
113126
"""
114127
This controls the numerics for the job when doing allreduces across replicas
@@ -223,6 +236,9 @@ def __init__(
223236
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
224237
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
225238

239+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240+
TORCH_FR_DUMP_TEMP_FILE_ENV
241+
)
226242
self._replica_id = replica_id
227243

228244
# Protects state dict
@@ -257,7 +273,7 @@ def __init__(
257273
store_port = store_port or int(os.environ["MASTER_PORT"])
258274
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
259275
group_rank = self._group_rank
260-
group_world_size = world_size or int(os.environ["WORLD_SIZE"])
276+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
261277
self._min_replica_size = min_replica_size
262278

263279
if checkpoint_transport is None:
@@ -310,7 +326,7 @@ def __init__(
310326
hostname=hostname,
311327
bind=bind,
312328
store_addr=f"{store_addr}:{store_port}",
313-
world_size=group_world_size,
329+
world_size=self._group_world_size,
314330
heartbeat_interval=heartbeat_interval,
315331
connect_timeout=connect_timeout,
316332
quorum_retries=self._quorum_retries,
@@ -338,6 +354,17 @@ def __init__(
338354
self._participating_replica_world_size: int = 0
339355
self._is_state_dict_read_allowed = True
340356

357+
self._global_rank: int = (
358+
self._group_rank
359+
if self._replica_id is None
360+
else (
361+
extract_trailing_digits(self._replica_id) * self._group_world_size
362+
+ self._group_rank
363+
)
364+
)
365+
366+
self._update_fr_path()
367+
341368
def allow_state_dict_read(self) -> None:
342369
if self._is_state_dict_read_allowed:
343370
return
@@ -446,7 +473,7 @@ def allreduce(
446473
# on the Future
447474
@torch.profiler.record_function("torchft::manager::allreduce::callback")
448475
def callback(
449-
fut: torch.futures.Future[list[torch.Tensor]],
476+
fut: torch.futures.Future[torch.Tensor],
450477
) -> torch.Tensor:
451478
nonlocal tensor
452479
if reduce_op == ReduceOp.AVG:
@@ -455,6 +482,7 @@ def callback(
455482

456483
managed_work = _ManagedWork(self, work, tensor)
457484
fut = managed_work.get_future()
485+
fut = cast(torch.futures.Future[torch.Tensor], fut)
458486
fut = fut.then(callback)
459487
return managed_work
460488

@@ -634,6 +662,13 @@ def _async_quorum(
634662
max_replica_rank = quorum.max_replica_rank
635663
max_replica_world_size = quorum.max_world_size
636664
heal = quorum.heal
665+
replica_ids = quorum.replica_ids
666+
667+
ranks_in_quorum = [
668+
extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
669+
+ self._group_rank
670+
for replica_id in replica_ids
671+
]
637672

638673
# When using async quorum we need to take the recovered workers.
639674
# When not using async quorum we need to take the max world size as all
@@ -674,16 +709,30 @@ def _async_quorum(
674709
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
675710
# We use the replica rank and world as we want all replicas in the PG.
676711
try:
712+
self._quorum_id = quorum_id
677713
with torch.profiler.record_function("torchft::manager::_pg::configure"):
714+
# Reset GPU state for Flight Recorder
678715
if torch.accelerator.is_available():
679716
torch.accelerator.synchronize()
717+
680718
self._pg.configure(
681719
store_prefixed_addr,
682720
self._replica_id if self._replica_id is not None else "0",
683721
replica_rank,
684722
replica_world_size,
723+
quorum_id,
724+
self._group_rank,
725+
self._group_world_size,
726+
ranks_in_quorum,
685727
)
686-
self._quorum_id = quorum_id
728+
729+
# We need to reset the trace after reconfiguring the PG because that
730+
# calls abort which may trigger a dump
731+
self._logger.info(
732+
f"resetting fr recording for quorum id {self._quorum_id}"
733+
)
734+
self._update_fr_path()
735+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
687736
except Exception as e:
688737
self._logger.exception(f"got exception in pg configure: {e}")
689738
self.report_error(e)
@@ -758,6 +807,17 @@ def _async_quorum(
758807
else None
759808
)
760809

810+
def _update_fr_path(self) -> None:
811+
"""
812+
Update the path that flight recorder will dump the traces to.
813+
The format is
814+
<TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
815+
"""
816+
if self._original_fr_dump_temp_file is not None:
817+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
818+
os.makedirs(folder, exist_ok=True)
819+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
820+
761821
def _apply_pending_state_dict(self) -> None:
762822
assert self._healing, "must be in healing state"
763823

0 commit comments

Comments
 (0)