Skip to content

Commit ca070c5

Browse files
committed
fix pp hook
1 parent 50861cf commit ca070c5

File tree

1 file changed

+51
-10
lines changed

1 file changed

+51
-10
lines changed

python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -594,7 +594,9 @@ def forward_backward_pipeline(
594594
)
595595

596596
self._record_stamp("F", step_id, '"B"', self._forward_color)
597-
output_tensor = self._forward_step(input_tensor, micro_dataset)
597+
output_tensor = self._forward_step(
598+
input_tensor, micro_dataset, step_id=step_id
599+
)
598600
self._record_stamp("F", step_id, '"E"', self._forward_color)
599601
self._p2p_helper.send_forward(
600602
output_tensor,
@@ -626,7 +628,9 @@ def forward_backward_pipeline(
626628
self._record_stamp(
627629
"F", startup_steps + i, '"B"', self._forward_color
628630
)
629-
output_tensor = self._forward_step(input_tensor, micro_dataset)
631+
output_tensor = self._forward_step(
632+
input_tensor, micro_dataset, step_id=startup_steps + i
633+
)
630634
self._record_stamp(
631635
"F", startup_steps + i, '"E"', self._forward_color
632636
)
@@ -649,7 +653,7 @@ def forward_backward_pipeline(
649653

650654
self._record_stamp("B", i, '"B"', self._backward_color)
651655
input_tensor_grad = self._backward_step(
652-
input_tensor, output_tensor, output_tensor_grad
656+
input_tensor, output_tensor, output_tensor_grad, step_id=i
653657
)
654658
self._record_stamp("B", i, '"E"', self._backward_color)
655659

@@ -684,7 +688,10 @@ def forward_backward_pipeline(
684688
"B", steady_steps + i, '"B"', self._backward_color
685689
)
686690
input_tensor_grad = self._backward_step(
687-
input_tensor, output_tensor, output_tensor_grad
691+
input_tensor,
692+
output_tensor,
693+
output_tensor_grad,
694+
step_id=steady_steps + i,
688695
)
689696
self._record_stamp(
690697
"B", steady_steps + i, '"E"', self._backward_color
@@ -844,7 +851,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
844851
self.is_pipeline_first_stage()
845852
)
846853

847-
output_tensor = self._forward_step(input_tensor, micro_dataset)
854+
output_tensor = self._forward_step(
855+
input_tensor, micro_dataset, step_id=None
856+
)
848857
self._p2p_helper.send_forward(
849858
output_tensor,
850859
self.is_pipeline_last_stage(),
@@ -862,7 +871,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
862871
for i in range(steady_steps):
863872
last_iter = i == (steady_steps - 1)
864873

865-
output_tensor = self._forward_step(input_tensor, micro_dataset)
874+
output_tensor = self._forward_step(
875+
input_tensor, micro_dataset, step_id=None
876+
)
866877
self._p2p_helper.send_forward(
867878
output_tensor,
868879
self.is_pipeline_last_stage(),
@@ -884,7 +895,9 @@ def eval_batch(self, data, compute_loss=False, loss_fn_idx=0):
884895

885896
return self.train_loss
886897

887-
def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
898+
def _forward_step(
899+
self, input_tensor, micro_dataset, chunk_id=None, step_id=None
900+
):
888901
if self._enable_timer:
889902
self.timers("forward_step").start()
890903
if self.is_pipeline_first_stage():
@@ -893,7 +906,18 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
893906

894907
assert chunk_id is None or isinstance(chunk_id, int)
895908

909+
self.callbacks.on_location(
910+
PipelineParallelMicroStepLocations.FORWARD_BEGIN,
911+
input_tensor=input_tensor,
912+
step_id=step_id,
913+
)
896914
output_tensor = self._layers.forward(input_tensor, chunk_id=chunk_id)
915+
self.callbacks.on_location(
916+
PipelineParallelMicroStepLocations.FORWARD_END,
917+
input_tensor=input_tensor,
918+
output_tensor=output_tensor,
919+
step_id=step_id,
920+
)
897921

898922
if self.is_pipeline_last_stage():
899923
# train calculate loss for train
@@ -935,10 +959,19 @@ def _forward_step(self, input_tensor, micro_dataset, chunk_id=None):
935959
return backward_loss_tensor
936960
return output_tensor
937961

938-
def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
962+
def _backward_step(
963+
self, input_tensor, output_tensor, output_tensor_grad, step_id=None
964+
):
939965
if self._enable_timer:
940966
self.timers("backward_step").start()
941967
with paddle.amp.auto_cast(enable=False):
968+
self.callbacks.on_location(
969+
PipelineParallelMicroStepLocations.BACKWARD_BEGIN,
970+
input_tensor=input_tensor,
971+
output_tensor=output_tensor,
972+
output_tensor_grad=output_tensor_grad,
973+
step_id=step_id,
974+
)
942975
if self.is_pipeline_last_stage():
943976
assert output_tensor_grad is None
944977
if self.scaler:
@@ -969,6 +1002,14 @@ def _backward_step(self, input_tensor, output_tensor, output_tensor_grad):
9691002
input_tensor_grad = input_tensor.grad
9701003
if self._enable_timer:
9711004
self.timers("backward_step").stop()
1005+
self.callbacks.on_location(
1006+
PipelineParallelMicroStepLocations.BACKWARD_END,
1007+
input_tensor=input_tensor,
1008+
output_tensor=output_tensor,
1009+
input_tensor_grad=input_tensor_grad,
1010+
output_tensor_grad=output_tensor_grad,
1011+
step_id=step_id,
1012+
)
9721013
return input_tensor_grad
9731014

9741015
def _check_micro_batch_data_valid(self, micro_batch_data):
@@ -1217,7 +1258,7 @@ def _forward_step_helper(self, micro_dataset, micro_step):
12171258
)
12181259
input_tensor = self.input_tensors[virtual_pp_rank][-1]
12191260
output_tensor = self._forward_step(
1220-
input_tensor, micro_dataset, virtual_pp_rank
1261+
input_tensor, micro_dataset, virtual_pp_rank, step_id=micro_step
12211262
)
12221263
self.output_tensors[virtual_pp_rank].append(output_tensor)
12231264

@@ -1281,7 +1322,7 @@ def _backward_step_helper(self, micro_step):
12811322
output_tensor = self.output_tensors[virtual_pp_rank].pop(0)
12821323
output_tensor_grad = self.output_tensor_grads[virtual_pp_rank].pop(0)
12831324
input_tensor_grad = self._backward_step(
1284-
input_tensor, output_tensor, output_tensor_grad
1325+
input_tensor, output_tensor, output_tensor_grad, step_id=micro_step
12851326
)
12861327

12871328
self._overlap_comm_grads()

0 commit comments

Comments
 (0)