@@ -42,7 +42,31 @@ def _gen_invalid_tensor_or_op_input_with_wrong_types():
42
42
yield TestCase (data = wrong_val , description = 'wrong type {}' .format (type (wrong_val )))
43
43
44
44
45
- def _gen_valid_tensor_op_objects ():
45
+ def _gen_invalid_tensor_or_op_with_graph_pairing ():
46
+ tnsr = tf .constant (1427.08 , name = 'someConstOp' )
47
+ other_graph = tf .Graph ()
48
+ op_name = tnsr .op .name
49
+
50
+ # Test get_tensor and get_op returns tensor or op contained in the same graph
51
+ yield TestCase (data = lambda : tfx .get_op (tnsr , other_graph ),
52
+ description = 'test graph from getting op fron tensor' )
53
+ yield TestCase (data = lambda : tfx .get_tensor (tnsr , other_graph ),
54
+ description = 'test graph from getting tensor from tensor' )
55
+ yield TestCase (data = lambda : tfx .get_op (tnsr .name , other_graph ),
56
+ description = 'test graph from getting op fron tensor name' )
57
+ yield TestCase (data = lambda : tfx .get_tensor (tnsr .name , other_graph ),
58
+ description = 'test graph from getting tensor from tensor name' )
59
+ yield TestCase (data = lambda : tfx .get_op (tnsr .op , other_graph ),
60
+ description = 'test graph from getting op from op' )
61
+ yield TestCase (data = lambda : tfx .get_tensor (tnsr .op , other_graph ),
62
+ description = 'test graph from getting tensor from op' )
63
+ yield TestCase (data = lambda : tfx .get_op (op_name , other_graph ),
64
+ description = 'test graph from getting op from op name' )
65
+ yield TestCase (data = lambda : tfx .get_tensor (op_name , other_graph ),
66
+ description = 'test graph from getting tensor from op name' )
67
+
68
+
69
+ def _gen_valid_tensor_op_input_combos ():
46
70
op_name = 'someConstOp'
47
71
tnsr_name = '{}:0' .format (op_name )
48
72
tnsr = tf .constant (1427.08 , name = op_name )
@@ -154,8 +178,14 @@ def test_invalid_tensor_inputs_with_wrong_types(self, data, description):
154
178
with self .assertRaises (TypeError , msg = description ):
155
179
tfx .get_tensor (data , tf .Graph ())
156
180
157
- @parameterized .expand (_gen_valid_tensor_op_objects )
181
+ @parameterized .expand (_gen_valid_tensor_op_input_combos )
158
182
def test_valid_tensor_op_object_inputs (self , data , description ):
159
183
""" Must get correct graph elements from valid graph elements or their names """
160
184
tfobj_or_name_a , tfobj_or_name_b = data
161
185
self .assertEqual (tfobj_or_name_a , tfobj_or_name_b , msg = description )
186
+
187
+ @parameterized .expand (_gen_invalid_tensor_or_op_with_graph_pairing )
188
+ def test_invalid_tensor_op_object_graph_pairing (self , data , description ):
189
+ """ Must fail when the graph element is from a different graph than the provided """
190
+ with self .assertRaises ((KeyError , AssertionError , TypeError ), msg = description ):
191
+ data ()
0 commit comments