Skip to content
This repository was archived by the owner on May 15, 2023. It is now read-only.

Commit 9b82a36

Browse files
committed
Merge remote-tracking branch 'origin/master' into local
2 parents a86f189 + 9ac02ea commit 9b82a36

File tree

2 files changed

+152
-0
lines changed

2 files changed

+152
-0
lines changed

deepmath/deephol/BUILD

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,3 +42,16 @@ py_library(
4242
"//third_party/py/numpy",
4343
],
4444
)
45+
46+
py_test(
47+
name = "predictions_test",
48+
size = "medium",
49+
srcs = ["predictions_test.py"],
50+
deps = [
51+
":mock_predictions_lib",
52+
":predictions",
53+
"@absl_py//absl/testing:parameterized",
54+
"//third_party/py/numpy",
55+
"@org_tensorflow//tensorflow:tensorflow_py",
56+
],
57+
)

deepmath/deephol/predictions_test.py

Lines changed: 139 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,139 @@
1+
"""Tests for predictions."""
2+
from __future__ import absolute_import
3+
from __future__ import division
4+
from __future__ import print_function
5+
from absl.testing import parameterized
6+
import numpy as np
7+
import tensorflow as tf
8+
from deepmath.deephol import mock_predictions_lib
9+
from deepmath.deephol import predictions
10+
11+
TEST_ARRAY = np.reshape(np.arange(100), (10, 10)).astype(float)
12+
MOCK_PREDICTOR = mock_predictions_lib.MockPredictionsLib
13+
14+
15+
def double(x):
16+
if x is None:
17+
return x
18+
else:
19+
return 2 * x
20+
21+
22+
class PredictionsTest(tf.test.TestCase, parameterized.TestCase):
23+
24+
def test_batch_array_with_none(self):
25+
result = predictions.batch_array(TEST_ARRAY, None)
26+
self.assertEqual(len(result), 1)
27+
self.assertAllEqual(TEST_ARRAY, result[0])
28+
29+
def test_batch_array_with_batch_size_1(self):
30+
result = predictions.batch_array(TEST_ARRAY, 1)
31+
self.assertEqual(len(result), 10)
32+
for i in range(10):
33+
self.assertAllEqual(np.expand_dims(TEST_ARRAY[i, :], 0), result[i])
34+
35+
def test_batch_array_with_batch_size_3(self):
36+
result = predictions.batch_array(TEST_ARRAY, 3)
37+
expected = [
38+
TEST_ARRAY[:3, :], TEST_ARRAY[3:6, :], TEST_ARRAY[6:9, :],
39+
TEST_ARRAY[9:, :]
40+
]
41+
self.assertEqual(len(result), len(expected))
42+
for i in range(len(expected)):
43+
self.assertAllEqual(expected[i], result[i])
44+
45+
def test_batch_array_with_batch_size_10(self):
46+
result = predictions.batch_array(TEST_ARRAY, 10)
47+
self.assertEqual(len(result), 1)
48+
self.assertAllEqual(TEST_ARRAY, result[0])
49+
50+
def test_batch_array_with_batch_size_15(self):
51+
result = predictions.batch_array(TEST_ARRAY, 15)
52+
self.assertEqual(len(result), 1)
53+
self.assertAllEqual(TEST_ARRAY, result[0])
54+
55+
def test_batch_array_strlist_with_batch_size_3(self):
56+
strlist = [str(i) for i in range(10)]
57+
result = predictions.batch_array(strlist, 3)
58+
expected = [strlist[:3], strlist[3:6], strlist[6:9], [strlist[9]]]
59+
print('result:', result)
60+
self.assertEqual(len(expected), len(result))
61+
for i in range(len(expected)):
62+
self.assertAllEqual(expected[i], result[i])
63+
64+
def test_batch_array_strlist_with_batch_size_none(self):
65+
strlist = [str(i) for i in range(10)]
66+
result = predictions.batch_array(strlist, None)
67+
self.assertEqual(len(result), 1)
68+
self.assertAllEqual(result[0], strlist)
69+
70+
@parameterized.parameters(1, 2, 3, 10, 15, None)
71+
def test_batched_run_identity(self, max_batch_size):
72+
result = predictions.batched_run([TEST_ARRAY], lambda x: x, max_batch_size)
73+
self.assertAllEqual(result, TEST_ARRAY)
74+
75+
@parameterized.parameters(1, 2, 3, 10, 15, None)
76+
def test_batched_run_add(self, max_batch_size):
77+
result = predictions.batched_run(
78+
[TEST_ARRAY, TEST_ARRAY], lambda x, y: x + y, max_batch_size)
79+
self.assertAllEqual(result, 2.0 * TEST_ARRAY)
80+
81+
@parameterized.parameters(1, 2, 3, 10, 15, None)
82+
def test_batched_run_str_to_int_and_back(self, max_batch_size):
83+
strlist = [str(i) for i in range(10)]
84+
result = predictions.batched_run(
85+
[strlist], lambda l: np.array([[float(x)] for x in l]), max_batch_size)
86+
self.assertAllEqual(result, [[float(i)] for i in range(10)])
87+
88+
@parameterized.parameters(1, 2, 3, 10, 15, None)
89+
def test_predict_goal_embedding(self, max_batch_size):
90+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
91+
self.assertAllEqual(
92+
predictor.goal_embedding('goal'),
93+
predictor.batch_goal_embedding(['goal'])[0])
94+
95+
@parameterized.parameters(1, 2, 3, 10, 15, None)
96+
def test_predict_thm_embedding(self, max_batch_size):
97+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
98+
self.assertAllEqual(
99+
predictor.thm_embedding('thm'),
100+
predictor.batch_thm_embedding(['thm'])[0])
101+
102+
@parameterized.parameters(1, 2, 3, 10, 15, None)
103+
def test_predict_batch_goal_embedding(self, max_batch_size):
104+
strlist = [str(i) for i in range(10)]
105+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
106+
self.assertAllEqual(
107+
predictor.batch_goal_embedding(strlist),
108+
predictor._batch_goal_embedding(strlist))
109+
110+
@parameterized.parameters(1, 2, 3, 10, 15, None)
111+
def test_predict_batch_thm_embedding(self, max_batch_size):
112+
strlist = [str(i) for i in range(10)]
113+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
114+
self.assertAllEqual(
115+
predictor.batch_thm_embedding(strlist),
116+
predictor._batch_thm_embedding(strlist))
117+
118+
@parameterized.parameters(1, 2, 3, 10, 15, None)
119+
def test_predict_batch_tactic_scores(self, max_batch_size):
120+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
121+
self.assertAllEqual(
122+
predictor.batch_tactic_scores(TEST_ARRAY),
123+
predictor._batch_tactic_scores(TEST_ARRAY))
124+
125+
@parameterized.parameters(1, 2, 3, 10, 15, None)
126+
def test_predict_batch_thm_scores(self, max_batch_size):
127+
predictor = MOCK_PREDICTOR(max_batch_size, double(max_batch_size))
128+
state = np.arange(10)
129+
dup_state = np.tile(np.arange(10), [10, 1])
130+
self.assertAllEqual(
131+
predictor.batch_thm_scores(state, TEST_ARRAY),
132+
predictor._batch_thm_scores(dup_state, TEST_ARRAY))
133+
self.assertAllEqual(
134+
predictor.batch_thm_scores(state, TEST_ARRAY, tactic_id=4),
135+
predictor._batch_thm_scores(dup_state, TEST_ARRAY, tactic_id=4))
136+
137+
138+
if __name__ == '__main__':
139+
tf.test.main()

0 commit comments

Comments
 (0)