17
17
#include " paddle/fluid/prim/utils/static/static_global_utils.h"
18
18
19
19
PADDLE_DEFINE_EXPORTED_bool (prim_enabled, false , " enable_prim or not" );
20
- PADDLE_DEFINE_EXPORTED_string (prim_blacklist, " " , " prim ops blacklist" );
20
+ PADDLE_DEFINE_EXPORTED_bool (prim_all, false , " enable prim_all or not" );
21
+ PADDLE_DEFINE_EXPORTED_bool (prim_forward, false , " enable prim_forward or not" );
22
+ PADDLE_DEFINE_EXPORTED_bool (prim_backward, false , " enable prim_backward not" );
21
23
22
24
namespace paddle {
23
25
namespace prim {
24
-
25
26
bool PrimCommonUtils::IsBwdPrimEnabled () {
26
- return StaticCompositeContext::Instance ().IsBwdPrimEnabled ();
27
+ bool res = StaticCompositeContext::Instance ().IsBwdPrimEnabled ();
28
+ return res || FLAGS_prim_all || FLAGS_prim_backward;
27
29
}
28
30
29
31
void PrimCommonUtils::SetBwdPrimEnabled (bool enable_prim) {
@@ -39,16 +41,15 @@ void PrimCommonUtils::SetEagerPrimEnabled(bool enable_prim) {
39
41
}
40
42
41
43
bool PrimCommonUtils::IsFwdPrimEnabled () {
42
- return StaticCompositeContext::Instance ().IsFwdPrimEnabled ();
44
+ bool res = StaticCompositeContext::Instance ().IsFwdPrimEnabled ();
45
+ return res || FLAGS_prim_all || FLAGS_prim_forward;
43
46
}
44
47
45
48
void PrimCommonUtils::SetFwdPrimEnabled (bool enable_prim) {
46
- VLOG (0 ) << " FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled;
47
49
StaticCompositeContext::Instance ().SetFwdPrimEnabled (enable_prim);
48
50
}
49
51
50
52
void PrimCommonUtils::SetAllPrimEnabled (bool enable_prim) {
51
- VLOG (0 ) << " FLAGS_prim_enabled ====================== " << FLAGS_prim_enabled;
52
53
StaticCompositeContext::Instance ().SetAllPrimEnabled (enable_prim);
53
54
}
54
55
0 commit comments