Skip to content

Commit 391d0cc

Browse files
authored
fix(command): fix incorrect param check for Parser.register (#236)
Change the requirements from non-async and only one str param, to regardless of async and either (str) or (Message, Client, str) param types. Since this feature was broken, there would be no incompatibilities. Closes #235
1 parent eaaea9d commit 391d0cc

File tree

1 file changed

+22
-7
lines changed

1 file changed

+22
-7
lines changed

khl/command/parser.py

Lines changed: 22 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,12 @@ def _get_param_type(param: Union[inspect.Parameter, None]):
2828
return param.annotation
2929

3030

31+
def _wrap_one_param_func(func: Callable) -> Callable:
32+
def wrapper(msg: Message, client: Client, token: str):
33+
return func(token)
34+
return wrapper
35+
36+
3137
async def _parse_user(_, client, token) -> User:
3238
if not (token.startswith("(met)") and token.endswith("(met)")):
3339
raise ValueError(f"wrong format: expected: '(met)`user_id`(met)', actual: {token}")
@@ -105,16 +111,25 @@ def register(self, func):
105111
decorator, register the func into object restricted _parse_funcs()
106112
107113
checks if parse func for that type exists, and insert if not
108-
:param func: parse func
114+
:param func: parse func, which owns either (str) or (Message, Client, str) param types
109115
"""
110116
s = inspect.signature(func)
111117

112-
# check: 1. not coroutine, 2. len matches
113-
if asyncio.iscoroutinefunction(func):
114-
raise TypeError('parse function should not be async')
115-
if len(s.parameters) != 1 or list(s.parameters.values())[0].annotation != str:
116-
raise TypeError('parse function should own only one param, and the param type is str')
118+
params = list(s.parameters.values())
119+
parse_func = None
120+
121+
# first type of function: (str)
122+
if len(s.parameters) == 1 and params[0].annotation == str:
123+
parse_func = _wrap_one_param_func(func)
124+
# second type of function: (Message, Client, str)
125+
elif len(s.parameters) == 3 \
126+
and params[0].annotation == Message \
127+
and params[1].annotation == Client \
128+
and params[2].annotation == str:
129+
parse_func = func
130+
else:
131+
raise TypeError('parse function should own either (str) or (Message, Client, str) param types')
117132

118133
# insert, remember this is a replacement
119-
self._parse_funcs[s.return_annotation] = func
134+
self._parse_funcs[s.return_annotation] = parse_func
120135
return func

0 commit comments

Comments
 (0)