66
66
"pd_op.sin" ,
67
67
"pd_op.cos" ,
68
68
"pd_op.add_n" ,
69
- "pd_op.any" ,
69
+ # "pd_op.any",
70
70
"pd_op.cast" ,
71
71
"pd_op.concat" ,
72
72
"pd_op.full_with_tensor" ,
80
80
"pd_op.slice" ,
81
81
"pd_op.squeeze" ,
82
82
"pd_op.unsqueeze" ,
83
- "pd_op.transpose" ,
83
+ # "pd_op.transpose",
84
84
# "pd_op.prod",
85
85
"pd_op.log" ,
86
86
"pd_op.log1p" ,
@@ -431,6 +431,22 @@ def auto_recompute(
431
431
432
432
fusible_ops = recomputable_ops | set (random_ops )
433
433
434
+ def _get_bw_no_need_buffer_values (program , backward_op_start_idx ):
435
+ need_buffer_values = backward_utils .ValueSet ()
436
+ all_values = backward_utils .ValueSet ()
437
+ for op in program .global_block ().ops [backward_op_start_idx :]:
438
+ for op_operand_source in op .operands_source ():
439
+ all_values .add (op_operand_source )
440
+ if op .is_no_need_buffer (op_operand_source ):
441
+ continue
442
+ need_buffer_values .add (op_operand_source )
443
+ bw_no_need_buffer_values = all_values - need_buffer_values
444
+ return bw_no_need_buffer_values
445
+
446
+ bw_no_need_buffer_values = _get_bw_no_need_buffer_values (
447
+ program , backward_op_start_idx
448
+ )
449
+
434
450
def _is_fusible (value_node1 , value_node2 ):
435
451
return (
436
452
value_node1 .get_defining_op ().name () in fusible_ops
@@ -442,7 +458,9 @@ def _is_materialized_backwards(value_node):
442
458
cur_value_nodes .add (value_node )
443
459
while len (cur_value_nodes ) > 0 :
444
460
cur_value_node = cur_value_nodes .pop ()
445
- users = find_value_node_users (cur_value_node )
461
+ users = find_value_node_users (
462
+ cur_value_node , bw_no_need_buffer_values , True
463
+ )
446
464
for user in users :
447
465
if user not in required_fw_value_nodes and not _is_fusible (
448
466
cur_value_node , user
@@ -458,34 +476,19 @@ def _is_materialized_backwards(value_node):
458
476
def _is_materialized (value_node , placeholder_value_nodes ):
459
477
if value_node in placeholder_value_nodes :
460
478
return True
461
- users = find_value_node_users (value_node )
479
+ users = find_value_node_users (
480
+ value_node , bw_no_need_buffer_values , True
481
+ )
462
482
return not all (_is_fusible (value_node , user ) for user in users )
463
483
464
- def _get_no_need_buffer_values_from_program (program ):
465
- need_buffer_values = backward_utils .ValueSet ()
466
- all_values = backward_utils .ValueSet ()
467
- for op in program .global_block ().ops :
468
- for op_operand_source in op .operands_source ():
469
- all_values .add (op_operand_source )
470
- if op .is_no_need_buffer (op_operand_source ):
471
- continue
472
- need_buffer_values .add (op_operand_source )
473
- no_need_buffer_values = all_values - need_buffer_values
474
- return no_need_buffer_values
475
-
476
- def _get_node_weight (
477
- value_node , no_need_buffer_values , placeholder_value_nodes
478
- ):
479
- if value_node in no_need_buffer_values :
480
- return MINIMUM_WEIGHT
481
-
484
+ def _get_node_weight (value_node , placeholder_value_nodes ):
482
485
mem_sz = cal_value_node_size (value_node )
483
486
484
487
if (
485
488
value_node .get_defining_op ().name () in tending_to_recompute_ops
486
489
and mem_sz == 0
487
490
):
488
- return 0.1
491
+ return MINIMUM_WEIGHT
489
492
490
493
# Heuristic to bias towards nodes closer to the backwards pass
491
494
mem_sz = int (
@@ -532,7 +535,6 @@ def _ban_recomputation(value_node):
532
535
533
536
judge_fusion_loop = JudgeFusionLoop (program , unrecomputable_ops )
534
537
forward_ops = set (program .global_block ().ops [: fwd_op_end_idx + 1 ])
535
- no_need_buffer_values = _get_no_need_buffer_values_from_program (program )
536
538
537
539
for value_node in (
538
540
required_fw_value_nodes
@@ -592,7 +594,6 @@ def _ban_recomputation(value_node):
592
594
593
595
weight = _get_node_weight (
594
596
value_node ,
595
- no_need_buffer_values ,
596
597
placeholder_value_nodes = inputs | outputs ,
597
598
)
598
599
@@ -602,7 +603,9 @@ def _ban_recomputation(value_node):
602
603
)
603
604
value_id_dict [value_node .id ] = value_node
604
605
605
- users = find_value_node_users (value_node )
606
+ users = find_value_node_users (
607
+ value_node , bw_no_need_buffer_values , True
608
+ )
606
609
for user in users :
607
610
DebugPrint (
608
611
"add edge link from: " ,
@@ -669,6 +672,7 @@ def _ban_recomputation(value_node):
669
672
saved_values ,
670
673
inputs ,
671
674
outputs ,
675
+ bw_no_need_buffer_values ,
672
676
fwd_op_end_idx ,
673
677
backward_op_start_idx ,
674
678
)
@@ -685,6 +689,7 @@ def partition_joint_graph(
685
689
saved_values : list [pir .Value ],
686
690
inputs : list [pir .Value ],
687
691
outputs : list [pir .Value ],
692
+ bw_no_need_buffer_values : list [pir .Value ],
688
693
fwd_op_end_idx : int ,
689
694
backward_op_start_idx : int ,
690
695
) -> tuple [paddle .static .Program , int ]:
@@ -715,6 +720,7 @@ def partition_joint_graph(
715
720
saved_values ,
716
721
inputs ,
717
722
outputs ,
723
+ bw_no_need_buffer_values ,
718
724
fwd_op_end_idx ,
719
725
backward_op_start_idx ,
720
726
)
@@ -917,7 +923,10 @@ def classify_value_node(program, grad_outputs, fwd_op_end_idx):
917
923
)
918
924
919
925
920
- def find_value_node_users (value_node ):
926
+ # Sometimes we need to discard no_need_buffer values because they‘re not REAL tensor users.
927
+ def find_value_node_users (
928
+ value_node , bw_no_need_buffer_values = {}, without_no_need_buffer = False
929
+ ):
921
930
'''
922
931
Find all the value nodes which use the same value node to be computed.
923
932
'''
@@ -939,6 +948,9 @@ def find_value_node_users(value_node):
939
948
else :
940
949
users .add (result )
941
950
else :
951
+ if without_no_need_buffer :
952
+ if value_node in bw_no_need_buffer_values :
953
+ continue
942
954
results = op .results ()
943
955
for result in results :
944
956
if len (result .all_used_ops ()) == 1 and result .all_used_ops ()[
@@ -1057,6 +1069,7 @@ def analyze_mid_hold_values(
1057
1069
saved_values ,
1058
1070
inputs ,
1059
1071
outputs ,
1072
+ no_need_buffer_values ,
1060
1073
fwd_op_end_idx ,
1061
1074
backward_op_start_idx ,
1062
1075
):
@@ -1067,10 +1080,11 @@ def analyze_mid_hold_values(
1067
1080
for result in op .results ():
1068
1081
all_used_ops = all_used_op_consider_combine (program , result )
1069
1082
if (
1070
- any (op in backward_ops for op in all_used_ops )
1083
+ any (used_op in backward_ops for used_op in all_used_ops )
1071
1084
and result not in saved_values
1072
1085
and result not in outputs
1073
1086
and result not in inputs
1087
+ and result not in no_need_buffer_values
1074
1088
):
1075
1089
mid_hold_values .add (result )
1076
1090
return mid_hold_values
0 commit comments