27
27
TestCase = namedtuple ('TestCase' , ['data' , 'description' ])
28
28
29
29
30
- def _gen_graph_elems_names ():
30
+ def _gen_tensor_op_string_input_tests ():
31
31
op_name = 'someOp'
32
32
for tnsr_idx in range (17 ):
33
33
tnsr_name = '{}:{}' .format (op_name , tnsr_idx )
@@ -37,12 +37,12 @@ def _gen_graph_elems_names():
37
37
description = 'must get the tensor name from its operation' )
38
38
39
39
40
- def _gen_wrong_graph_elems_types ():
40
+ def _gen_invalid_tensor_op_input_with_wrong_types ():
41
41
for wrong_val in [7 , 1.2 , tf .Graph ()]:
42
42
yield TestCase (data = wrong_val , description = 'wrong type {}' .format (type (wrong_val )))
43
43
44
44
45
- def _gen_graph_elems ():
45
+ def _gen_valid_tensor_op_objects ():
46
46
op_name = 'someConstOp'
47
47
tnsr_name = '{}:0' .format (op_name )
48
48
tnsr = tf .constant (1427.08 , name = op_name )
@@ -67,25 +67,25 @@ def _gen_graph_elems():
67
67
68
68
69
69
class TFeXtensionGraphUtilsTest (PythonUnitTestCase ):
70
- @parameterized .expand (_gen_graph_elems_names )
70
+ @parameterized .expand (_gen_tensor_op_string_input_tests )
71
71
def test_valid_graph_element_names (self , data , description ):
72
72
""" Must get correct names from valid graph element names """
73
73
name_a , name_b = data
74
74
self .assertEqual (name_a , name_b , msg = description )
75
75
76
- @parameterized .expand (_gen_wrong_graph_elems_types )
77
- def test_wrong_op_types (self , data , description ):
76
+ @parameterized .expand (_gen_invalid_tensor_op_input_with_wrong_types )
77
+ def test_wrong_tensor_types (self , data , description ):
78
78
""" Must fail when provided wrong types """
79
79
with self .assertRaises (TypeError ):
80
- tfx .op_name (data , msg = description )
80
+ tfx .tensor_name (data , msg = description )
81
81
82
- @parameterized .expand (_gen_wrong_graph_elems_types )
82
+ @parameterized .expand (_gen_invalid_tensor_op_input_with_wrong_types )
83
83
def test_wrong_op_types (self , data , description ):
84
84
""" Must fail when provided wrong types """
85
85
with self .assertRaises (TypeError ):
86
86
tfx .op_name (data , msg = description )
87
87
88
- @parameterized .expand (_gen_graph_elems )
88
+ @parameterized .expand (_gen_valid_tensor_op_objects )
89
89
def test_get_graph_elements (self , data , description ):
90
90
""" Must get correct graph elements from valid graph elements or their names """
91
91
tfobj_or_name_a , tfobj_or_name_b = data
0 commit comments