Skip to content

Commit ff57d6f

Browse files
authored
[PIR Unittest] fix 4 pir uts test_elementwise_(add_op,div_op,nn_grad,sub_op) (#67403)
1 parent 015dfc6 commit ff57d6f

File tree

4 files changed

+73
-61
lines changed

4 files changed

+73
-61
lines changed

test/legacy_test/test_elementwise_add_op.py

Lines changed: 31 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -620,12 +620,13 @@ def _executed_api(self, x, y, name=None):
620620
return paddle.add(x, y, name)
621621

622622
def test_name(self):
623-
with base.program_guard(base.Program()):
624-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
625-
y = paddle.static.data(name='y', shape=[2, 3], dtype='float32')
623+
with paddle.pir_utils.OldIrGuard():
624+
with base.program_guard(base.Program()):
625+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
626+
y = paddle.static.data(name='y', shape=[2, 3], dtype='float32')
626627

627-
y_1 = self._executed_api(x, y, name='add_res')
628-
self.assertEqual(('add_res' in y_1.name), True)
628+
y_1 = self._executed_api(x, y, name='add_res')
629+
self.assertEqual(('add_res' in y_1.name), True)
629630

630631
def test_declarative(self):
631632
with base.program_guard(base.Program()):
@@ -642,7 +643,7 @@ def gen_data():
642643

643644
place = base.CPUPlace()
644645
exe = base.Executor(place)
645-
z_value = exe.run(feed=gen_data(), fetch_list=[z.name])
646+
z_value = exe.run(feed=gen_data(), fetch_list=[z])
646647
z_expected = np.array([3.0, 8.0, 6.0])
647648
self.assertEqual((z_value == z_expected).all(), True)
648649

@@ -857,27 +858,30 @@ def test_float16_add(self):
857858

858859
class TestTensorAddAPIWarnings(unittest.TestCase):
859860
def test_warnings(self):
860-
with warnings.catch_warnings(record=True) as context:
861-
warnings.simplefilter("always")
862-
863-
paddle.enable_static()
864-
helper = LayerHelper("elementwise_add")
865-
data = paddle.static.data(
866-
name='data', shape=[None, 3, 32, 32], dtype='float32'
867-
)
868-
out = helper.create_variable_for_type_inference(dtype=data.dtype)
869-
os.environ['FLAGS_print_extra_attrs'] = "1"
870-
helper.append_op(
871-
type="elementwise_add",
872-
inputs={'X': data, 'Y': data},
873-
outputs={'Out': out},
874-
attrs={'axis': 1, 'use_mkldnn': False},
875-
)
876-
self.assertTrue(
877-
"op elementwise_add's attr axis = 1 is not the default value: -1"
878-
in str(context[-1].message)
879-
)
880-
os.environ['FLAGS_print_extra_attrs'] = "0"
861+
with paddle.pir_utils.OldIrGuard():
862+
with warnings.catch_warnings(record=True) as context:
863+
warnings.simplefilter("always")
864+
865+
paddle.enable_static()
866+
helper = LayerHelper("elementwise_add")
867+
data = paddle.static.data(
868+
name='data', shape=[None, 3, 32, 32], dtype='float32'
869+
)
870+
out = helper.create_variable_for_type_inference(
871+
dtype=data.dtype
872+
)
873+
os.environ['FLAGS_print_extra_attrs'] = "1"
874+
helper.append_op(
875+
type="elementwise_add",
876+
inputs={'X': data, 'Y': data},
877+
outputs={'Out': out},
878+
attrs={'axis': 1, 'use_mkldnn': False},
879+
)
880+
self.assertTrue(
881+
"op elementwise_add's attr axis = 1 is not the default value: -1"
882+
in str(context[-1].message)
883+
)
884+
os.environ['FLAGS_print_extra_attrs'] = "0"
881885

882886

883887
class TestTensorFloat32Bfloat16OrFloat16Add(unittest.TestCase):

test/legacy_test/test_elementwise_div_op.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from op_test import OpTest, convert_float_to_uint16, skip_check_grad_ci
1919

2020
import paddle
21+
import paddle.static
2122
from paddle import base
2223
from paddle.base import core
2324
from paddle.pir_utils import test_with_pir_api
@@ -509,14 +510,15 @@ def test_shape_with_batch_sizes(self):
509510
class TestDivideOp(unittest.TestCase):
510511
def test_name(self):
511512
paddle.enable_static()
512-
main_program = paddle.static.Program()
513-
with paddle.static.program_guard(main_program):
514-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
515-
y = paddle.static.data(name='y', shape=[2, 3], dtype='float32')
513+
with paddle.pir_utils.OldIrGuard():
514+
main_program = paddle.static.Program()
515+
with paddle.static.program_guard(main_program):
516+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
517+
y = paddle.static.data(name='y', shape=[2, 3], dtype='float32')
516518

517-
y_1 = paddle.divide(x, y, name='div_res')
519+
y_1 = paddle.divide(x, y, name='div_res')
518520

519-
self.assertEqual(('div_res' in y_1.name), True)
521+
self.assertEqual(('div_res' in y_1.name), True)
520522

521523
paddle.disable_static()
522524

test/legacy_test/test_elementwise_nn_grad.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -552,7 +552,8 @@ def test_grad(self):
552552
if core.is_compiled_with_cuda():
553553
places.append(base.CUDAPlace(0))
554554
for p in places:
555-
self.func(p)
555+
with paddle.pir_utils.OldIrGuard():
556+
self.func(p)
556557

557558

558559
class TestElementwiseMulTripleGradCheck(unittest.TestCase):
@@ -621,7 +622,8 @@ def test_grad(self):
621622
if core.is_compiled_with_cuda():
622623
places.append(base.CUDAPlace(0))
623624
for p in places:
624-
self.func(p)
625+
with paddle.pir_utils.OldIrGuard():
626+
self.func(p)
625627

626628

627629
if __name__ == "__main__":

test/legacy_test/test_elementwise_sub_op.py

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -897,12 +897,13 @@ def _executed_api(self, x, y, name=None):
897897
return paddle.subtract(x, y, name)
898898

899899
def test_name(self):
900-
with base.program_guard(base.Program()):
901-
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
902-
y = paddle.static.data(name='y', shape=[2, 3], dtype=np.float32)
900+
with paddle.pir_utils.OldIrGuard():
901+
with base.program_guard(base.Program()):
902+
x = paddle.static.data(name="x", shape=[2, 3], dtype="float32")
903+
y = paddle.static.data(name='y', shape=[2, 3], dtype=np.float32)
903904

904-
y_1 = self._executed_api(x, y, name='subtract_res')
905-
self.assertEqual(('subtract_res' in y_1.name), True)
905+
y_1 = self._executed_api(x, y, name='subtract_res')
906+
self.assertEqual(('subtract_res' in y_1.name), True)
906907

907908
@test_with_pir_api
908909
def test_declarative(self):
@@ -1063,27 +1064,30 @@ def test_dygraph_sub(self):
10631064

10641065
class TestTensorSubAPIWarnings(unittest.TestCase):
10651066
def test_warnings(self):
1066-
with warnings.catch_warnings(record=True) as context:
1067-
warnings.simplefilter("always")
1068-
1069-
paddle.enable_static()
1070-
helper = LayerHelper("elementwise_sub")
1071-
data = paddle.static.data(
1072-
name='data', shape=[None, 3, 32, 32], dtype=np.float32
1073-
)
1074-
out = helper.create_variable_for_type_inference(dtype=data.dtype)
1075-
os.environ['FLAGS_print_extra_attrs'] = "1"
1076-
helper.append_op(
1077-
type="elementwise_sub",
1078-
inputs={'X': data, 'Y': data},
1079-
outputs={'Out': out},
1080-
attrs={'axis': 1, 'use_mkldnn': False},
1081-
)
1082-
self.assertTrue(
1083-
"op elementwise_sub's attr axis = 1 is not the default value: -1"
1084-
in str(context[-1].message)
1085-
)
1086-
os.environ['FLAGS_print_extra_attrs'] = "0"
1067+
with paddle.pir_utils.OldIrGuard():
1068+
with warnings.catch_warnings(record=True) as context:
1069+
warnings.simplefilter("always")
1070+
1071+
paddle.enable_static()
1072+
helper = LayerHelper("elementwise_sub")
1073+
data = paddle.static.data(
1074+
name='data', shape=[None, 3, 32, 32], dtype=np.float32
1075+
)
1076+
out = helper.create_variable_for_type_inference(
1077+
dtype=data.dtype
1078+
)
1079+
os.environ['FLAGS_print_extra_attrs'] = "1"
1080+
helper.append_op(
1081+
type="elementwise_sub",
1082+
inputs={'X': data, 'Y': data},
1083+
outputs={'Out': out},
1084+
attrs={'axis': 1, 'use_mkldnn': False},
1085+
)
1086+
self.assertTrue(
1087+
"op elementwise_sub's attr axis = 1 is not the default value: -1"
1088+
in str(context[-1].message)
1089+
)
1090+
os.environ['FLAGS_print_extra_attrs'] = "0"
10871091

10881092

10891093
if __name__ == '__main__':

0 commit comments

Comments
 (0)