Skip to content

Commit 649bf3b

Browse files
【dygraph】Split backward api from api.h to backward_api.h (#71142)
* add gard_api inplace version * split api.h and backward_api.h * modify build wrong * modify build bug * change backward_api_yaml args * change backward_api_yaml args * modify impl to base
1 parent d58f2c6 commit 649bf3b

26 files changed

+284
-87
lines changed

.gitignore

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,10 @@ paddle/fluid/op_use_default_grad_maker_PR.spec
88
paddle/fluid/operators/ops_extra_info.cc
99
paddle/phi/api/backward/backward_api.h
1010
paddle/phi/api/backward/fused_backward_api.h
11-
paddle/phi/api/backward/sparse_bw_api.h
11+
paddle/phi/api/backward/sparse_backward_api.h
12+
paddle/phi/api/backward/backward_api_base.h
13+
paddle/phi/api/backward/fused_backward_api_base.h
14+
paddle/phi/api/backward/sparse_backward_api_base.h
1215
paddle/phi/api/include/api.h
1316
paddle/phi/api/include/fused_api.h
1417
paddle/phi/api/include/operants_base.h
@@ -19,12 +22,15 @@ paddle/phi/api/include/tensor_operants.h
1922
paddle/phi/api/lib/api.cc
2023
paddle/phi/api/lib/fused_api.cc
2124
paddle/phi/api/lib/dygraph_api.*
22-
paddle/phi/api/lib/backward_api.cc
23-
paddle/phi/api/lib/fused_backward_api.cc
2425
paddle/phi/api/lib/operants_manager.cc
2526
paddle/phi/api/lib/sparse_api.cc
2627
paddle/phi/api/lib/strings_api.cc
27-
paddle/phi/api/lib/sparse_bw_api.cc
28+
paddle/phi/api/lib/backward_api.cc
29+
paddle/phi/api/lib/fused_backward_api.cc
30+
paddle/phi/api/lib/sparse_backward_api.cc
31+
paddle/phi/api/lib/backward_api_base.cc
32+
paddle/phi/api/lib/fused_backward_api_base.cc
33+
paddle/phi/api/lib/sparse_backward_api_base.cc
2834
paddle/phi/api/lib/tensor_api.cc
2935
paddle/phi/api/lib/tensor_operants.cc
3036
paddle/phi/extension.h

paddle/fluid/eager/api/generated/eager_generated/forwards/CMakeLists.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@ cc_library(
66
if(NOT (NOT WITH_PYTHON AND ON_INFER))
77
cc_library(
88
final_dygraph_function
9-
SRCS dygraph_functions.cc ${eager_manual_functions}
9+
SRCS dygraph_functions.cc dygraph_grad_functions.cc
10+
${eager_manual_functions}
1011
DEPS ${eager_deps} final_dygraph_node)
1112
add_dependencies(final_dygraph_function eager_codegen)
1213
endif()

paddle/fluid/eager/api/manual/eager_manual/nodes/conv2d_nodes.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020
#include "paddle/fluid/framework/op_registry.h"
2121
#include "paddle/fluid/imperative/tracer.h"
2222
#include "paddle/phi/api/all.h"
23-
#include "paddle/phi/api/backward/backward_api.h"
24-
#include "paddle/phi/api/backward/sparse_bw_api.h"
23+
#include "paddle/phi/api/backward/backward_api_base.h"
24+
#include "paddle/phi/api/backward/sparse_backward_api_base.h"
2525
#include "paddle/phi/core/platform/profiler/event_tracing.h"
2626

2727
#include "paddle/common/flags.h"

paddle/fluid/eager/api/manual/eager_manual/nodes/multiply_node.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@
2525
#include "paddle/fluid/prim/api/composite_backward/composite_backward_api.h"
2626
#include "paddle/fluid/prim/utils/utils.h"
2727
#include "paddle/phi/api/all.h"
28-
#include "paddle/phi/api/backward/backward_api.h"
29-
#include "paddle/phi/api/backward/sparse_bw_api.h"
28+
#include "paddle/phi/api/backward/backward_api_base.h"
29+
#include "paddle/phi/api/backward/sparse_backward_api_base.h"
3030
#include "paddle/phi/api/include/sparse_api.h"
3131
#include "paddle/phi/api/lib/api_custom_impl.h"
3232
#include "paddle/phi/core/platform/profiler/event_tracing.h"

paddle/fluid/eager/api/manual/eager_manual/nodes/sync_batch_norm_node.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@
2222
#include "paddle/fluid/framework/op_registry.h"
2323
#include "paddle/fluid/imperative/tracer.h"
2424
#include "paddle/phi/api/all.h"
25-
#include "paddle/phi/api/backward/backward_api.h"
26-
#include "paddle/phi/api/backward/sparse_bw_api.h"
25+
#include "paddle/phi/api/backward/backward_api_base.h"
26+
#include "paddle/phi/api/backward/sparse_backward_api_base.h"
2727
#include "paddle/phi/api/include/sparse_api.h"
2828
#include "paddle/phi/api/lib/api_custom_impl.h"
2929
#include "paddle/phi/core/platform/profiler/event_tracing.h"

paddle/fluid/eager/auto_code_generator/generate_file_structures.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@ def GenerateFileStructureForFinalDygraph(eager_dir):
2626
| |- forwards
2727
| |- "dygraph_functions.cc"
2828
| |- "dygraph_functions.h"
29+
| |- "dygraph_grad_functions.cc"
30+
| |- "dygraph_grad_functions.h"
2931
|
3032
| |- backwards
3133
| |- "nodes.cc"
@@ -42,10 +44,12 @@ def GenerateFileStructureForFinalDygraph(eager_dir):
4244

4345
# Empty files
4446
dygraph_forward_api_h_path = os.path.join(
45-
generated_dir, "dygraph_functions.h"
47+
forwards_dir, "dygraph_functions.h"
4648
)
4749
empty_files = [dygraph_forward_api_h_path]
50+
empty_files.append(os.path.join(forwards_dir, "dygraph_grad_functions.h"))
4851
empty_files.append(os.path.join(forwards_dir, "dygraph_functions.cc"))
52+
empty_files.append(os.path.join(forwards_dir, "dygraph_grad_functions.cc"))
4953
empty_files.append(os.path.join(nodes_dir, "nodes.cc"))
5054
empty_files.append(os.path.join(nodes_dir, "nodes.h"))
5155

paddle/fluid/eager/auto_code_generator/generator/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,12 @@ set(tmp_forwards_cc_path
1111
set(tmp_forwards_h_path
1212
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_functions.h"
1313
)
14+
set(tmp_backwards_cc_path
15+
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_grad_functions.cc"
16+
)
17+
set(tmp_backwards_h_path
18+
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/tmp_dygraph_grad_functions.h"
19+
)
1420
set(tmp_nodes_cc_path
1521
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/tmp_nodes.cc"
1622
)
@@ -23,6 +29,12 @@ set(forwards_cc_path
2329
set(forwards_h_path
2430
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
2531
)
32+
set(backwards_cc_path
33+
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_grad_functions.cc"
34+
)
35+
set(backwards_h_path
36+
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_grad_functions.h"
37+
)
2638
set(nodes_cc_path
2739
"${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.cc"
2840
)
@@ -43,11 +55,17 @@ add_custom_target(
4355
"--backward_yaml_path=${backward_yaml_path}"
4456
"--forwards_cc_path=${tmp_forwards_cc_path}"
4557
"--forwards_h_path=${tmp_forwards_h_path}"
58+
"--backwards_cc_path=${tmp_backwards_cc_path}"
59+
"--backwards_h_path=${tmp_backwards_h_path}"
4660
"--nodes_cc_path=${tmp_nodes_cc_path}" "--nodes_h_path=${tmp_nodes_h_path}"
4761
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_cc_path}
4862
${forwards_cc_path}
4963
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_forwards_h_path}
5064
${forwards_h_path}
65+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_backwards_cc_path}
66+
${backwards_cc_path}
67+
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_backwards_h_path}
68+
${backwards_h_path}
5169
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_cc_path}
5270
${nodes_cc_path}
5371
COMMAND ${CMAKE_COMMAND} -E copy_if_different ${tmp_nodes_h_path}

paddle/fluid/eager/auto_code_generator/generator/eager_gen.py

Lines changed: 55 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,8 @@ def ParseArguments():
237237
parser.add_argument('--nodes_cc_path', type=str)
238238
parser.add_argument('--forwards_h_path', type=str)
239239
parser.add_argument('--forwards_cc_path', type=str)
240+
parser.add_argument('--backwards_h_path', type=str)
241+
parser.add_argument('--backwards_cc_path', type=str)
240242
parser.add_argument('--api_yaml_path', type=str)
241243
parser.add_argument('--backward_yaml_path', type=str)
242244

@@ -546,9 +548,9 @@ class {} : public egr::GradNodeBase {{
546548
NODE_CC_FILE_TEMPLATE = """
547549
#include "glog/logging.h"
548550
#include "paddle/phi/api/all.h"
549-
#include "paddle/phi/api/backward/backward_api.h"
550-
#include "paddle/phi/api/backward/fused_backward_api.h"
551-
#include "paddle/phi/api/backward/sparse_bw_api.h"
551+
#include "paddle/phi/api/backward/backward_api_base.h"
552+
#include "paddle/phi/api/backward/fused_backward_api_base.h"
553+
#include "paddle/phi/api/backward/sparse_backward_api_base.h"
552554
#include "paddle/fluid/imperative/tracer.h"
553555
#include "paddle/fluid/framework/op_registry.h"
554556
#include "paddle/phi/core/platform/profiler/event_tracing.h"
@@ -578,19 +580,28 @@ class {} : public egr::GradNodeBase {{
578580
579581
{}
580582
"""
583+
FORWARD_CC_HEADER = """
584+
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
585+
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
586+
"""
581587

588+
BACKWARD_CC_HEADER = """
589+
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_grad_functions.h"
590+
#include "paddle/phi/api/backward/sparse_backward_api.h"
591+
#include "paddle/phi/api/backward/fused_backward_api.h"
592+
#include "paddle/phi/api/backward/backward_api.h"
593+
"""
582594
FORWARD_CC_FILE_TEMPLATE = """
583595
#include "paddle/phi/api/lib/dygraph_api.h"
584-
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
585596
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
586597
#include "paddle/fluid/eager/eager_layout_auto_tune.h"
587598
#include "paddle/phi/api/include/strings_api.h"
588-
#include "paddle/phi/api/include/sparse_api.h"
599+
589600
#include "paddle/fluid/eager/api/utils/global_utils.h"
590601
#include "paddle/phi/core/platform/profiler/event_tracing.h"
591602
#include "paddle/phi/backends/gpu/gpu_info.h"
592603
#include "paddle/fluid/eager/nan_inf_utils.h"
593-
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
604+
594605
#include "paddle/common/flags.h"
595606
#include "paddle/phi/api/lib/data_transform.h"
596607
#include "paddle/fluid/eager/type_promotion_utils.h"
@@ -612,9 +623,8 @@ class {} : public egr::GradNodeBase {{
612623
#include "paddle/phi/api/all.h"
613624
#include "paddle/fluid/eager/utils.h"
614625
#include "paddle/fluid/framework/op_registry.h"
615-
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
616626
#include "paddle/utils/test_macros.h"
617-
627+
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
618628
using CPUPlace = phi::CPUPlace;
619629
{}
620630
{}
@@ -3299,23 +3309,33 @@ def GenerateNodeHFile(filepath, node_declaration_str):
32993309
f.write(file_contents)
33003310

33013311

3302-
def GenerateForwardCCFile(filepath, forward_definition_str):
3312+
def GenerateForwardCCFile(filepath, forward_definition_str, grad_flag):
33033313
if os.path.exists(filepath):
33043314
os.remove(filepath)
33053315

3306-
core_ops_info_str = GenerateCoreOpInfoDefinition()
3307-
file_contents = FORWARD_CC_FILE_TEMPLATE.format(
3316+
if not grad_flag:
3317+
file_contents = FORWARD_CC_HEADER
3318+
core_ops_info_str = " "
3319+
else:
3320+
file_contents = BACKWARD_CC_HEADER
3321+
core_ops_info_str = GenerateCoreOpInfoDefinition()
3322+
3323+
file_contents += FORWARD_CC_FILE_TEMPLATE.format(
33083324
core_ops_info_str, forward_definition_str
33093325
)
3326+
33103327
with open(filepath, 'a') as f:
33113328
f.write(file_contents)
33123329

33133330

3314-
def GenerateForwardHFile(filepath, forward_function_declaration_str):
3331+
def GenerateForwardHFile(filepath, forward_function_declaration_str, grad_flag):
33153332
if os.path.exists(filepath):
33163333
os.remove(filepath)
3334+
if not grad_flag:
3335+
core_ops_info_str = ""
3336+
else:
3337+
core_ops_info_str = GenerateCoreOpInfoDeclaration()
33173338

3318-
core_ops_info_str = GenerateCoreOpInfoDeclaration()
33193339
file_contents = FORWARD_H_FILE_TEMPLATE.format(
33203340
core_ops_info_str, forward_function_declaration_str
33213341
)
@@ -3336,6 +3356,8 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
33363356
forward_declaration_str = ""
33373357
forward_definition_str = ""
33383358

3359+
backward_declaration_str = ""
3360+
backward_definition_str = ""
33393361
# merge dygraph_ops.yaml and ops.yaml, dygraph_backward.yaml and backward.yaml
33403362
all_ops = []
33413363
all_bw = []
@@ -3380,6 +3402,20 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
33803402
forward_declaration_str += generator.forward_declaration_str + "\n"
33813403
forward_definition_str += generator.forward_definition_str + "\n"
33823404

3405+
# Generate Files
3406+
nodes_h_path = args.nodes_h_path
3407+
nodes_cc_path = args.nodes_cc_path
3408+
forwards_h_path = args.forwards_h_path
3409+
forwards_cc_path = args.forwards_cc_path
3410+
3411+
GenerateNodeCCFile(nodes_cc_path, node_definition_str)
3412+
GenerateNodeHFile(nodes_h_path, node_declaration_str)
3413+
GenerateForwardCCFile(forwards_cc_path, forward_definition_str, False)
3414+
GenerateForwardHFile(forwards_h_path, forward_declaration_str, False)
3415+
3416+
backwards_h_path = args.backwards_h_path
3417+
backwards_cc_path = args.backwards_cc_path
3418+
33833419
for i in range(len(backward_yaml_paths)):
33843420
backward_yaml_path = backward_yaml_paths[i]
33853421
if backward_yaml_path.endswith('/backward.yaml'):
@@ -3395,18 +3431,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
33953431

33963432
generator_grad.run(True)
33973433

3398-
node_declaration_str += generator_grad.node_declaration_str + "\n"
3399-
node_definition_str += generator_grad.node_definition_str + "\n"
3400-
3401-
forward_declaration_str += generator_grad.forward_declaration_str + "\n"
3402-
forward_definition_str += generator_grad.forward_definition_str + "\n"
3403-
# Generate Files
3404-
nodes_h_path = args.nodes_h_path
3405-
nodes_cc_path = args.nodes_cc_path
3406-
forwards_h_path = args.forwards_h_path
3407-
forwards_cc_path = args.forwards_cc_path
3434+
backward_declaration_str += (
3435+
generator_grad.forward_declaration_str + "\n"
3436+
)
3437+
backward_definition_str += generator_grad.forward_definition_str + "\n"
34083438

3409-
GenerateNodeCCFile(nodes_cc_path, node_definition_str)
3410-
GenerateNodeHFile(nodes_h_path, node_declaration_str)
3411-
GenerateForwardCCFile(forwards_cc_path, forward_definition_str)
3412-
GenerateForwardHFile(forwards_h_path, forward_declaration_str)
3439+
GenerateForwardCCFile(backwards_cc_path, backward_definition_str, True)
3440+
GenerateForwardHFile(backwards_h_path, backward_declaration_str, True)

paddle/fluid/eager/auto_code_generator/generator/python_c_gen.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,7 @@ def FindParsingFunctionFromAttributeType(atype):
198198
#include "paddle/phi/core/platform/profiler/event_tracing.h"
199199
#include "paddle/fluid/pybind/op_function_common.h"
200200
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
201+
#include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_grad_functions.h"
201202
#include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
202203
#include "paddle/fluid/eager/utils.h"
203204
#include "paddle/fluid/pybind/eager_custom_python_api.h"

paddle/fluid/operators/custom_device_common_op_registry.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ limitations under the License. */
1717
#include "paddle/fluid/operators/collective/c_concat_op.h"
1818
#include "paddle/fluid/operators/load_combine_op.h"
1919
#include "paddle/fluid/operators/save_combine_op.h"
20-
#include "paddle/phi/api/backward/backward_api.h"
20+
#include "paddle/phi/api/backward/backward_api_base.h"
2121
#include "paddle/phi/api/include/api.h"
2222
#include "paddle/phi/backends/device_manager.h"
2323
#include "paddle/phi/core/distributed/comm_context_manager.h"

0 commit comments

Comments
 (0)