diff --git a/test/dygraph_to_static/test_bmn.py b/test/dygraph_to_static/test_bmn.py index 57aca8f9ada7f..aaf31a87b52ef 100644 --- a/test/dygraph_to_static/test_bmn.py +++ b/test/dygraph_to_static/test_bmn.py @@ -18,6 +18,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest from predictor_utils import PredictorTools import paddle @@ -637,6 +638,7 @@ def val_bmn(model, args): return loss_data +@dy2static_unittest class TestTrain(unittest.TestCase): def setUp(self): self.args = Args() @@ -667,6 +669,7 @@ def train_bmn(self, args, place, to_static): local_random = np.random.RandomState(SEED) bmn = BMN(args) + bmn = paddle.jit.to_static(bmn) adam = optimizer(args, parameter_list=bmn.parameters()) train_reader = fake_data_reader(args, 'train') @@ -749,6 +752,7 @@ def train_bmn(self, args, place, to_static): break return np.array(loss_data) + @ast_only_test def test_train(self): static_res = self.train_bmn(self.args, self.place, to_static=True) diff --git a/test/dygraph_to_static/test_lstm.py b/test/dygraph_to_static/test_lstm.py index 1c114a50914ce..7a54ae0c1c04f 100644 --- a/test/dygraph_to_static/test_lstm.py +++ b/test/dygraph_to_static/test_lstm.py @@ -17,6 +17,7 @@ import unittest import numpy as np +from dygraph_to_static_util import ast_only_test, dy2static_unittest import paddle from paddle import nn @@ -44,6 +45,7 @@ def forward(self, x): return x +@dy2static_unittest class TestLstm(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -69,6 +71,7 @@ def test_lstm_to_static(self): static_out = self.run_lstm(to_static=True) np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05) + @ast_only_test def test_save_in_eval(self, with_training=True): paddle.jit.enable_to_static(True) net = Net(12, 2) @@ -133,6 +136,7 @@ def forward(self, x): return y +@dy2static_unittest class TestSaveInEvalMode(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -140,6 +144,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @ast_only_test def test_save_in_eval(self): paddle.jit.enable_to_static(True) net = LinearNet() @@ -178,6 +183,7 @@ def test_save_in_eval(self): ) +@dy2static_unittest class TestEvalAfterSave(unittest.TestCase): def setUp(self): self.temp_dir = tempfile.TemporaryDirectory() @@ -185,6 +191,7 @@ def setUp(self): def tearDown(self): self.temp_dir.cleanup() + @ast_only_test def test_eval_after_save(self): x = paddle.randn((2, 10, 12)).astype('float32') net = Net(12, 2)