Skip to content

Commit 36dd23a

Browse files
gouzilzeroRains
authored andcommitted
[Dy2St] pir dy2st unittest verification - Part 1 (PaddlePaddle#58630)
1 parent 906e183 commit 36dd23a

File tree

1 file changed

+41
-17
lines changed

1 file changed

+41
-17
lines changed

test/dygraph_to_static/dygraph_to_static_utils_new.py

Lines changed: 41 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
import numpy as np
2323

24+
import paddle
2425
from paddle import set_flags, static
2526
from paddle.base import core
2627
from paddle.jit.api import sot_mode_guard
@@ -29,9 +30,9 @@
2930
# Usage:
3031
class MyTest(Dy2StTestBase):
3132
@set_to_static_mode(
32-
ToStaticMode.LEGACY_AST | ToStaticMode.SOT | ToStaticMode.PIR_AST
33+
ToStaticMode.AST | ToStaticMode.SOT
3334
)
34-
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)
35+
@set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE | IrMode.PIR_API)
3536
def test_case1(self):
3637
raise ValueError("MyTest 1")
3738
@@ -49,8 +50,7 @@ def test_case1(self):
4950

5051

5152
class ToStaticMode(Flag):
52-
LEGACY_AST = auto()
53-
PIR_AST = auto()
53+
AST = auto()
5454
SOT = auto()
5555

5656
def lower_case_name(self):
@@ -59,13 +59,16 @@ def lower_case_name(self):
5959

6060
class IrMode(Flag):
6161
LEGACY_IR = auto()
62-
PIR = auto()
62+
# pir translator mode, Reference link: https://github.com/PaddlePaddle/community/blob/master/pfcc/paddle-code-reading/IR_Dialect/program_translator.md
63+
PIR_EXE = auto()
64+
# using native pir api mode
65+
PIR_API = auto()
6366

6467
def lower_case_name(self):
6568
return self.name.lower()
6669

6770

68-
DEFAULT_TO_STATIC_MODE = ToStaticMode.LEGACY_AST | ToStaticMode.SOT
71+
DEFAULT_TO_STATIC_MODE = ToStaticMode.AST | ToStaticMode.SOT
6972
DEFAULT_IR_MODE = IrMode.LEGACY_IR
7073

7174

@@ -98,13 +101,24 @@ def impl(*args, **kwargs):
98101

99102

100103
def to_pir_ast_test(fn):
101-
raise TypeError("Don't enable PIR AST mode now!")
104+
@wraps(fn)
105+
def impl(*args, **kwargs):
106+
logger.info("[PIR][AST] running pir api")
107+
ir_outs = None
108+
try:
109+
with paddle.pir_utils.IrGuard():
110+
paddle.disable_static()
111+
ir_outs = fn(*args, **kwargs)
112+
finally:
113+
paddle.enable_static()
114+
return ir_outs
115+
116+
return impl
102117

103118

104119
def to_legacy_ir_test(fn):
105120
def impl(*args, **kwargs):
106121
logger.info("[Program] running legacy ir")
107-
# breakpoint()
108122
return fn(*args, **kwargs)
109123

110124
return impl
@@ -136,13 +150,13 @@ def impl(*args, **kwargs):
136150
class Dy2StTestMeta(type):
137151
TO_STATIC_HANDLER_MAP = {
138152
ToStaticMode.SOT: to_sot_test,
139-
ToStaticMode.LEGACY_AST: to_legacy_ast_test,
140-
ToStaticMode.PIR_AST: to_pir_ast_test,
153+
ToStaticMode.AST: to_legacy_ast_test,
141154
}
142155

143156
IR_HANDLER_MAP = {
144157
IrMode.LEGACY_IR: to_legacy_ir_test,
145-
IrMode.PIR: to_pir_test,
158+
IrMode.PIR_EXE: to_pir_test,
159+
IrMode.PIR_API: to_pir_ast_test,
146160
}
147161

148162
def __new__(cls, name, bases, attrs):
@@ -191,11 +205,11 @@ def __new__(cls, name, bases, attrs):
191205
)
192206
# Generate all test cases
193207
for to_static_mode, ir_mode in to_static_with_ir_modes:
208+
# NOTE(gouzil): Temporarily not supported SOT + PIR, link: https://github.com/PaddlePaddle/Paddle/pull/58630
194209
if (
195-
to_static_mode == ToStaticMode.PIR_AST
196-
and ir_mode == IrMode.LEGACY_IR
210+
to_static_mode == ToStaticMode.SOT
211+
and ir_mode == IrMode.PIR_API
197212
):
198-
# PIR with LEGACY_IR is not a valid combination
199213
continue
200214
new_attrs[
201215
Dy2StTestMeta.test_case_name(
@@ -250,7 +264,7 @@ def decorator(fn):
250264
# Suger decorators
251265
# These decorators can be simply composed by base decorators
252266
def test_ast_only(fn):
253-
fn = set_to_static_mode(ToStaticMode.LEGACY_AST)(fn)
267+
fn = set_to_static_mode(ToStaticMode.AST)(fn)
254268
return fn
255269

256270

@@ -260,12 +274,22 @@ def test_sot_only(fn):
260274

261275

262276
def test_pir_only(fn):
263-
fn = set_ir_mode(IrMode.PIR)(fn)
277+
fn = set_ir_mode(IrMode.PIR_EXE)(fn)
264278
return fn
265279

266280

267281
def test_legacy_and_pir(fn):
268-
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR)(fn)
282+
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_EXE)(fn)
283+
return fn
284+
285+
286+
def test_legacy_and_pir_api(fn):
287+
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API)
288+
return fn
289+
290+
291+
def test_legacy_and_pir_api_and_pir_exe(fn):
292+
fn = set_ir_mode(IrMode.LEGACY_IR | IrMode.PIR_API | IrMode.PIR_EXE)
269293
return fn
270294

271295

0 commit comments

Comments
 (0)