@@ -5715,6 +5715,35 @@ def _mv_head_recv(self, program):
57155715 backward_insert_index += 1
57165716 block ._sync_with_cpp ()
57175717
5718+ def _check_pipeline_persist_var (self , program ):
5719+ """
5720+ Pipeline may need multiple forward before
5721+ """
5722+ block = program .global_block ()
5723+
5724+ persist_output = set ()
5725+ used_in_backward = set ()
5726+ for op in block .ops :
5727+ if self ._is_forward_op (op ):
5728+ for var_name in op .output_arg_names :
5729+ var = block .vars [var_name ]
5730+ if var .persistable :
5731+ persist_output .add (var_name )
5732+ elif self ._is_backward_op (op ):
5733+ for var_name in op .input_arg_names :
5734+ if var_name in persist_output :
5735+ used_in_backward .add (var_name )
5736+ if len (used_in_backward ) == 0 :
5737+ return
5738+ warnings .warn (
5739+ "The pipeline requires multiple forward calculations before backward, "
5740+ "so when the persistable var is changed in the forward, it may cause "
5741+ "errors in the backward calculation who using this persistable var. "
5742+ "However, some backward op don't need this var(NoNeedBufferVars), "
5743+ "there will be no error at this time.\n "
5744+ "So please check these persistable vars which changed in "
5745+ "forward and used in backward:\n {}" .format (used_in_backward ))
5746+
57185747 def minimize (self ,
57195748 loss ,
57205749 startup_program = None ,
@@ -5831,6 +5860,11 @@ def device_cmp(device1, device2):
58315860 # A pass to move the recv op to the beginning of
58325861 # the forward/backward phase
58335862 self ._mv_head_recv (program_list [self .local_rank ])
5863+
5864+ # A pass to check pipeline persist var which changed in
5865+ # forward and used in backward
5866+ self ._check_pipeline_persist_var (program_list [self .local_rank ])
5867+
58345868 main_program ._pipeline_opt = {
58355869 "trainer" : "PipelineTrainer" ,
58365870 "device_worker" : "Section" ,
0 commit comments