29
29
30
30
def _gen_tensor_op_string_input_tests ():
31
31
op_name = 'someOp'
32
- for tnsr_idx in range ( 17 ) :
32
+ for tnsr_idx in [ 0 , 1 , 2 , 3 , 5 , 8 , 15 , 17 ] :
33
33
tnsr_name = '{}:{}' .format (op_name , tnsr_idx )
34
34
yield TestCase (data = (op_name , tfx .op_name (tnsr_name )),
35
- description = 'must get the same op name from its tensor ' )
35
+ description = 'test tensor name to op name' )
36
36
yield TestCase (data = (tnsr_name , tfx .tensor_name (tnsr_name )),
37
- description = 'must get the tensor name from its operation ' )
37
+ description = 'test tensor name to tensor name ' )
38
38
39
39
40
- def _gen_invalid_tensor_op_input_with_wrong_types ():
40
+ def _gen_invalid_tensor_or_op_input_with_wrong_types ():
41
41
for wrong_val in [7 , 1.2 , tf .Graph ()]:
42
42
yield TestCase (data = wrong_val , description = 'wrong type {}' .format (type (wrong_val )))
43
43
@@ -48,6 +48,7 @@ def _gen_valid_tensor_op_objects():
48
48
tnsr = tf .constant (1427.08 , name = op_name )
49
49
graph = tnsr .graph
50
50
51
+ # Test for op_name
51
52
yield TestCase (data = (op_name , tfx .op_name (tnsr )),
52
53
description = 'get op name from tensor (no graph)' )
53
54
yield TestCase (data = (op_name , tfx .op_name (tnsr , graph )),
@@ -65,6 +66,7 @@ def _gen_valid_tensor_op_objects():
65
66
yield TestCase (data = (op_name , tfx .op_name (op_name , graph )),
66
67
description = 'get op name from op name (with graph)' )
67
68
69
+ # Test for tensor_name
68
70
yield TestCase (data = (tnsr_name , tfx .tensor_name (tnsr )),
69
71
description = 'get tensor name from tensor (no graph)' )
70
72
yield TestCase (data = (tnsr_name , tfx .tensor_name (tnsr , graph )),
@@ -82,51 +84,78 @@ def _gen_valid_tensor_op_objects():
82
84
yield TestCase (data = (tnsr_name , tfx .tensor_name (tnsr_name , graph )),
83
85
description = 'get tensor name from op name (with graph)' )
84
86
87
+ # Test for get_tensor
85
88
yield TestCase (data = (tnsr , tfx .get_tensor (tnsr , graph )),
86
- description = 'get tensor from tensor (with graph) ' )
89
+ description = 'get tensor from tensor' )
87
90
yield TestCase (data = (tnsr , tfx .get_tensor (tnsr_name , graph )),
88
- description = 'get tensor from tensor name (with graph) ' )
91
+ description = 'get tensor from tensor name' )
89
92
yield TestCase (data = (tnsr , tfx .get_tensor (tnsr .op , graph )),
90
- description = 'get tensor from op (with graph) ' )
93
+ description = 'get tensor from op' )
91
94
yield TestCase (data = (tnsr , tfx .get_tensor (op_name , graph )),
92
- description = 'get tensor from op name (with graph) ' )
95
+ description = 'get tensor from op name' )
93
96
97
+ # Test for get_op
94
98
yield TestCase (data = (tnsr .op , tfx .get_op (tnsr , graph )),
95
- description = 'get op from tensor (with graph) ' )
99
+ description = 'get op from tensor' )
96
100
yield TestCase (data = (tnsr .op , tfx .get_op (tnsr_name , graph )),
97
- description = 'get op from tensor name (with graph) ' )
101
+ description = 'get op from tensor name' )
98
102
yield TestCase (data = (tnsr .op , tfx .get_op (tnsr .op , graph )),
99
- description = 'get op from op (with graph) ' )
103
+ description = 'get op from op' )
100
104
yield TestCase (data = (tnsr .op , tfx .get_op (op_name , graph )),
101
- description = 'get op from op name (with graph) ' )
105
+ description = 'test op from op name' )
102
106
107
+ # Test get_tensor and get_op returns tensor or op contained in the same graph
103
108
yield TestCase (data = (graph , tfx .get_op (tnsr , graph ).graph ),
104
- description = 'get graph from retrieved op (with graph) ' )
109
+ description = 'test graph from getting op fron tensor ' )
105
110
yield TestCase (data = (graph , tfx .get_tensor (tnsr , graph ).graph ),
106
- description = 'get graph from retrieved tensor (with graph)' )
111
+ description = 'test graph from getting tensor from tensor' )
112
+ yield TestCase (data = (graph , tfx .get_op (tnsr_name , graph ).graph ),
113
+ description = 'test graph from getting op fron tensor name' )
114
+ yield TestCase (data = (graph , tfx .get_tensor (tnsr_name , graph ).graph ),
115
+ description = 'test graph from getting tensor from tensor name' )
116
+ yield TestCase (data = (graph , tfx .get_op (tnsr .op , graph ).graph ),
117
+ description = 'test graph from getting op from op' )
118
+ yield TestCase (data = (graph , tfx .get_tensor (tnsr .op , graph ).graph ),
119
+ description = 'test graph from getting tensor from op' )
120
+ yield TestCase (data = (graph , tfx .get_op (op_name , graph ).graph ),
121
+ description = 'test graph from getting op from op name' )
122
+ yield TestCase (data = (graph , tfx .get_tensor (op_name , graph ).graph ),
123
+ description = 'test graph from getting tensor from op name' )
107
124
108
125
109
126
class TFeXtensionGraphUtilsTest (PythonUnitTestCase ):
110
127
@parameterized .expand (_gen_tensor_op_string_input_tests )
111
- def test_valid_graph_element_names (self , data , description ):
128
+ def test_valid_tensor_op_name_inputs (self , data , description ):
112
129
""" Must get correct names from valid graph element names """
113
130
name_a , name_b = data
114
131
self .assertEqual (name_a , name_b , msg = description )
115
132
116
- @parameterized .expand (_gen_invalid_tensor_op_input_with_wrong_types )
117
- def test_wrong_tensor_types (self , data , description ):
133
+ @parameterized .expand (_gen_invalid_tensor_or_op_input_with_wrong_types )
134
+ def test_invalid_tensor_name_inputs_with_wrong_types (self , data , description ):
118
135
""" Must fail when provided wrong types """
119
- with self .assertRaises (TypeError ):
120
- tfx .tensor_name (data , msg = description )
136
+ with self .assertRaises (TypeError , msg = description ):
137
+ tfx .tensor_name (data )
121
138
122
- @parameterized .expand (_gen_invalid_tensor_op_input_with_wrong_types )
123
- def test_wrong_op_types (self , data , description ):
139
+ @parameterized .expand (_gen_invalid_tensor_or_op_input_with_wrong_types )
140
+ def test_invalid_op_name_inputs_with_wrong_types (self , data , description ):
124
141
""" Must fail when provided wrong types """
125
- with self .assertRaises (TypeError ):
126
- tfx .op_name (data , msg = description )
142
+ with self .assertRaises (TypeError , msg = description ):
143
+ tfx .op_name (data )
144
+
145
+ @parameterized .expand (_gen_invalid_tensor_or_op_input_with_wrong_types )
146
+ def test_invalid_op_inputs_with_wrong_types (self , data , description ):
147
+ """ Must fail when provided wrong types """
148
+ with self .assertRaises (TypeError , msg = description ):
149
+ tfx .get_op (data , tf .Graph ())
150
+
151
+ @parameterized .expand (_gen_invalid_tensor_or_op_input_with_wrong_types )
152
+ def test_invalid_tensor_inputs_with_wrong_types (self , data , description ):
153
+ """ Must fail when provided wrong types """
154
+ with self .assertRaises (TypeError , msg = description ):
155
+ tfx .get_tensor (data , tf .Graph ())
127
156
128
157
@parameterized .expand (_gen_valid_tensor_op_objects )
129
- def test_get_graph_elements (self , data , description ):
158
+ def test_valid_tensor_op_object_inputs (self , data , description ):
130
159
""" Must get correct graph elements from valid graph elements or their names """
131
160
tfobj_or_name_a , tfobj_or_name_b = data
132
161
self .assertEqual (tfobj_or_name_a , tfobj_or_name_b , msg = description )
0 commit comments