@@ -237,6 +237,8 @@ def ParseArguments():
237
237
parser .add_argument ('--nodes_cc_path' , type = str )
238
238
parser .add_argument ('--forwards_h_path' , type = str )
239
239
parser .add_argument ('--forwards_cc_path' , type = str )
240
+ parser .add_argument ('--backwards_h_path' , type = str )
241
+ parser .add_argument ('--backwards_cc_path' , type = str )
240
242
parser .add_argument ('--api_yaml_path' , type = str )
241
243
parser .add_argument ('--backward_yaml_path' , type = str )
242
244
@@ -546,9 +548,9 @@ class {} : public egr::GradNodeBase {{
546
548
NODE_CC_FILE_TEMPLATE = """
547
549
#include "glog/logging.h"
548
550
#include "paddle/phi/api/all.h"
549
- #include "paddle/phi/api/backward/backward_api .h"
550
- #include "paddle/phi/api/backward/fused_backward_api .h"
551
- #include "paddle/phi/api/backward/sparse_bw_api .h"
551
+ #include "paddle/phi/api/backward/backward_api_base .h"
552
+ #include "paddle/phi/api/backward/fused_backward_api_base .h"
553
+ #include "paddle/phi/api/backward/sparse_backward_api_base .h"
552
554
#include "paddle/fluid/imperative/tracer.h"
553
555
#include "paddle/fluid/framework/op_registry.h"
554
556
#include "paddle/phi/core/platform/profiler/event_tracing.h"
@@ -578,19 +580,28 @@ class {} : public egr::GradNodeBase {{
578
580
579
581
{}
580
582
"""
583
+ FORWARD_CC_HEADER = """
584
+ #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
585
+ #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
586
+ """
581
587
588
+ BACKWARD_CC_HEADER = """
589
+ #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_grad_functions.h"
590
+ #include "paddle/phi/api/backward/sparse_backward_api.h"
591
+ #include "paddle/phi/api/backward/fused_backward_api.h"
592
+ #include "paddle/phi/api/backward/backward_api.h"
593
+ """
582
594
FORWARD_CC_FILE_TEMPLATE = """
583
595
#include "paddle/phi/api/lib/dygraph_api.h"
584
- #include "paddle/fluid/eager/api/generated/eager_generated/forwards/dygraph_functions.h"
585
596
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
586
597
#include "paddle/fluid/eager/eager_layout_auto_tune.h"
587
598
#include "paddle/phi/api/include/strings_api.h"
588
- #include "paddle/phi/api/include/sparse_api.h"
599
+
589
600
#include "paddle/fluid/eager/api/utils/global_utils.h"
590
601
#include "paddle/phi/core/platform/profiler/event_tracing.h"
591
602
#include "paddle/phi/backends/gpu/gpu_info.h"
592
603
#include "paddle/fluid/eager/nan_inf_utils.h"
593
- #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
604
+
594
605
#include "paddle/common/flags.h"
595
606
#include "paddle/phi/api/lib/data_transform.h"
596
607
#include "paddle/fluid/eager/type_promotion_utils.h"
@@ -612,9 +623,8 @@ class {} : public egr::GradNodeBase {{
612
623
#include "paddle/phi/api/all.h"
613
624
#include "paddle/fluid/eager/utils.h"
614
625
#include "paddle/fluid/framework/op_registry.h"
615
- #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
616
626
#include "paddle/utils/test_macros.h"
617
-
627
+ #include "paddle/fluid/eager/api/manual/eager_manual/dygraph_forward_api.h"
618
628
using CPUPlace = phi::CPUPlace;
619
629
{}
620
630
{}
@@ -3299,23 +3309,33 @@ def GenerateNodeHFile(filepath, node_declaration_str):
3299
3309
f .write (file_contents )
3300
3310
3301
3311
3302
- def GenerateForwardCCFile (filepath , forward_definition_str ):
3312
+ def GenerateForwardCCFile (filepath , forward_definition_str , grad_flag ):
3303
3313
if os .path .exists (filepath ):
3304
3314
os .remove (filepath )
3305
3315
3306
- core_ops_info_str = GenerateCoreOpInfoDefinition ()
3307
- file_contents = FORWARD_CC_FILE_TEMPLATE .format (
3316
+ if not grad_flag :
3317
+ file_contents = FORWARD_CC_HEADER
3318
+ core_ops_info_str = " "
3319
+ else :
3320
+ file_contents = BACKWARD_CC_HEADER
3321
+ core_ops_info_str = GenerateCoreOpInfoDefinition ()
3322
+
3323
+ file_contents += FORWARD_CC_FILE_TEMPLATE .format (
3308
3324
core_ops_info_str , forward_definition_str
3309
3325
)
3326
+
3310
3327
with open (filepath , 'a' ) as f :
3311
3328
f .write (file_contents )
3312
3329
3313
3330
3314
- def GenerateForwardHFile (filepath , forward_function_declaration_str ):
3331
+ def GenerateForwardHFile (filepath , forward_function_declaration_str , grad_flag ):
3315
3332
if os .path .exists (filepath ):
3316
3333
os .remove (filepath )
3334
+ if not grad_flag :
3335
+ core_ops_info_str = ""
3336
+ else :
3337
+ core_ops_info_str = GenerateCoreOpInfoDeclaration ()
3317
3338
3318
- core_ops_info_str = GenerateCoreOpInfoDeclaration ()
3319
3339
file_contents = FORWARD_H_FILE_TEMPLATE .format (
3320
3340
core_ops_info_str , forward_function_declaration_str
3321
3341
)
@@ -3336,6 +3356,8 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
3336
3356
forward_declaration_str = ""
3337
3357
forward_definition_str = ""
3338
3358
3359
+ backward_declaration_str = ""
3360
+ backward_definition_str = ""
3339
3361
# merge dygraph_ops.yaml and ops.yaml, dygraph_backward.yaml and backward.yaml
3340
3362
all_ops = []
3341
3363
all_bw = []
@@ -3380,6 +3402,20 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
3380
3402
forward_declaration_str += generator .forward_declaration_str + "\n "
3381
3403
forward_definition_str += generator .forward_definition_str + "\n "
3382
3404
3405
+ # Generate Files
3406
+ nodes_h_path = args .nodes_h_path
3407
+ nodes_cc_path = args .nodes_cc_path
3408
+ forwards_h_path = args .forwards_h_path
3409
+ forwards_cc_path = args .forwards_cc_path
3410
+
3411
+ GenerateNodeCCFile (nodes_cc_path , node_definition_str )
3412
+ GenerateNodeHFile (nodes_h_path , node_declaration_str )
3413
+ GenerateForwardCCFile (forwards_cc_path , forward_definition_str , False )
3414
+ GenerateForwardHFile (forwards_h_path , forward_declaration_str , False )
3415
+
3416
+ backwards_h_path = args .backwards_h_path
3417
+ backwards_cc_path = args .backwards_cc_path
3418
+
3383
3419
for i in range (len (backward_yaml_paths )):
3384
3420
backward_yaml_path = backward_yaml_paths [i ]
3385
3421
if backward_yaml_path .endswith ('/backward.yaml' ):
@@ -3395,18 +3431,10 @@ def GenerateForwardHFile(filepath, forward_function_declaration_str):
3395
3431
3396
3432
generator_grad .run (True )
3397
3433
3398
- node_declaration_str += generator_grad .node_declaration_str + "\n "
3399
- node_definition_str += generator_grad .node_definition_str + "\n "
3400
-
3401
- forward_declaration_str += generator_grad .forward_declaration_str + "\n "
3402
- forward_definition_str += generator_grad .forward_definition_str + "\n "
3403
- # Generate Files
3404
- nodes_h_path = args .nodes_h_path
3405
- nodes_cc_path = args .nodes_cc_path
3406
- forwards_h_path = args .forwards_h_path
3407
- forwards_cc_path = args .forwards_cc_path
3434
+ backward_declaration_str += (
3435
+ generator_grad .forward_declaration_str + "\n "
3436
+ )
3437
+ backward_definition_str += generator_grad .forward_definition_str + "\n "
3408
3438
3409
- GenerateNodeCCFile (nodes_cc_path , node_definition_str )
3410
- GenerateNodeHFile (nodes_h_path , node_declaration_str )
3411
- GenerateForwardCCFile (forwards_cc_path , forward_definition_str )
3412
- GenerateForwardHFile (forwards_h_path , forward_declaration_str )
3439
+ GenerateForwardCCFile (backwards_cc_path , backward_definition_str , True )
3440
+ GenerateForwardHFile (backwards_h_path , backward_declaration_str , True )
0 commit comments