Skip to content

Commit 3bbff25

Browse files
authored
Merge pull request #7453 from reyoung/feature/fix_seed_for_dynrnn_test
Fix random seed of dynamic rnn gradient check
2 parents 8d253e4 + e5b2378 commit 3bbff25

File tree

1 file changed

+19
-2
lines changed

1 file changed

+19
-2
lines changed

python/paddle/v2/fluid/tests/test_dynrnn_gradient_check.py

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,24 @@ def _exe_mean_out_(self):
197197
return numpy.array([o.mean() for o in outs.itervalues()]).mean()
198198

199199

200-
class TestSimpleMul(unittest.TestCase):
200+
class SeedFixedTestCase(unittest.TestCase):
201+
@classmethod
202+
def setUpClass(cls):
203+
"""Fix random seeds to remove randomness from tests"""
204+
cls._np_rand_state = numpy.random.get_state()
205+
cls._py_rand_state = random.getstate()
206+
207+
numpy.random.seed(123)
208+
random.seed(124)
209+
210+
@classmethod
211+
def tearDownClass(cls):
212+
"""Restore random seeds"""
213+
numpy.random.set_state(cls._np_rand_state)
214+
random.setstate(cls._py_rand_state)
215+
216+
217+
class TestSimpleMul(SeedFixedTestCase):
201218
DATA_NAME = 'X'
202219
DATA_WIDTH = 32
203220
PARAM_NAME = 'W'
@@ -263,7 +280,7 @@ def test_forward_backward(self):
263280
self.assertTrue(numpy.allclose(i_g_num, i_g, rtol=0.05))
264281

265282

266-
class TestSimpleMulWithMemory(unittest.TestCase):
283+
class TestSimpleMulWithMemory(SeedFixedTestCase):
267284
DATA_WIDTH = 32
268285
HIDDEN_WIDTH = 20
269286
DATA_NAME = 'X'

0 commit comments

Comments
 (0)