@@ -603,27 +603,22 @@ def step_microbatches(
603
603
losses : Optional [List ] = None ,
604
604
):
605
605
"""
606
- # n_loop = n_stage / n_pp
607
- # run microbatches in sequences of NPp
606
+ Operate on the microbatches for interleaved 1f1b schedule (https://arxiv.org/pdf/2104.04473.pdf).
608
607
609
- schedule operates at the rank level
610
-
611
- highest rank has a warmup (F only) count of [len(stages) - 1] * seq_size
612
- each hop away from highest rank adds 2 warmup stages
608
+ Highest rank has a warmup (fwd only) count of [len(stages) - 1] * number of PP ranks
609
+ and each rank away from highest rank adds 2 warmup steps due to:
613
610
- one happened before highest rank's warmup started,
614
611
- one waiting for backward result to trickle down from highest rank
615
- dist_from_highest = (worldsize - 1) - rank
616
-
617
- total_steps = warmup_steps + (num_stages * num_microbatch)
618
612
619
- Rank 0: 0F 0F 0F 0F 2F 2F 2F 2F
620
- Rank 1: 1F 1F 1F 1F 3F3B 3F 3F 3F
613
+ TODO: Interleaved 1F1B does not support using sorted_batch_isend_irecv()
614
+ because it requires recvs and sends from different peers
615
+ to execute in the same coalesced operation. As a result, this schedule does
616
+ not support models with skip connections.
621
617
"""
622
618
arg_mbs , kwarg_mbs = self ._check_inputs (
623
619
arg_mbs , kwarg_mbs , target_mbs , losses
624
620
)
625
621
626
- # warmup steps for latest pp stage is trivial to compute
627
622
# increment warmup_steps by 2 for each hop away
628
623
warmup_steps = (self .n_local_stages - 1 ) * self .pp_group_size
629
624
warmup_steps += 2 * ((self .pp_group_size - 1 ) - self .rank )
@@ -641,7 +636,7 @@ def step_microbatches(
641
636
warmup_steps + fwd_bwd_steps * 2 + cooldown_steps
642
637
== self .n_local_stages * self ._n_microbatches * 2
643
638
)
644
- self . total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
639
+ total_steps = warmup_steps + fwd_bwd_steps + cooldown_steps
645
640
646
641
logger .debug (
647
642
f"""
@@ -669,104 +664,110 @@ def backward_stage_local_index(step):
669
664
# Delay send waits
670
665
sends_to_wait : List [dist .Work ] = []
671
666
672
- for step in range (self .total_steps ):
673
- # warmup, forward only
674
- if step < warmup_steps :
675
- logger .debug (f"{ forward_stage_local_index (step )= } " )
667
+ # Store ops (potentially across steps)
668
+ ops : List [dist .P2POp ] = []
676
669
677
- fwd_stage = self ._stages [forward_stage_local_index (step )]
678
- # assigns the current microbatch index and updates it for future steps
679
- fwd_stage_mb_index [fwd_stage ] = (
680
- mb_index := fwd_stage_mb_index [fwd_stage ]
681
- ) + 1
670
+ # Warmup Phase (forward only)
671
+ for step in range (warmup_steps ):
672
+ fwd_stage = self ._stages [forward_stage_local_index (step )]
682
673
683
- logger .debug (
684
- f"{ self .rank } : { step = } , { fwd_stage .stage_index = } , { mb_index = } "
685
- )
674
+ # This will assign the current microbatch index and update it for future steps
675
+ fwd_stage_mb_index [fwd_stage ] = (
676
+ mb_index := fwd_stage_mb_index [fwd_stage ]
677
+ ) + 1
686
678
687
- with record_function (f"Forward { step } " ):
688
- ops = fwd_stage .get_fwd_recv_ops ()
689
- works = sorted_batch_isend_irecv (ops )
690
- for work in works .values ():
691
- work .wait ()
679
+ logger .debug (
680
+ f"Rank { self .rank } : { step = } , { fwd_stage .stage_index = } , { mb_index = } "
681
+ )
692
682
693
- output = fwd_stage .forward_one_chunk (arg_mbs [mb_index ], kwarg_mbs [mb_index ]) # type: ignore[index]
683
+ with record_function (f"Forward { step } " ):
684
+ ops .extend (fwd_stage .get_fwd_recv_ops ())
685
+ if ops :
686
+ work = dist .batch_isend_irecv (ops ).pop ()
687
+ work .wait ()
688
+ ops .clear ()
694
689
695
- ops = fwd_stage .get_fwd_send_ops ()
696
- works = sorted_batch_isend_irecv (ops )
697
- sends_to_wait .extend (works .values ())
690
+ output = fwd_stage .forward_one_chunk (arg_mbs [mb_index ], kwarg_mbs [mb_index ]) # type: ignore[index]
698
691
699
- self ._maybe_compute_loss (
700
- fwd_stage , output , target_mbs , mb_index
701
- )
692
+ ops .extend (fwd_stage .get_fwd_send_ops ())
693
+ # If we are right before the fwd-bwd step, then we need to delay the send to the next step,
694
+ # This is because fwd-bwd send/recvs among ranks need to be aligned to prevent a hang.
695
+ # In the edge cases where there are no fwd_bwds and cooldown is immediate, then no delay is needed
696
+ if ops and (step != warmup_steps - 1 or fwd_bwd_steps == 0 ):
697
+ work = dist .batch_isend_irecv (ops ).pop ()
698
+ sends_to_wait .append (work )
699
+ ops .clear ()
702
700
703
- # 1f1b
704
- elif warmup_steps <= step < warmup_steps + fwd_bwd_steps :
705
- logger .debug (f"{ forward_stage_local_index (step )= } " )
706
- logger .debug (f"{ backward_stage_local_index (step )= } " )
701
+ self ._maybe_compute_loss (
702
+ fwd_stage , output , target_mbs , mb_index
703
+ )
707
704
708
- fwd_stage = self ._stages [forward_stage_local_index (step )]
709
- bwd_stage = self ._stages [backward_stage_local_index (step )]
705
+ # 1F1B Phase (forward and backward)
706
+ for step in range (warmup_steps , warmup_steps + fwd_bwd_steps ):
707
+ fwd_stage = self ._stages [forward_stage_local_index (step )]
708
+ bwd_stage = self ._stages [backward_stage_local_index (step )]
710
709
711
- fwd_stage_mb_index [fwd_stage ] = (
712
- fwd_mb_index := fwd_stage_mb_index [fwd_stage ]
713
- ) + 1
714
- bwd_stage_mb_index [bwd_stage ] = (
715
- bwd_mb_index := bwd_stage_mb_index [bwd_stage ]
716
- ) + 1
710
+ fwd_stage_mb_index [fwd_stage ] = (
711
+ fwd_mb_index := fwd_stage_mb_index [fwd_stage ]
712
+ ) + 1
713
+ bwd_stage_mb_index [bwd_stage ] = (
714
+ bwd_mb_index := bwd_stage_mb_index [bwd_stage ]
715
+ ) + 1
717
716
718
- bwd_stage ._configure_data_parallel_mode (
719
- bwd_mb_index == self ._n_microbatches - 1
720
- )
721
- logger .debug (
722
- f"{ self .rank } : { step = } , { fwd_stage .stage_index = } , { bwd_stage .stage_index = } , { fwd_mb_index = } , { bwd_mb_index = } "
723
- )
724
- with record_function (f"1F1B { step } " ):
725
- ops = fwd_stage .get_fwd_recv_ops ()
726
- ops .extend (bwd_stage .get_bwd_recv_ops ())
727
- works = sorted_batch_isend_irecv (ops )
728
- for work in works .values ():
729
- work .wait ()
717
+ bwd_stage ._configure_data_parallel_mode (
718
+ bwd_mb_index == self ._n_microbatches - 1
719
+ )
720
+ logger .debug (
721
+ f"Rank { self .rank } : { step = } , { fwd_stage .stage_index = } , { bwd_stage .stage_index = } , { fwd_mb_index = } , { bwd_mb_index = } "
722
+ )
723
+ with record_function (f"1F1B { step } " ):
724
+ ops .extend (fwd_stage .get_fwd_recv_ops ())
725
+ ops .extend (bwd_stage .get_bwd_recv_ops ())
726
+ if ops :
727
+ work = dist .batch_isend_irecv (ops ).pop ()
728
+ work .wait ()
729
+ ops .clear ()
730
730
731
- # fwd
732
- output = fwd_stage .forward_one_chunk (arg_mbs [fwd_mb_index ], kwarg_mbs [fwd_mb_index ]) # type: ignore[index]
733
- ops = fwd_stage .get_fwd_send_ops ()
734
- self ._maybe_compute_loss (
735
- fwd_stage , output , target_mbs , fwd_mb_index
736
- )
731
+ # Forward
732
+ output = fwd_stage .forward_one_chunk (arg_mbs [fwd_mb_index ], kwarg_mbs [fwd_mb_index ]) # type: ignore[index]
733
+ ops . extend ( fwd_stage .get_fwd_send_ops () )
734
+ self ._maybe_compute_loss (
735
+ fwd_stage , output , target_mbs , fwd_mb_index
736
+ )
737
737
738
- # bwd
739
- loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
740
- bwd_stage .backward_one_chunk (loss = loss )
741
- ops .extend (bwd_stage .get_bwd_send_ops ())
738
+ # Backward
739
+ loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
740
+ bwd_stage .backward_one_chunk (loss = loss )
741
+ ops .extend (bwd_stage .get_bwd_send_ops ())
742
+
743
+ # Cooldown Phase (backward only)
744
+ for step in range (warmup_steps + fwd_bwd_steps , total_steps ):
745
+ bwd_stage = self ._stages [backward_stage_local_index (step )]
746
+ bwd_stage_mb_index [bwd_stage ] = (
747
+ bwd_mb_index := bwd_stage_mb_index [bwd_stage ]
748
+ ) + 1
749
+ bwd_stage ._configure_data_parallel_mode (
750
+ bwd_mb_index == self ._n_microbatches - 1
751
+ )
742
752
743
- works = sorted_batch_isend_irecv (ops )
744
- sends_to_wait .extend (works .values ())
745
-
746
- # cooldown
747
- else :
748
- bwd_stage = self ._stages [backward_stage_local_index (step )]
749
- bwd_stage_mb_index [bwd_stage ] = (
750
- bwd_mb_index := bwd_stage_mb_index [bwd_stage ]
751
- ) + 1
752
- bwd_stage ._configure_data_parallel_mode (
753
- bwd_mb_index == self ._n_microbatches - 1
754
- )
755
- logger .debug (
756
- f"{ self .rank } : { step = } , { bwd_stage .stage_index = } , { bwd_mb_index = } "
757
- )
758
- with record_function (f"Cooldown (backward) { step } " ):
759
- ops = bwd_stage .get_bwd_recv_ops ()
760
- works = sorted_batch_isend_irecv (ops )
761
- for work in works .values ():
762
- work .wait ()
753
+ logger .debug (
754
+ f"Rank { self .rank } : { step = } , { bwd_stage .stage_index = } , { bwd_mb_index = } "
755
+ )
756
+ with record_function (f"Cooldown { step } " ):
757
+ ops .extend (bwd_stage .get_bwd_recv_ops ())
758
+ if ops :
759
+ work = dist .batch_isend_irecv (ops ).pop ()
760
+ work .wait ()
761
+ ops .clear ()
763
762
764
- loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
765
- bwd_stage .backward_one_chunk (loss = loss )
763
+ loss = self ._maybe_get_loss (bwd_stage , bwd_mb_index )
764
+ bwd_stage .backward_one_chunk (loss = loss )
766
765
767
- ops = bwd_stage .get_bwd_send_ops ()
768
- works = sorted_batch_isend_irecv (ops )
769
- sends_to_wait .extend (works .values ())
766
+ ops .extend (bwd_stage .get_bwd_send_ops ())
767
+ if ops :
768
+ work = dist .batch_isend_irecv (ops ).pop ()
769
+ sends_to_wait .append (work )
770
+ ops .clear ()
770
771
771
772
# Make sure all sends are finished
772
773
for work in sends_to_wait :
0 commit comments