Skip to content

Commit 9748d59

Browse files
committed
Merge pull request #19 from mogproject/topic-option-type
add Option to types
2 parents b77af0c + 286518e commit 9748d59

File tree

3 files changed

+46
-8
lines changed

3 files changed

+46
-8
lines changed

src/mog_commons/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
__version__ = '0.1.13'
1+
__version__ = '0.1.14'

src/mog_commons/types.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
__all__ = [
1313
'String',
1414
'Unicode',
15+
'Option',
1516
'ListOf',
1617
'TupleOf',
1718
'SetOf',
@@ -101,6 +102,11 @@ def KwArg(cls):
101102
return DictOf(String, cls)
102103

103104

105+
def Option(cls):
106+
"""Shorthand description for a type allowing NoneType"""
107+
return cls + (type(None),) if isinstance(cls, tuple) else (cls, type(None))
108+
109+
104110
#
105111
# Helper functions
106112
#
@@ -114,7 +120,12 @@ def _get_name(cls):
114120

115121

116122
def _check_type(obj, cls):
117-
return cls.check(obj) if isinstance(cls, ComposableType) else isinstance(obj, cls)
123+
if isinstance(cls, ComposableType):
124+
return cls.check(obj)
125+
elif isinstance(cls, tuple):
126+
return any(_check_type(obj, t) for t in cls)
127+
else:
128+
return isinstance(obj, cls)
118129

119130

120131
#
@@ -147,12 +158,15 @@ def wrapper(*args, **kwargs):
147158
for arg_name, expect in arg_types.items():
148159
assert arg_name in call_args, 'Not found argument: %s' % arg_name
149160
actual = call_args[arg_name]
150-
assert _check_type(actual, expect), arg_msg % (arg_name, _get_name(expect), type(actual).__name__)
161+
if not _check_type(actual, expect):
162+
raise TypeError(arg_msg % (arg_name, _get_name(expect), type(actual).__name__))
151163

152164
ret = func(*args, **kwargs)
153165
if return_type:
154-
assert _check_type(ret, return_type[0]), return_msg % (_get_name(return_type[0]), type(ret).__name__)
166+
if not _check_type(ret, return_type[0]):
167+
raise TypeError(return_msg % (_get_name(return_type[0]), type(ret).__name__))
155168
return ret
169+
156170
return wrapper
157171

158172
return f

tests/mog_commons/test_types.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,49 @@ def err_func2():
2929
@types(bool, int)
3030
def f():
3131
pass
32+
3233
return 1
3334

3435
@staticmethod
3536
@types(bool)
3637
def predicate():
3738
return 1
3839

40+
@types(x=Option((int, float)))
41+
def optional_func1(self, x):
42+
return x
43+
44+
@types(int, xs=Option(ListOf(int)))
45+
def optional_func2(self, xs=None):
46+
return len(xs or [])
47+
48+
class Foo(object):
49+
pass
50+
3951
def test_types(self):
4052
str_type = '(basestring|str)' if six.PY2 else '(str|bytes)'
4153

4254
self.assertEqual(self.bin_func(10, 20), 30)
43-
self.assertRaisesMessage(AssertionError, 'x must be int, not dict.', self.bin_func, {}, 20)
44-
self.assertRaisesMessage(AssertionError, 'y must be int, not list.', self.bin_func, 10, [])
55+
self.assertRaisesMessage(TypeError, 'x must be int, not dict.', self.bin_func, {}, 20)
56+
self.assertRaisesMessage(TypeError, 'y must be int, not list.', self.bin_func, 10, [])
4557

4658
self.assertEqual(self.complex_func(123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}]), 1)
47-
self.assertRaisesMessage(AssertionError, 'kw must be dict(%s->float), not dict.' % str_type,
59+
self.assertRaisesMessage(TypeError, 'kw must be dict(%s->float), not dict.' % str_type,
4860
self.complex_func, 123, [1, 2], 10, 'abc', 'def', [{'x': set([3, 4, 5])}], x='12.3')
4961

50-
self.assertRaisesMessage(AssertionError, 'must return bool, not int.', self.predicate)
62+
self.assertRaisesMessage(TypeError, 'must return bool, not int.', self.predicate)
63+
64+
self.assertEqual(self.optional_func1(None), None)
65+
self.assertEqual(self.optional_func1(123), 123)
66+
self.assertEqual(self.optional_func1(1.23), 1.23)
67+
self.assertRaisesMessage(TypeError, 'x must be (int|float|NoneType), not Foo.',
68+
self.optional_func1, self.Foo())
69+
70+
self.assertEqual(self.optional_func2(), 0)
71+
self.assertEqual(self.optional_func2(None), 0)
72+
self.assertEqual(self.optional_func2([1, 2, 3]), 3)
73+
self.assertRaisesMessage(TypeError, 'xs must be (list(int)|NoneType), not list.',
74+
self.optional_func2, [1, 2, 3.4])
5175

5276
def test_types_error(self):
5377
self.assertRaisesMessage(AssertionError, 'Not found argument: a', self.err_func1, 123)

0 commit comments

Comments
 (0)