|
| 1 | +"""Tests for predictions.""" |
| 2 | +from __future__ import absolute_import |
| 3 | +from __future__ import division |
| 4 | +from __future__ import print_function |
| 5 | +from absl.testing import parameterized |
| 6 | +import numpy as np |
| 7 | +import tensorflow as tf |
| 8 | +from deepmath.deephol import mock_predictions_lib |
| 9 | +from deepmath.deephol import predictions |
| 10 | + |
| 11 | +TEST_ARRAY = np.reshape(np.arange(100), (10, 10)).astype(float) |
| 12 | +MOCK_PREDICTOR = mock_predictions_lib.MockPredictionsLib |
| 13 | + |
| 14 | + |
| 15 | +def double(x): |
| 16 | + if x is None: |
| 17 | + return x |
| 18 | + else: |
| 19 | + return 2 * x |
| 20 | + |
| 21 | + |
| 22 | +class PredictionsTest(tf.test.TestCase, parameterized.TestCase): |
| 23 | + |
| 24 | + def test_batch_array_with_none(self): |
| 25 | + result = predictions.batch_array(TEST_ARRAY, None) |
| 26 | + self.assertEqual(len(result), 1) |
| 27 | + self.assertAllEqual(TEST_ARRAY, result[0]) |
| 28 | + |
| 29 | + def test_batch_array_with_batch_size_1(self): |
| 30 | + result = predictions.batch_array(TEST_ARRAY, 1) |
| 31 | + self.assertEqual(len(result), 10) |
| 32 | + for i in range(10): |
| 33 | + self.assertAllEqual(np.expand_dims(TEST_ARRAY[i, :], 0), result[i]) |
| 34 | + |
| 35 | + def test_batch_array_with_batch_size_3(self): |
| 36 | + result = predictions.batch_array(TEST_ARRAY, 3) |
| 37 | + expected = [ |
| 38 | + TEST_ARRAY[:3, :], TEST_ARRAY[3:6, :], TEST_ARRAY[6:9, :], |
| 39 | + TEST_ARRAY[9:, :] |
| 40 | + ] |
| 41 | + self.assertEqual(len(result), len(expected)) |
| 42 | + for i in range(len(expected)): |
| 43 | + self.assertAllEqual(expected[i], result[i]) |
| 44 | + |
| 45 | + def test_batch_array_with_batch_size_10(self): |
| 46 | + result = predictions.batch_array(TEST_ARRAY, 10) |
| 47 | + self.assertEqual(len(result), 1) |
| 48 | + self.assertAllEqual(TEST_ARRAY, result[0]) |
| 49 | + |
| 50 | + def test_batch_array_with_batch_size_15(self): |
| 51 | + result = predictions.batch_array(TEST_ARRAY, 15) |
| 52 | + self.assertEqual(len(result), 1) |
| 53 | + self.assertAllEqual(TEST_ARRAY, result[0]) |
| 54 | + |
| 55 | + def test_batch_array_strlist_with_batch_size_3(self): |
| 56 | + strlist = [str(i) for i in range(10)] |
| 57 | + result = predictions.batch_array(strlist, 3) |
| 58 | + expected = [strlist[:3], strlist[3:6], strlist[6:9], [strlist[9]]] |
| 59 | + print('result:', result) |
| 60 | + self.assertEqual(len(expected), len(result)) |
| 61 | + for i in range(len(expected)): |
| 62 | + self.assertAllEqual(expected[i], result[i]) |
| 63 | + |
| 64 | + def test_batch_array_strlist_with_batch_size_none(self): |
| 65 | + strlist = [str(i) for i in range(10)] |
| 66 | + result = predictions.batch_array(strlist, None) |
| 67 | + self.assertEqual(len(result), 1) |
| 68 | + self.assertAllEqual(result[0], strlist) |
| 69 | + |
| 70 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 71 | + def test_batched_run_identity(self, max_batch_size): |
| 72 | + result = predictions.batched_run([TEST_ARRAY], lambda x: x, max_batch_size) |
| 73 | + self.assertAllEqual(result, TEST_ARRAY) |
| 74 | + |
| 75 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 76 | + def test_batched_run_add(self, max_batch_size): |
| 77 | + result = predictions.batched_run( |
| 78 | + [TEST_ARRAY, TEST_ARRAY], lambda x, y: x + y, max_batch_size) |
| 79 | + self.assertAllEqual(result, 2.0 * TEST_ARRAY) |
| 80 | + |
| 81 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 82 | + def test_batched_run_str_to_int_and_back(self, max_batch_size): |
| 83 | + strlist = [str(i) for i in range(10)] |
| 84 | + result = predictions.batched_run( |
| 85 | + [strlist], lambda l: np.array([[float(x)] for x in l]), max_batch_size) |
| 86 | + self.assertAllEqual(result, [[float(i)] for i in range(10)]) |
| 87 | + |
| 88 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 89 | + def test_predict_goal_embedding(self, max_batch_size): |
| 90 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 91 | + self.assertAllEqual( |
| 92 | + predictor.goal_embedding('goal'), |
| 93 | + predictor.batch_goal_embedding(['goal'])[0]) |
| 94 | + |
| 95 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 96 | + def test_predict_thm_embedding(self, max_batch_size): |
| 97 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 98 | + self.assertAllEqual( |
| 99 | + predictor.thm_embedding('thm'), |
| 100 | + predictor.batch_thm_embedding(['thm'])[0]) |
| 101 | + |
| 102 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 103 | + def test_predict_batch_goal_embedding(self, max_batch_size): |
| 104 | + strlist = [str(i) for i in range(10)] |
| 105 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 106 | + self.assertAllEqual( |
| 107 | + predictor.batch_goal_embedding(strlist), |
| 108 | + predictor._batch_goal_embedding(strlist)) |
| 109 | + |
| 110 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 111 | + def test_predict_batch_thm_embedding(self, max_batch_size): |
| 112 | + strlist = [str(i) for i in range(10)] |
| 113 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 114 | + self.assertAllEqual( |
| 115 | + predictor.batch_thm_embedding(strlist), |
| 116 | + predictor._batch_thm_embedding(strlist)) |
| 117 | + |
| 118 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 119 | + def test_predict_batch_tactic_scores(self, max_batch_size): |
| 120 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 121 | + self.assertAllEqual( |
| 122 | + predictor.batch_tactic_scores(TEST_ARRAY), |
| 123 | + predictor._batch_tactic_scores(TEST_ARRAY)) |
| 124 | + |
| 125 | + @parameterized.parameters(1, 2, 3, 10, 15, None) |
| 126 | + def test_predict_batch_thm_scores(self, max_batch_size): |
| 127 | + predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size)) |
| 128 | + state = np.arange(10) |
| 129 | + dup_state = np.tile(np.arange(10), [10, 1]) |
| 130 | + self.assertAllEqual( |
| 131 | + predictor.batch_thm_scores(state, TEST_ARRAY), |
| 132 | + predictor._batch_thm_scores(dup_state, TEST_ARRAY)) |
| 133 | + self.assertAllEqual( |
| 134 | + predictor.batch_thm_scores(state, TEST_ARRAY, tactic_id=4), |
| 135 | + predictor._batch_thm_scores(dup_state, TEST_ARRAY, tactic_id=4)) |
| 136 | + |
| 137 | + |
| 138 | +if __name__ == '__main__': |
| 139 | + tf.test.main() |
0 commit comments