Skip to content

Commit f0912fb

Browse files
committed
test update cont'd
1 parent 742cdaf commit f0912fb

File tree

1 file changed

+53
-24
lines changed

1 file changed

+53
-24
lines changed

python/tests/graph/test_utils.py

Lines changed: 53 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -29,15 +29,15 @@
2929

3030
def _gen_tensor_op_string_input_tests():
3131
op_name = 'someOp'
32-
for tnsr_idx in range(17):
32+
for tnsr_idx in [0, 1, 2, 3, 5, 8, 15, 17]:
3333
tnsr_name = '{}:{}'.format(op_name, tnsr_idx)
3434
yield TestCase(data=(op_name, tfx.op_name(tnsr_name)),
35-
description='must get the same op name from its tensor')
35+
description='test tensor name to op name')
3636
yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name)),
37-
description='must get the tensor name from its operation')
37+
description='test tensor name to tensor name')
3838

3939

40-
def _gen_invalid_tensor_op_input_with_wrong_types():
40+
def _gen_invalid_tensor_or_op_input_with_wrong_types():
4141
for wrong_val in [7, 1.2, tf.Graph()]:
4242
yield TestCase(data=wrong_val, description='wrong type {}'.format(type(wrong_val)))
4343

@@ -48,6 +48,7 @@ def _gen_valid_tensor_op_objects():
4848
tnsr = tf.constant(1427.08, name=op_name)
4949
graph = tnsr.graph
5050

51+
# Test for op_name
5152
yield TestCase(data=(op_name, tfx.op_name(tnsr)),
5253
description='get op name from tensor (no graph)')
5354
yield TestCase(data=(op_name, tfx.op_name(tnsr, graph)),
@@ -65,6 +66,7 @@ def _gen_valid_tensor_op_objects():
6566
yield TestCase(data=(op_name, tfx.op_name(op_name, graph)),
6667
description='get op name from op name (with graph)')
6768

69+
# Test for tensor_name
6870
yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr)),
6971
description='get tensor name from tensor (no graph)')
7072
yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr, graph)),
@@ -82,51 +84,78 @@ def _gen_valid_tensor_op_objects():
8284
yield TestCase(data=(tnsr_name, tfx.tensor_name(tnsr_name, graph)),
8385
description='get tensor name from op name (with graph)')
8486

87+
# Test for get_tensor
8588
yield TestCase(data=(tnsr, tfx.get_tensor(tnsr, graph)),
86-
description='get tensor from tensor (with graph)')
89+
description='get tensor from tensor')
8790
yield TestCase(data=(tnsr, tfx.get_tensor(tnsr_name, graph)),
88-
description='get tensor from tensor name (with graph)')
91+
description='get tensor from tensor name')
8992
yield TestCase(data=(tnsr, tfx.get_tensor(tnsr.op, graph)),
90-
description='get tensor from op (with graph)')
93+
description='get tensor from op')
9194
yield TestCase(data=(tnsr, tfx.get_tensor(op_name, graph)),
92-
description='get tensor from op name (with graph)')
95+
description='get tensor from op name')
9396

97+
# Test for get_op
9498
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr, graph)),
95-
description='get op from tensor (with graph)')
99+
description='get op from tensor')
96100
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr_name, graph)),
97-
description='get op from tensor name (with graph)')
101+
description='get op from tensor name')
98102
yield TestCase(data=(tnsr.op, tfx.get_op(tnsr.op, graph)),
99-
description='get op from op (with graph)')
103+
description='get op from op')
100104
yield TestCase(data=(tnsr.op, tfx.get_op(op_name, graph)),
101-
description='get op from op name (with graph)')
105+
description='test op from op name')
102106

107+
# Test get_tensor and get_op returns tensor or op contained in the same graph
103108
yield TestCase(data=(graph, tfx.get_op(tnsr, graph).graph),
104-
description='get graph from retrieved op (with graph)')
109+
description='test graph from getting op fron tensor')
105110
yield TestCase(data=(graph, tfx.get_tensor(tnsr, graph).graph),
106-
description='get graph from retrieved tensor (with graph)')
111+
description='test graph from getting tensor from tensor')
112+
yield TestCase(data=(graph, tfx.get_op(tnsr_name, graph).graph),
113+
description='test graph from getting op fron tensor name')
114+
yield TestCase(data=(graph, tfx.get_tensor(tnsr_name, graph).graph),
115+
description='test graph from getting tensor from tensor name')
116+
yield TestCase(data=(graph, tfx.get_op(tnsr.op, graph).graph),
117+
description='test graph from getting op from op')
118+
yield TestCase(data=(graph, tfx.get_tensor(tnsr.op, graph).graph),
119+
description='test graph from getting tensor from op')
120+
yield TestCase(data=(graph, tfx.get_op(op_name, graph).graph),
121+
description='test graph from getting op from op name')
122+
yield TestCase(data=(graph, tfx.get_tensor(op_name, graph).graph),
123+
description='test graph from getting tensor from op name')
107124

108125

109126
class TFeXtensionGraphUtilsTest(PythonUnitTestCase):
110127
@parameterized.expand(_gen_tensor_op_string_input_tests)
111-
def test_valid_graph_element_names(self, data, description):
128+
def test_valid_tensor_op_name_inputs(self, data, description):
112129
""" Must get correct names from valid graph element names """
113130
name_a, name_b = data
114131
self.assertEqual(name_a, name_b, msg=description)
115132

116-
@parameterized.expand(_gen_invalid_tensor_op_input_with_wrong_types)
117-
def test_wrong_tensor_types(self, data, description):
133+
@parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types)
134+
def test_invalid_tensor_name_inputs_with_wrong_types(self, data, description):
118135
""" Must fail when provided wrong types """
119-
with self.assertRaises(TypeError):
120-
tfx.tensor_name(data, msg=description)
136+
with self.assertRaises(TypeError, msg=description):
137+
tfx.tensor_name(data)
121138

122-
@parameterized.expand(_gen_invalid_tensor_op_input_with_wrong_types)
123-
def test_wrong_op_types(self, data, description):
139+
@parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types)
140+
def test_invalid_op_name_inputs_with_wrong_types(self, data, description):
124141
""" Must fail when provided wrong types """
125-
with self.assertRaises(TypeError):
126-
tfx.op_name(data, msg=description)
142+
with self.assertRaises(TypeError, msg=description):
143+
tfx.op_name(data)
144+
145+
@parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types)
146+
def test_invalid_op_inputs_with_wrong_types(self, data, description):
147+
""" Must fail when provided wrong types """
148+
with self.assertRaises(TypeError, msg=description):
149+
tfx.get_op(data, tf.Graph())
150+
151+
@parameterized.expand(_gen_invalid_tensor_or_op_input_with_wrong_types)
152+
def test_invalid_tensor_inputs_with_wrong_types(self, data, description):
153+
""" Must fail when provided wrong types """
154+
with self.assertRaises(TypeError, msg=description):
155+
tfx.get_tensor(data, tf.Graph())
127156

128157
@parameterized.expand(_gen_valid_tensor_op_objects)
129-
def test_get_graph_elements(self, data, description):
158+
def test_valid_tensor_op_object_inputs(self, data, description):
130159
""" Must get correct graph elements from valid graph elements or their names """
131160
tfobj_or_name_a, tfobj_or_name_b = data
132161
self.assertEqual(tfobj_or_name_a, tfobj_or_name_b, msg=description)

0 commit comments

Comments
 (0)