21
21
22
22
import numpy as np
23
23
24
+ import paddle
24
25
from paddle import set_flags , static
25
26
from paddle .base import core
26
27
from paddle .jit .api import sot_mode_guard
29
30
# Usage:
30
31
class MyTest(Dy2StTestBase):
31
32
@set_to_static_mode(
32
- ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST
33
+ ToStaticMode.AST | ToStaticMode.SOT
33
34
)
34
- @set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR )
35
+ @set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE | IrMode.PIR_API )
35
36
def test_case1(self):
36
37
raise ValueError("MyTest 1")
37
38
@@ -49,8 +50,7 @@ def test_case1(self):
49
50
50
51
51
52
class ToStaticMode (Flag ):
52
- LEGACY_AST = auto ()
53
- PIR_AST = auto ()
53
+ AST = auto ()
54
54
SOT = auto ()
55
55
56
56
def lower_case_name (self ):
@@ -59,13 +59,16 @@ def lower_case_name(self):
59
59
60
60
class IrMode (Flag ):
61
61
LEGACY_IR = auto ()
62
- PIR = auto ()
62
+ # pir translator mode, Reference link: https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/IR_Dialect/program_translator.md
63
+ PIR_EXE = auto ()
64
+ # using native pir api mode
65
+ PIR_API = auto ()
63
66
64
67
def lower_case_name (self ):
65
68
return self .name .lower ()
66
69
67
70
68
- DEFAULT_TO_STATIC_MODE = ToStaticMode .LEGACY_AST | ToStaticMode .SOT
71
+ DEFAULT_TO_STATIC_MODE = ToStaticMode .AST | ToStaticMode .SOT
69
72
DEFAULT_IR_MODE = IrMode .LEGACY_IR
70
73
71
74
@@ -98,13 +101,24 @@ def impl(*args, **kwargs):
98
101
99
102
100
103
def to_pir_ast_test (fn ):
101
- raise TypeError ("Don't enable PIR AST mode now!" )
104
+ @wraps (fn )
105
+ def impl (* args , ** kwargs ):
106
+ logger .info ("[PIR][AST] running pir api" )
107
+ ir_outs = None
108
+ try :
109
+ with paddle .pir_utils .IrGuard ():
110
+ paddle .disable_static ()
111
+ ir_outs = fn (* args , ** kwargs )
112
+ finally :
113
+ paddle .enable_static ()
114
+ return ir_outs
115
+
116
+ return impl
102
117
103
118
104
119
def to_legacy_ir_test (fn ):
105
120
def impl (* args , ** kwargs ):
106
121
logger .info ("[Program] running legacy ir" )
107
- # breakpoint()
108
122
return fn (* args , ** kwargs )
109
123
110
124
return impl
@@ -136,13 +150,13 @@ def impl(*args, **kwargs):
136
150
class Dy2StTestMeta (type ):
137
151
TO_STATIC_HANDLER_MAP = {
138
152
ToStaticMode .SOT : to_sot_test ,
139
- ToStaticMode .LEGACY_AST : to_legacy_ast_test ,
140
- ToStaticMode .PIR_AST : to_pir_ast_test ,
153
+ ToStaticMode .AST : to_legacy_ast_test ,
141
154
}
142
155
143
156
IR_HANDLER_MAP = {
144
157
IrMode .LEGACY_IR : to_legacy_ir_test ,
145
- IrMode .PIR : to_pir_test ,
158
+ IrMode .PIR_EXE : to_pir_test ,
159
+ IrMode .PIR_API : to_pir_ast_test ,
146
160
}
147
161
148
162
def __new__ (cls , name , bases , attrs ):
@@ -191,11 +205,11 @@ def __new__(cls, name, bases, attrs):
191
205
)
192
206
# Generate all test cases
193
207
for to_static_mode , ir_mode in to_static_with_ir_modes :
208
+ # NOTE(gouzil): Temporarily not supported SOT + PIR, link: https://github.com/PaddlePaddle/Paddle/pull/58630
194
209
if (
195
- to_static_mode == ToStaticMode .PIR_AST
196
- and ir_mode == IrMode .LEGACY_IR
210
+ to_static_mode == ToStaticMode .SOT
211
+ and ir_mode == IrMode .PIR_API
197
212
):
198
- # PIR with LEGACY_IR is not a valid combination
199
213
continue
200
214
new_attrs [
201
215
Dy2StTestMeta .test_case_name (
@@ -250,7 +264,7 @@ def decorator(fn):
250
264
# Suger decorators
251
265
# These decorators can be simply composed by base decorators
252
266
def test_ast_only (fn ):
253
- fn = set_to_static_mode (ToStaticMode .LEGACY_AST )(fn )
267
+ fn = set_to_static_mode (ToStaticMode .AST )(fn )
254
268
return fn
255
269
256
270
@@ -260,12 +274,22 @@ def test_sot_only(fn):
260
274
261
275
262
276
def test_pir_only (fn ):
263
- fn = set_ir_mode (IrMode .PIR )(fn )
277
+ fn = set_ir_mode (IrMode .PIR_EXE )(fn )
264
278
return fn
265
279
266
280
267
281
def test_legacy_and_pir (fn ):
268
- fn = set_ir_mode (IrMode .LEGACY_IR | IrMode .PIR )(fn )
282
+ fn = set_ir_mode (IrMode .LEGACY_IR | IrMode .PIR_EXE )(fn )
283
+ return fn
284
+
285
+
286
+ def test_legacy_and_pir_api (fn ):
287
+ fn = set_ir_mode (IrMode .LEGACY_IR | IrMode .PIR_API )
288
+ return fn
289
+
290
+
291
+ def test_legacy_and_pir_api_and_pir_exe (fn ):
292
+ fn = set_ir_mode (IrMode .LEGACY_IR | IrMode .PIR_API | IrMode .PIR_EXE )
269
293
return fn
270
294
271
295
0 commit comments