Skip to content

Commit fcabcb6

Browse files
committed
using parameterized to simplify testing logic
https://github.com/wolever/parameterized
1 parent f7a7d38 commit fcabcb6

File tree

2 files changed

+30
-80
lines changed

2 files changed

+30
-80
lines changed

python/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ h5py>=2.7.0
44
keras==2.0.4 # NOTE: this package has only been tested with keras 2.0.4 and may not work with other releases
55
nose>=1.3.7 # for testing
66
numpy>=1.11.2
7+
parameterized>=0.6.1 # for testing
78
pillow>=4.1.1,<4.2
89
pygments>=2.2.0
910
tensorflow==1.3.0

python/tests/param/params_test.py

Lines changed: 29 additions & 80 deletions
Original file line numberDiff line numberDiff line change
@@ -15,97 +15,36 @@
1515
from __future__ import absolute_import, division, print_function
1616

1717
from collections import namedtuple
18-
from six import with_metaclass
18+
# Use this to create parameterized test cases
19+
from parameterized import parameterized
1920

2021
from sparkdl.param.converters import SparkDLTypeConverters
2122

2223
from ..tests import PythonUnitTestCase
2324

24-
25-
class TestGenMeta(type):
26-
"""
27-
This meta-class add test cases to the main unit-test class.
28-
To add test cases, implement the test logic in a function
29-
>>> `def _my_test_impl(): ...`
30-
then call the following function
31-
>>> `_register_test_case(fn_impl=_my_test_impl, name=..., doc=...)`
32-
"""
33-
def __new__(mcs, name, bases, attrs):
34-
_add_invalid_col2tnsr_mapping_tests()
35-
attrs.update(_TEST_FUNCTIONS_REGISTRY)
36-
return super(TestGenMeta, mcs).__new__(mcs, name, bases, attrs)
37-
38-
39-
# Stores test function name mapped to implementation body
40-
_TEST_FUNCTIONS_REGISTRY = {}
41-
4225
TestCase = namedtuple('TestCase', ['data', 'description'])
4326

44-
def _register_test_case(fn_impl, name, doc):
45-
""" Add an individual test case """
46-
fn_impl.__name__ = name
47-
fn_impl.__doc__ = doc
48-
_TEST_FUNCTIONS_REGISTRY[name] = fn_impl
49-
50-
def _add_invalid_col2tnsr_mapping_tests():
51-
""" Create a list of test cases and construct individual test functions for each case """
52-
shared_test_cases = [
53-
TestCase(data=['a1', 'b2'], description='required pair but get single element'),
54-
TestCase(data=('c3', 'd4'), description='required pair but get single element'),
55-
TestCase(data=[('a', 1), ('b', 2)], description='only accept dict, but get list'),
56-
TestCase(data={1: 'a', 2.0: 'b'}, description='wrong mapping type'),
57-
TestCase(data={'a': 1.0, 'b': 2}, description='wrong mapping type'),
58-
]
59-
60-
# Specify test cases for `asColumnToTensorNameMap`
61-
# Add additional test cases specific to this one
62-
col2tnsr_test_cases = shared_test_cases + [
63-
TestCase(data={'colA': 'tnsrOpA', 'colB': 'tnsrOpB'},
64-
description='strict tensor name required'),
65-
]
66-
_fn_name_template = 'test_invalid_col2tnsr_{idx}'
67-
_fn_doc_template = 'Test invalid column => tensor name mapping: {description}'
68-
69-
for idx, test_case in enumerate(col2tnsr_test_cases):
70-
# Add the actual test logic here
71-
def test_fn_impl(self):
72-
with self.assertRaises(TypeError, msg=test_case.description):
73-
SparkDLTypeConverters.asColumnToTensorNameMap(test_case.data)
74-
75-
_name = _fn_name_template.format(idx=idx)
76-
_doc = _fn_doc_template.format(description=test_case.description)
77-
_register_test_case(fn_impl=test_fn_impl, name=_name, doc=_doc)
78-
79-
80-
# Specify tests for `asTensorNameToColumnMap`
81-
tnsr2col_test_cases = shared_test_cases + [
82-
TestCase(data={'tnsrOpA': 'colA', 'tnsrOpB': 'colB'},
83-
description='strict tensor name required'),
84-
]
85-
_fn_name_template = 'test_invalid_tnsr2col_{idx}'
86-
_fn_doc_template = 'Test invalid tensor name => column mapping: {description}'
87-
88-
for idx, test_case in enumerate(tnsr2col_test_cases):
89-
# Add the actual test logic here
90-
def test_fn_impl(self): # pylint: disable=function-redefined
91-
with self.assertRaises(TypeError, msg=test_case.description):
92-
SparkDLTypeConverters.asTensorNameToColumnMap(test_case.data)
93-
94-
_name = _fn_name_template.format(idx=idx)
95-
_doc = _fn_doc_template.format(description=test_case.description)
96-
_register_test_case(fn_impl=test_fn_impl, name=_name, doc=_doc)
97-
98-
99-
class ParamsConverterTest(with_metaclass(TestGenMeta, PythonUnitTestCase)):
27+
_shared_invalid_test_cases = [
28+
TestCase(data=['a1', 'b2'], description='required pair but get single element'),
29+
TestCase(data=('c3', 'd4'), description='required pair but get single element'),
30+
TestCase(data=[('a', 1), ('b', 2)], description='only accept dict, but get list'),
31+
TestCase(data={1: 'a', 2.0: 'b'}, description='wrong mapping type'),
32+
TestCase(data={'a': 1.0, 'b': 2}, description='wrong mapping type'),
33+
]
34+
_col2tnsr_test_cases = _shared_invalid_test_cases + [
35+
TestCase(data={'colA': 'tnsrOpA', 'colB': 'tnsrOpB'},
36+
description='strict tensor name required'),
37+
]
38+
_tnsr2col_test_cases = _shared_invalid_test_cases + [
39+
TestCase(data={'tnsrOpA': 'colA', 'tnsrOpB': 'colB'},
40+
description='strict tensor name required'),
41+
]
42+
43+
class ParamsConverterTest(PythonUnitTestCase):
10044
"""
10145
Test MLlib Params introduced in Spark Deep Learning Pipeline
10246
Additional test cases are attached via the meta class `TestGenMeta`.
10347
"""
104-
# pylint: disable=protected-access
105-
106-
@classmethod
107-
def setUpClass(cls):
108-
print(repr(cls), cls)
10948

11049
def test_tf_input_mapping_converter(self):
11150
""" Test valid input mapping conversion """
@@ -122,3 +61,13 @@ def test_tf_output_mapping_converter(self):
12261

12362
res = SparkDLTypeConverters.asTensorNameToColumnMap(valid_tnsr_output)
12463
self.assertEqual(valid_output_mapping_result, res)
64+
65+
@parameterized.expand(_col2tnsr_test_cases)
66+
def test_invalid_input_mapping(self, data, description):
67+
with self.assertRaises(TypeError, msg=description):
68+
SparkDLTypeConverters.asColumnToTensorNameMap(data)
69+
70+
@parameterized.expand(_tnsr2col_test_cases)
71+
def test_invalid_output_mapping(self, data, description):
72+
with self.assertRaises(TypeError, msg=description):
73+
SparkDLTypeConverters.asTensorNameToColumnMap(data)

0 commit comments

Comments
 (0)