Skip to content

Commit

Permalink
Merge pull request PaddlePaddle#11 from 0x45f/fix-ut2
Browse files Browse the repository at this point in the history
Ast only some uts
  • Loading branch information
2742195759 committed Jun 19, 2023
2 parents c2fe354 + 84894f5 commit a60e744
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 0 deletions.
4 changes: 4 additions & 0 deletions test/dygraph_to_static/test_bmn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -637,6 +638,7 @@ def val_bmn(model, args):
return loss_data


@dy2static_unittest
class TestTrain(unittest.TestCase):
def setUp(self):
self.args = Args()
Expand Down Expand Up @@ -667,6 +669,7 @@ def train_bmn(self, args, place, to_static):
local_random = np.random.RandomState(SEED)

bmn = BMN(args)
bmn = paddle.jit.to_static(bmn)
adam = optimizer(args, parameter_list=bmn.parameters())

train_reader = fake_data_reader(args, 'train')
Expand Down Expand Up @@ -749,6 +752,7 @@ def train_bmn(self, args, place, to_static):
break
return np.array(loss_data)

@ast_only_test
def test_train(self):

static_res = self.train_bmn(self.args, self.place, to_static=True)
Expand Down
7 changes: 7 additions & 0 deletions test/dygraph_to_static/test_lstm.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
import unittest

import numpy as np
from dygraph_to_static_util import ast_only_test, dy2static_unittest

import paddle
from paddle import nn
Expand Down Expand Up @@ -44,6 +45,7 @@ def forward(self, x):
return x


@dy2static_unittest
class TestLstm(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()
Expand All @@ -69,6 +71,7 @@ def test_lstm_to_static(self):
static_out = self.run_lstm(to_static=True)
np.testing.assert_allclose(dygraph_out, static_out, rtol=1e-05)

@ast_only_test
def test_save_in_eval(self, with_training=True):
paddle.jit.enable_to_static(True)
net = Net(12, 2)
Expand Down Expand Up @@ -133,13 +136,15 @@ def forward(self, x):
return y


@dy2static_unittest
class TestSaveInEvalMode(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_save_in_eval(self):
paddle.jit.enable_to_static(True)
net = LinearNet()
Expand Down Expand Up @@ -178,13 +183,15 @@ def test_save_in_eval(self):
)


@dy2static_unittest
class TestEvalAfterSave(unittest.TestCase):
def setUp(self):
self.temp_dir = tempfile.TemporaryDirectory()

def tearDown(self):
self.temp_dir.cleanup()

@ast_only_test
def test_eval_after_save(self):
x = paddle.randn((2, 10, 12)).astype('float32')
net = Net(12, 2)
Expand Down

0 comments on commit a60e744

Please sign in to comment.