15
15
from __future__ import absolute_import , division , print_function
16
16
17
17
from collections import namedtuple
18
- from six import with_metaclass
18
+ # Use this to create parameterized test cases
19
+ from parameterized import parameterized
19
20
20
21
from sparkdl .param .converters import SparkDLTypeConverters
21
22
22
23
from ..tests import PythonUnitTestCase
23
24
24
-
25
- class TestGenMeta (type ):
26
- """
27
- This meta-class add test cases to the main unit-test class.
28
- To add test cases, implement the test logic in a function
29
- >>> `def _my_test_impl(): ...`
30
- then call the following function
31
- >>> `_register_test_case(fn_impl=_my_test_impl, name=..., doc=...)`
32
- """
33
- def __new__ (mcs , name , bases , attrs ):
34
- _add_invalid_col2tnsr_mapping_tests ()
35
- attrs .update (_TEST_FUNCTIONS_REGISTRY )
36
- return super (TestGenMeta , mcs ).__new__ (mcs , name , bases , attrs )
37
-
38
-
39
- # Stores test function name mapped to implementation body
40
- _TEST_FUNCTIONS_REGISTRY = {}
41
-
42
25
TestCase = namedtuple ('TestCase' , ['data' , 'description' ])
43
26
44
- def _register_test_case (fn_impl , name , doc ):
45
- """ Add an individual test case """
46
- fn_impl .__name__ = name
47
- fn_impl .__doc__ = doc
48
- _TEST_FUNCTIONS_REGISTRY [name ] = fn_impl
49
-
50
- def _add_invalid_col2tnsr_mapping_tests ():
51
- """ Create a list of test cases and construct individual test functions for each case """
52
- shared_test_cases = [
53
- TestCase (data = ['a1' , 'b2' ], description = 'required pair but get single element' ),
54
- TestCase (data = ('c3' , 'd4' ), description = 'required pair but get single element' ),
55
- TestCase (data = [('a' , 1 ), ('b' , 2 )], description = 'only accept dict, but get list' ),
56
- TestCase (data = {1 : 'a' , 2.0 : 'b' }, description = 'wrong mapping type' ),
57
- TestCase (data = {'a' : 1.0 , 'b' : 2 }, description = 'wrong mapping type' ),
58
- ]
59
-
60
- # Specify test cases for `asColumnToTensorNameMap`
61
- # Add additional test cases specific to this one
62
- col2tnsr_test_cases = shared_test_cases + [
63
- TestCase (data = {'colA' : 'tnsrOpA' , 'colB' : 'tnsrOpB' },
64
- description = 'strict tensor name required' ),
65
- ]
66
- _fn_name_template = 'test_invalid_col2tnsr_{idx}'
67
- _fn_doc_template = 'Test invalid column => tensor name mapping: {description}'
68
-
69
- for idx , test_case in enumerate (col2tnsr_test_cases ):
70
- # Add the actual test logic here
71
- def test_fn_impl (self ):
72
- with self .assertRaises (TypeError , msg = test_case .description ):
73
- SparkDLTypeConverters .asColumnToTensorNameMap (test_case .data )
74
-
75
- _name = _fn_name_template .format (idx = idx )
76
- _doc = _fn_doc_template .format (description = test_case .description )
77
- _register_test_case (fn_impl = test_fn_impl , name = _name , doc = _doc )
78
-
79
-
80
- # Specify tests for `asTensorNameToColumnMap`
81
- tnsr2col_test_cases = shared_test_cases + [
82
- TestCase (data = {'tnsrOpA' : 'colA' , 'tnsrOpB' : 'colB' },
83
- description = 'strict tensor name required' ),
84
- ]
85
- _fn_name_template = 'test_invalid_tnsr2col_{idx}'
86
- _fn_doc_template = 'Test invalid tensor name => column mapping: {description}'
87
-
88
- for idx , test_case in enumerate (tnsr2col_test_cases ):
89
- # Add the actual test logic here
90
- def test_fn_impl (self ): # pylint: disable=function-redefined
91
- with self .assertRaises (TypeError , msg = test_case .description ):
92
- SparkDLTypeConverters .asTensorNameToColumnMap (test_case .data )
93
-
94
- _name = _fn_name_template .format (idx = idx )
95
- _doc = _fn_doc_template .format (description = test_case .description )
96
- _register_test_case (fn_impl = test_fn_impl , name = _name , doc = _doc )
97
-
98
-
99
- class ParamsConverterTest (with_metaclass (TestGenMeta , PythonUnitTestCase )):
27
+ _shared_invalid_test_cases = [
28
+ TestCase (data = ['a1' , 'b2' ], description = 'required pair but get single element' ),
29
+ TestCase (data = ('c3' , 'd4' ), description = 'required pair but get single element' ),
30
+ TestCase (data = [('a' , 1 ), ('b' , 2 )], description = 'only accept dict, but get list' ),
31
+ TestCase (data = {1 : 'a' , 2.0 : 'b' }, description = 'wrong mapping type' ),
32
+ TestCase (data = {'a' : 1.0 , 'b' : 2 }, description = 'wrong mapping type' ),
33
+ ]
34
+ _col2tnsr_test_cases = _shared_invalid_test_cases + [
35
+ TestCase (data = {'colA' : 'tnsrOpA' , 'colB' : 'tnsrOpB' },
36
+ description = 'strict tensor name required' ),
37
+ ]
38
+ _tnsr2col_test_cases = _shared_invalid_test_cases + [
39
+ TestCase (data = {'tnsrOpA' : 'colA' , 'tnsrOpB' : 'colB' },
40
+ description = 'strict tensor name required' ),
41
+ ]
42
+
43
+ class ParamsConverterTest (PythonUnitTestCase ):
100
44
"""
101
45
Test MLlib Params introduced in Spark Deep Learning Pipeline
102
46
Additional test cases are attached via the meta class `TestGenMeta`.
103
47
"""
104
- # pylint: disable=protected-access
105
-
106
- @classmethod
107
- def setUpClass (cls ):
108
- print (repr (cls ), cls )
109
48
110
49
def test_tf_input_mapping_converter (self ):
111
50
""" Test valid input mapping conversion """
@@ -122,3 +61,13 @@ def test_tf_output_mapping_converter(self):
122
61
123
62
res = SparkDLTypeConverters .asTensorNameToColumnMap (valid_tnsr_output )
124
63
self .assertEqual (valid_output_mapping_result , res )
64
+
65
+ @parameterized .expand (_col2tnsr_test_cases )
66
+ def test_invalid_input_mapping (self , data , description ):
67
+ with self .assertRaises (TypeError , msg = description ):
68
+ SparkDLTypeConverters .asColumnToTensorNameMap (data )
69
+
70
+ @parameterized .expand (_tnsr2col_test_cases )
71
+ def test_invalid_output_mapping (self , data , description ):
72
+ with self .assertRaises (TypeError , msg = description ):
73
+ SparkDLTypeConverters .asTensorNameToColumnMap (data )
0 commit comments