Skip to content

Commit d41e766

Browse files
merge one test case (#66477)
1 parent 0c6cbf9 commit d41e766

File tree

4 files changed

+59
-85
lines changed

4 files changed

+59
-85
lines changed

test/deprecated/legacy_test/test_compiled_program_deprecated.py

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,11 @@
1515
import sys
1616
import unittest
1717

18+
import numpy as np
19+
from simple_nets import simple_fc_net
20+
1821
sys.path.append("../../legacy_test")
22+
from test_imperative_base import new_program_scope
1923

2024
import paddle
2125
from paddle import base
@@ -24,6 +28,58 @@
2428
paddle.enable_static()
2529

2630

31+
class TestCompiledProgram(unittest.TestCase):
32+
def setUp(self):
33+
self.seed = 100
34+
self.img = np.random.random(size=(16, 784)).astype('float32')
35+
self.label = np.random.randint(
36+
low=0, high=10, size=[16, 1], dtype=np.int64
37+
)
38+
paddle.enable_static()
39+
with new_program_scope():
40+
paddle.seed(self.seed)
41+
paddle.framework.random._manual_program_seed(self.seed)
42+
place = (
43+
base.CUDAPlace(0)
44+
if core.is_compiled_with_cuda()
45+
else base.CPUPlace()
46+
)
47+
exe = base.Executor(place)
48+
49+
loss = simple_fc_net()
50+
exe.run(base.default_startup_program())
51+
52+
(loss_data,) = exe.run(
53+
base.default_main_program(),
54+
feed={"image": self.img, "label": self.label},
55+
fetch_list=[loss],
56+
)
57+
self.loss = float(loss_data)
58+
59+
def test_compiled_program_base(self):
60+
paddle.enable_static()
61+
with new_program_scope():
62+
paddle.seed(self.seed)
63+
paddle.framework.random._manual_program_seed(self.seed)
64+
place = (
65+
base.CUDAPlace(0)
66+
if core.is_compiled_with_cuda()
67+
else base.CPUPlace()
68+
)
69+
exe = base.Executor(place)
70+
71+
loss = simple_fc_net()
72+
exe.run(base.default_startup_program())
73+
compiled_prog = base.CompiledProgram(base.default_main_program())
74+
75+
(loss_data,) = exe.run(
76+
compiled_prog,
77+
feed={"image": self.img, "label": self.label},
78+
fetch_list=[loss],
79+
)
80+
np.testing.assert_array_equal(float(loss_data), self.loss)
81+
82+
2783
class TestCompiledProgramError(unittest.TestCase):
2884
def test_program_or_graph_error(self):
2985
self.assertRaises(TypeError, base.CompiledProgram, "program")

test/deprecated/legacy_test/test_compiled_program_deprecated2.py

Lines changed: 0 additions & 82 deletions
This file was deleted.

tools/parallel_UT_rule.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1394,7 +1394,7 @@
13941394
'test_imperative_ptb_rnn_sorted_gradient',
13951395
'test_hapi_hub',
13961396
'test_reverse_op',
1397-
'test_compiled_program_deprecated2',
1397+
'test_compiled_program_deprecated',
13981398
'test_lambda',
13991399
'test_adadelta_op',
14001400
'test_nn_sigmoid_op',
@@ -2363,7 +2363,7 @@
23632363
'test_graph_send_recv_op',
23642364
'test_fill_constant_op',
23652365
'test_distribution',
2366-
'test_compiled_program_deprecated2',
2366+
'test_compiled_program_deprecated',
23672367
'test_compare_op',
23682368
'test_bitwise_op',
23692369
'test_bce_with_logits_loss',

tools/static_mode_white_list.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@
9494
'test_clip_op',
9595
'test_collect_fpn_proposals_op',
9696
'test_compare_reduce_op',
97-
'test_compiled_program_deprecated2',
97+
'test_compiled_program_deprecated',
9898
'test_cond',
9999
'test_conditional_block_deprecated',
100100
'test_context_manager',

0 commit comments

Comments
 (0)