@@ -594,7 +594,9 @@ def forward_backward_pipeline(
594
594
)
595
595
596
596
self ._record_stamp ("F" , step_id , '"B"' , self ._forward_color )
597
- output_tensor = self ._forward_step (input_tensor , micro_dataset )
597
+ output_tensor = self ._forward_step (
598
+ input_tensor , micro_dataset , step_id = step_id
599
+ )
598
600
self ._record_stamp ("F" , step_id , '"E"' , self ._forward_color )
599
601
self ._p2p_helper .send_forward (
600
602
output_tensor ,
@@ -626,7 +628,9 @@ def forward_backward_pipeline(
626
628
self ._record_stamp (
627
629
"F" , startup_steps + i , '"B"' , self ._forward_color
628
630
)
629
- output_tensor = self ._forward_step (input_tensor , micro_dataset )
631
+ output_tensor = self ._forward_step (
632
+ input_tensor , micro_dataset , step_id = startup_steps + i
633
+ )
630
634
self ._record_stamp (
631
635
"F" , startup_steps + i , '"E"' , self ._forward_color
632
636
)
@@ -649,7 +653,7 @@ def forward_backward_pipeline(
649
653
650
654
self ._record_stamp ("B" , i , '"B"' , self ._backward_color )
651
655
input_tensor_grad = self ._backward_step (
652
- input_tensor , output_tensor , output_tensor_grad
656
+ input_tensor , output_tensor , output_tensor_grad , step_id = i
653
657
)
654
658
self ._record_stamp ("B" , i , '"E"' , self ._backward_color )
655
659
@@ -684,7 +688,10 @@ def forward_backward_pipeline(
684
688
"B" , steady_steps + i , '"B"' , self ._backward_color
685
689
)
686
690
input_tensor_grad = self ._backward_step (
687
- input_tensor , output_tensor , output_tensor_grad
691
+ input_tensor ,
692
+ output_tensor ,
693
+ output_tensor_grad ,
694
+ step_id = steady_steps + i ,
688
695
)
689
696
self ._record_stamp (
690
697
"B" , steady_steps + i , '"E"' , self ._backward_color
@@ -844,7 +851,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
844
851
self .is_pipeline_first_stage ()
845
852
)
846
853
847
- output_tensor = self ._forward_step (input_tensor , micro_dataset )
854
+ output_tensor = self ._forward_step (
855
+ input_tensor , micro_dataset , step_id = None
856
+ )
848
857
self ._p2p_helper .send_forward (
849
858
output_tensor ,
850
859
self .is_pipeline_last_stage (),
@@ -862,7 +871,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
862
871
for i in range (steady_steps ):
863
872
last_iter = i == (steady_steps - 1 )
864
873
865
- output_tensor = self ._forward_step (input_tensor , micro_dataset )
874
+ output_tensor = self ._forward_step (
875
+ input_tensor , micro_dataset , step_id = None
876
+ )
866
877
self ._p2p_helper .send_forward (
867
878
output_tensor ,
868
879
self .is_pipeline_last_stage (),
@@ -884,7 +895,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
884
895
885
896
return self .train_loss
886
897
887
- def _forward_step (self , input_tensor , micro_dataset , chunk_id = None ):
898
+ def _forward_step (
899
+ self , input_tensor , micro_dataset , chunk_id = None , step_id = None
900
+ ):
888
901
if self ._enable_timer :
889
902
self .timers ("forward_step" ).start ()
890
903
if self .is_pipeline_first_stage ():
@@ -893,7 +906,18 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
893
906
894
907
assert chunk_id is None or isinstance (chunk_id , int )
895
908
909
+ self .callbacks .on_location (
910
+ PipelineParallelMicroStepLocations .FORWARD_BEGIN ,
911
+ input_tensor = input_tensor ,
912
+ step_id = step_id ,
913
+ )
896
914
output_tensor = self ._layers .forward (input_tensor , chunk_id = chunk_id )
915
+ self .callbacks .on_location (
916
+ PipelineParallelMicroStepLocations .FORWARD_END ,
917
+ input_tensor = input_tensor ,
918
+ output_tensor = output_tensor ,
919
+ step_id = step_id ,
920
+ )
897
921
898
922
if self .is_pipeline_last_stage ():
899
923
# train calculate loss for train
@@ -935,10 +959,19 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
935
959
return backward_loss_tensor
936
960
return output_tensor
937
961
938
- def _backward_step (self , input_tensor , output_tensor , output_tensor_grad ):
962
+ def _backward_step (
963
+ self , input_tensor , output_tensor , output_tensor_grad , step_id = None
964
+ ):
939
965
if self ._enable_timer :
940
966
self .timers ("backward_step" ).start ()
941
967
with paddle .amp .auto_cast (enable = False ):
968
+ self .callbacks .on_location (
969
+ PipelineParallelMicroStepLocations .BACKWARD_BEGIN ,
970
+ input_tensor = input_tensor ,
971
+ output_tensor = output_tensor ,
972
+ output_tensor_grad = output_tensor_grad ,
973
+ step_id = step_id ,
974
+ )
942
975
if self .is_pipeline_last_stage ():
943
976
assert output_tensor_grad is None
944
977
if self .scaler :
@@ -969,6 +1002,14 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
969
1002
input_tensor_grad = input_tensor .grad
970
1003
if self ._enable_timer :
971
1004
self .timers ("backward_step" ).stop ()
1005
+ self .callbacks .on_location (
1006
+ PipelineParallelMicroStepLocations .BACKWARD_END ,
1007
+ input_tensor = input_tensor ,
1008
+ output_tensor = output_tensor ,
1009
+ input_tensor_grad = input_tensor_grad ,
1010
+ output_tensor_grad = output_tensor_grad ,
1011
+ step_id = step_id ,
1012
+ )
972
1013
return input_tensor_grad
973
1014
974
1015
def _check_micro_batch_data_valid (self , micro_batch_data ):
@@ -1217,7 +1258,7 @@ def _forward_step_helper(self, micro_dataset, micro_step):
1217
1258
)
1218
1259
input_tensor = self .input_tensors [virtual_pp_rank ][- 1 ]
1219
1260
output_tensor = self ._forward_step (
1220
- input_tensor , micro_dataset , virtual_pp_rank
1261
+ input_tensor , micro_dataset , virtual_pp_rank , step_id = micro_step
1221
1262
)
1222
1263
self .output_tensors [virtual_pp_rank ].append (output_tensor )
1223
1264
@@ -1281,7 +1322,7 @@ def _backward_step_helper(self, micro_step):
1281
1322
output_tensor = self .output_tensors [virtual_pp_rank ].pop (0 )
1282
1323
output_tensor_grad = self .output_tensor_grads [virtual_pp_rank ].pop (0 )
1283
1324
input_tensor_grad = self ._backward_step (
1284
- input_tensor , output_tensor , output_tensor_grad
1325
+ input_tensor , output_tensor , output_tensor_grad , step_id = micro_step
1285
1326
)
1286
1327
1287
1328
self ._overlap_comm_grads ()
0 commit comments