Skip to content

Commit b037523

Browse files
tushar00jaintushar00jain
andauthored
✨ add profiling to manager (#178)
Summary: Fixes #137. Add profiler annotations to manager.py Test Plan: <img width="1499" alt="image" src="https://github.com/user-attachments/assets/b34b3701-66b1-4b90-9cab-eec8db35bc38" /> Reviewers: @d4l3k Co-authored-by: tushar00jain <tushar00jain@devvm5549.pnb0.facebook.com>
1 parent d7f6d1b commit b037523

File tree

2 files changed

+57
-22
lines changed

2 files changed

+57
-22
lines changed

torchft/manager.py

Lines changed: 36 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -448,6 +448,7 @@ def wait_quorum(self) -> None:
448448
), "must call start_quorum before wait_quorum"
449449
self._quorum_future.result()
450450

451+
@torch.profiler.record_function("torchft::manager::_async_quorum")
451452
def _async_quorum(
452453
self,
453454
allow_heal: bool,
@@ -459,14 +460,17 @@ def _async_quorum(
459460

460461
if curr_device >= 0 and torch.cuda.is_available():
461462
torch.cuda.set_device(curr_device)
462-
quorum = self._client._quorum(
463-
rank=self._rank,
464-
step=self._step,
465-
checkpoint_metadata=self._checkpoint_transport.metadata(),
466-
shrink_only=shrink_only,
467-
timeout=quorum_timeout,
468-
init_sync=self._init_sync,
469-
)
463+
464+
quorum = None
465+
with torch.profiler.record_function("torchft::manager::_client::_quorum"):
466+
quorum = self._client._quorum(
467+
rank=self._rank,
468+
step=self._step,
469+
checkpoint_metadata=self._checkpoint_transport.metadata(),
470+
shrink_only=shrink_only,
471+
timeout=quorum_timeout,
472+
init_sync=self._init_sync,
473+
)
470474

471475
quorum_id = quorum.quorum_id
472476
replica_rank = quorum.replica_rank
@@ -505,7 +509,10 @@ def _async_quorum(
505509
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
506510
# We use the replica rank and world as we want all replicas in the PG.
507511
# TODO: handle configure errors
508-
self._pg.configure(store_prefixed_addr, replica_rank, replica_world_size)
512+
with torch.profiler.record_function("torchft::manager::_pg.configure"):
513+
self._pg.configure(
514+
store_prefixed_addr, replica_rank, replica_world_size
515+
)
509516
self._quorum_id = quorum_id
510517

511518
if allow_heal:
@@ -520,12 +527,15 @@ def _async_quorum(
520527
self._logger.info(
521528
f"peers need recovery from us {quorum.recover_dst_ranks}"
522529
)
523-
self._checkpoint_transport.send_checkpoint(
524-
dst_ranks=quorum.recover_dst_ranks,
525-
step=max_step,
526-
state_dict=self._manager_state_dict(),
527-
timeout=self._timeout,
528-
)
530+
with torch.profiler.record_function(
531+
"torchft::manager::_checkpoint_transport::send_checkpoint"
532+
):
533+
self._checkpoint_transport.send_checkpoint(
534+
dst_ranks=quorum.recover_dst_ranks,
535+
step=max_step,
536+
state_dict=self._manager_state_dict(),
537+
timeout=self._timeout,
538+
)
529539

530540
# See manager.rs for healing conditions
531541
if heal:
@@ -551,14 +561,17 @@ def _async_quorum(
551561

552562
# we apply the user state dict only when safe from the main thread
553563
# save it for now
554-
self._pending_state_dict = (
555-
self._checkpoint_transport.recv_checkpoint(
556-
src_rank=recover_src_rank,
557-
metadata=checkpoint_metadata,
558-
step=max_step,
559-
timeout=self._timeout,
564+
with torch.profiler.record_function(
565+
"torchft::manager::_checkpoint_transport::recv_checkpoint"
566+
):
567+
self._pending_state_dict = (
568+
self._checkpoint_transport.recv_checkpoint(
569+
src_rank=recover_src_rank,
570+
metadata=checkpoint_metadata,
571+
step=max_step,
572+
timeout=self._timeout,
573+
)
560574
)
561-
)
562575

563576
# pyre-fixme[6]: got object
564577
self.load_state_dict(self._pending_state_dict["torchft"])
@@ -584,6 +597,7 @@ def _apply_pending_state_dict(self) -> None:
584597
self._pending_state_dict = None
585598
self._logger.info("Loaded state dict.")
586599

600+
@torch.profiler.record_function("torchft::manager::should_commit")
587601
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
588602
"""
589603
.. note::

train_ddp.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,10 +143,30 @@ def forward(self, x):
143143
num_params = sum(p.numel() for p in m.parameters())
144144
print(f"Total number of parameters: {num_params}")
145145

146+
sort_by_keyword = "self_" + device + "_time_total"
147+
148+
def trace_handler(p):
149+
output = p.key_averages().table(
150+
sort_by=sort_by_keyword,
151+
row_limit=100,
152+
)
153+
print(output)
154+
p.export_chrome_trace("/tmp/trace_" + str(p.step_num) + ".json")
155+
146156
# You can use an epoch based training but with faults it's easier to use step
147157
# based training.
158+
prof = torch.profiler.profile(
159+
schedule=torch.profiler.schedule(wait=5, warmup=1, active=10, repeat=2),
160+
on_trace_ready=trace_handler,
161+
record_shapes=True,
162+
profile_memory=True,
163+
)
164+
165+
prof.start()
148166
while True:
149167
for i, (inputs, labels) in enumerate(trainloader):
168+
prof.step()
169+
150170
inputs = inputs.to(device)
151171
labels = labels.to(device)
152172

@@ -178,6 +198,7 @@ def forward(self, x):
178198

179199
if manager.current_step() >= 10000:
180200
# complete training
201+
prof.stop()
181202
exit()
182203

183204

0 commit comments

Comments
 (0)