@@ -448,6 +448,7 @@ def wait_quorum(self) -> None:
448
448
), "must call start_quorum before wait_quorum"
449
449
self ._quorum_future .result ()
450
450
451
+ @torch .profiler .record_function ("torchft::manager::_async_quorum" )
451
452
def _async_quorum (
452
453
self ,
453
454
allow_heal : bool ,
@@ -459,14 +460,17 @@ def _async_quorum(
459
460
460
461
if curr_device >= 0 and torch .cuda .is_available ():
461
462
torch .cuda .set_device (curr_device )
462
- quorum = self ._client ._quorum (
463
- rank = self ._rank ,
464
- step = self ._step ,
465
- checkpoint_metadata = self ._checkpoint_transport .metadata (),
466
- shrink_only = shrink_only ,
467
- timeout = quorum_timeout ,
468
- init_sync = self ._init_sync ,
469
- )
463
+
464
+ quorum = None
465
+ with torch .profiler .record_function ("torchft::manager::_client::_quorum" ):
466
+ quorum = self ._client ._quorum (
467
+ rank = self ._rank ,
468
+ step = self ._step ,
469
+ checkpoint_metadata = self ._checkpoint_transport .metadata (),
470
+ shrink_only = shrink_only ,
471
+ timeout = quorum_timeout ,
472
+ init_sync = self ._init_sync ,
473
+ )
470
474
471
475
quorum_id = quorum .quorum_id
472
476
replica_rank = quorum .replica_rank
@@ -505,7 +509,10 @@ def _async_quorum(
505
509
self ._logger .info (f"reconfiguring for { quorum_id = } { store_prefixed_addr = } " )
506
510
# We use the replica rank and world as we want all replicas in the PG.
507
511
# TODO: handle configure errors
508
- self ._pg .configure (store_prefixed_addr , replica_rank , replica_world_size )
512
+ with torch .profiler .record_function ("torchft::manager::_pg.configure" ):
513
+ self ._pg .configure (
514
+ store_prefixed_addr , replica_rank , replica_world_size
515
+ )
509
516
self ._quorum_id = quorum_id
510
517
511
518
if allow_heal :
@@ -520,12 +527,15 @@ def _async_quorum(
520
527
self ._logger .info (
521
528
f"peers need recovery from us { quorum .recover_dst_ranks } "
522
529
)
523
- self ._checkpoint_transport .send_checkpoint (
524
- dst_ranks = quorum .recover_dst_ranks ,
525
- step = max_step ,
526
- state_dict = self ._manager_state_dict (),
527
- timeout = self ._timeout ,
528
- )
530
+ with torch .profiler .record_function (
531
+ "torchft::manager::_checkpoint_transport::send_checkpoint"
532
+ ):
533
+ self ._checkpoint_transport .send_checkpoint (
534
+ dst_ranks = quorum .recover_dst_ranks ,
535
+ step = max_step ,
536
+ state_dict = self ._manager_state_dict (),
537
+ timeout = self ._timeout ,
538
+ )
529
539
530
540
# See manager.rs for healing conditions
531
541
if heal :
@@ -551,14 +561,17 @@ def _async_quorum(
551
561
552
562
# we apply the user state dict only when safe from the main thread
553
563
# save it for now
554
- self ._pending_state_dict = (
555
- self ._checkpoint_transport .recv_checkpoint (
556
- src_rank = recover_src_rank ,
557
- metadata = checkpoint_metadata ,
558
- step = max_step ,
559
- timeout = self ._timeout ,
564
+ with torch .profiler .record_function (
565
+ "torchft::manager::_checkpoint_transport::recv_checkpoint"
566
+ ):
567
+ self ._pending_state_dict = (
568
+ self ._checkpoint_transport .recv_checkpoint (
569
+ src_rank = recover_src_rank ,
570
+ metadata = checkpoint_metadata ,
571
+ step = max_step ,
572
+ timeout = self ._timeout ,
573
+ )
560
574
)
561
- )
562
575
563
576
# pyre-fixme[6]: got object
564
577
self .load_state_dict (self ._pending_state_dict ["torchft" ])
@@ -584,6 +597,7 @@ def _apply_pending_state_dict(self) -> None:
584
597
self ._pending_state_dict = None
585
598
self ._logger .info ("Loaded state dict." )
586
599
600
+ @torch .profiler .record_function ("torchft::manager::should_commit" )
587
601
def should_commit (self , timeout : Optional [timedelta ] = None ) -> bool :
588
602
"""
589
603
.. note::
0 commit comments