Skip to content

Commit ca3d34d

Browse files
authored
[CINN] Fix no_need_buffer bug (#70349)
* fix no_need_buffer bug * update
1 parent 6a4f37f commit ca3d34d

File tree

1 file changed

+42
-28
lines changed

1 file changed

+42
-28
lines changed

python/paddle/decomposition/recompute.py

Lines changed: 42 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@
6666
"pd_op.sin",
6767
"pd_op.cos",
6868
"pd_op.add_n",
69-
"pd_op.any",
69+
# "pd_op.any",
7070
"pd_op.cast",
7171
"pd_op.concat",
7272
"pd_op.full_with_tensor",
@@ -80,7 +80,7 @@
8080
"pd_op.slice",
8181
"pd_op.squeeze",
8282
"pd_op.unsqueeze",
83-
"pd_op.transpose",
83+
# "pd_op.transpose",
8484
# "pd_op.prod",
8585
"pd_op.log",
8686
"pd_op.log1p",
@@ -431,6 +431,22 @@ def auto_recompute(
431431

432432
fusible_ops = recomputable_ops | set(random_ops)
433433

434+
def _get_bw_no_need_buffer_values(program, backward_op_start_idx):
435+
need_buffer_values = backward_utils.ValueSet()
436+
all_values = backward_utils.ValueSet()
437+
for op in program.global_block().ops[backward_op_start_idx:]:
438+
for op_operand_source in op.operands_source():
439+
all_values.add(op_operand_source)
440+
if op.is_no_need_buffer(op_operand_source):
441+
continue
442+
need_buffer_values.add(op_operand_source)
443+
bw_no_need_buffer_values = all_values - need_buffer_values
444+
return bw_no_need_buffer_values
445+
446+
bw_no_need_buffer_values = _get_bw_no_need_buffer_values(
447+
program, backward_op_start_idx
448+
)
449+
434450
def _is_fusible(value_node1, value_node2):
435451
return (
436452
value_node1.get_defining_op().name() in fusible_ops
@@ -442,7 +458,9 @@ def _is_materialized_backwards(value_node):
442458
cur_value_nodes.add(value_node)
443459
while len(cur_value_nodes) > 0:
444460
cur_value_node = cur_value_nodes.pop()
445-
users = find_value_node_users(cur_value_node)
461+
users = find_value_node_users(
462+
cur_value_node, bw_no_need_buffer_values, True
463+
)
446464
for user in users:
447465
if user not in required_fw_value_nodes and not _is_fusible(
448466
cur_value_node, user
@@ -458,34 +476,19 @@ def _is_materialized_backwards(value_node):
458476
def _is_materialized(value_node, placeholder_value_nodes):
459477
if value_node in placeholder_value_nodes:
460478
return True
461-
users = find_value_node_users(value_node)
479+
users = find_value_node_users(
480+
value_node, bw_no_need_buffer_values, True
481+
)
462482
return not all(_is_fusible(value_node, user) for user in users)
463483

464-
def _get_no_need_buffer_values_from_program(program):
465-
need_buffer_values = backward_utils.ValueSet()
466-
all_values = backward_utils.ValueSet()
467-
for op in program.global_block().ops:
468-
for op_operand_source in op.operands_source():
469-
all_values.add(op_operand_source)
470-
if op.is_no_need_buffer(op_operand_source):
471-
continue
472-
need_buffer_values.add(op_operand_source)
473-
no_need_buffer_values = all_values - need_buffer_values
474-
return no_need_buffer_values
475-
476-
def _get_node_weight(
477-
value_node, no_need_buffer_values, placeholder_value_nodes
478-
):
479-
if value_node in no_need_buffer_values:
480-
return MINIMUM_WEIGHT
481-
484+
def _get_node_weight(value_node, placeholder_value_nodes):
482485
mem_sz = cal_value_node_size(value_node)
483486

484487
if (
485488
value_node.get_defining_op().name() in tending_to_recompute_ops
486489
and mem_sz == 0
487490
):
488-
return 0.1
491+
return MINIMUM_WEIGHT
489492

490493
# Heuristic to bias towards nodes closer to the backwards pass
491494
mem_sz = int(
@@ -532,7 +535,6 @@ def _ban_recomputation(value_node):
532535

533536
judge_fusion_loop = JudgeFusionLoop(program, unrecomputable_ops)
534537
forward_ops = set(program.global_block().ops[: fwd_op_end_idx + 1])
535-
no_need_buffer_values = _get_no_need_buffer_values_from_program(program)
536538

537539
for value_node in (
538540
required_fw_value_nodes
@@ -592,7 +594,6 @@ def _ban_recomputation(value_node):
592594

593595
weight = _get_node_weight(
594596
value_node,
595-
no_need_buffer_values,
596597
placeholder_value_nodes=inputs | outputs,
597598
)
598599

@@ -602,7 +603,9 @@ def _ban_recomputation(value_node):
602603
)
603604
value_id_dict[value_node.id] = value_node
604605

605-
users = find_value_node_users(value_node)
606+
users = find_value_node_users(
607+
value_node, bw_no_need_buffer_values, True
608+
)
606609
for user in users:
607610
DebugPrint(
608611
"add edge link from: ",
@@ -669,6 +672,7 @@ def _ban_recomputation(value_node):
669672
saved_values,
670673
inputs,
671674
outputs,
675+
bw_no_need_buffer_values,
672676
fwd_op_end_idx,
673677
backward_op_start_idx,
674678
)
@@ -685,6 +689,7 @@ def partition_joint_graph(
685689
saved_values: list[pir.Value],
686690
inputs: list[pir.Value],
687691
outputs: list[pir.Value],
692+
bw_no_need_buffer_values: list[pir.Value],
688693
fwd_op_end_idx: int,
689694
backward_op_start_idx: int,
690695
) -> tuple[paddle.static.Program, int]:
@@ -715,6 +720,7 @@ def partition_joint_graph(
715720
saved_values,
716721
inputs,
717722
outputs,
723+
bw_no_need_buffer_values,
718724
fwd_op_end_idx,
719725
backward_op_start_idx,
720726
)
@@ -917,7 +923,10 @@ def classify_value_node(program, grad_outputs, fwd_op_end_idx):
917923
)
918924

919925

920-
def find_value_node_users(value_node):
926+
# Sometimes we need to discard no_need_buffer values because they‘re not REAL tensor users.
927+
def find_value_node_users(
928+
value_node, bw_no_need_buffer_values={}, without_no_need_buffer=False
929+
):
921930
'''
922931
Find all the value nodes which use the same value node to be computed.
923932
'''
@@ -939,6 +948,9 @@ def find_value_node_users(value_node):
939948
else:
940949
users.add(result)
941950
else:
951+
if without_no_need_buffer:
952+
if value_node in bw_no_need_buffer_values:
953+
continue
942954
results = op.results()
943955
for result in results:
944956
if len(result.all_used_ops()) == 1 and result.all_used_ops()[
@@ -1057,6 +1069,7 @@ def analyze_mid_hold_values(
10571069
saved_values,
10581070
inputs,
10591071
outputs,
1072+
no_need_buffer_values,
10601073
fwd_op_end_idx,
10611074
backward_op_start_idx,
10621075
):
@@ -1067,10 +1080,11 @@ def analyze_mid_hold_values(
10671080
for result in op.results():
10681081
all_used_ops = all_used_op_consider_combine(program, result)
10691082
if (
1070-
any(op in backward_ops for op in all_used_ops)
1083+
any(used_op in backward_ops for used_op in all_used_ops)
10711084
and result not in saved_values
10721085
and result not in outputs
10731086
and result not in inputs
1087+
and result not in no_need_buffer_values
10741088
):
10751089
mid_hold_values.add(result)
10761090
return mid_hold_values

0 commit comments

Comments
 (0)