diff --git a/api.py b/api.py index 94a31c7..c6f1d4c 100644 --- a/api.py +++ b/api.py @@ -3,6 +3,7 @@ logger = get_logger() action_list = {} +ob11_api_list = {} ''' def register_action(name: str) -> Callable: @@ -17,4 +18,8 @@ def decorator(func: Callable) -> None: def register_action(func: Callable) -> None: action_list[func.__name__] = func - logger.debug(f"成功注册动作:{func.__name__}") \ No newline at end of file + logger.debug(f"成功注册动作:{func.__name__}") + +def register_ob11_api(func: Callable) -> None: + ob11_api_list[func.__name__] = func + logger.debug(f"成功注册接口:{func.__name__} (OneBot V11)") diff --git a/call_action.py b/call_action.py index f3c188c..09b0738 100644 --- a/call_action.py +++ b/call_action.py @@ -1,4 +1,4 @@ -from api import action_list +from api import action_list, ob11_api_list from typing import Callable import inspect import return_object @@ -38,19 +38,40 @@ def check_params(func: Callable, params: dict) -> tuple[bool, dict]: return True, {} - +def get_action_function(action: str, protocol_version: int) -> Callable | None: + """ + 获取动作函数 + + Args: + action (str): 动作/接口名 + protocol_version (int): 协议版本 11/12 + + Returns: + Callable: 动作执行函数 + """ + if protocol_version == 11 and action not in ob11_api_list.keys() and config["system"].get("allow_v12_actions", True): + logger.warning(f"接口 {action} (V11) 不存在,尝试使用 V12") + return action_list.get(action) + elif protocol_version == 11: + logger.error(f"接口 {action} (V11) 不存在") + return ob11_api_list.get(action) + elif protocol_version == 12 and action not in action_list.keys() and config["system"].get("allow_v11_actions", False): + logger.warning(f"动作 {action} 不存在,尝试使用 V11") + return ob11_api_list.get(action) + else: + return action_list.get(action) -async def on_call_action(action: str, params: dict, echo: str | None = None, **_) -> dict: - logger.debug(f"请求执行动作:{action} ({params=}, {echo=})") +async def on_call_action(action: str, params: dict, echo: str | None = None, protocol_version: int = 12, **_) -> dict: + logger.debug(f"请求执行动作:{action} ({params=}, {echo=}, {protocol_version=})") if config['system'].get("allow_strike") and random.random() <= 0.1: return return_object.get(36000, "I am tried.") - if action not in action_list.keys(): + if not (action_function := get_action_function(action, protocol_version)): return return_object.get(10002, "action not found") - if not (params_checking_result := check_params(action_list[action], params))[0]: + if not (params_checking_result := check_params(action_function, params))[0]: return params_checking_result[1] try: - return_data = await action_list[action](**params) + return_data = await action_function(**params) except UnsupportedSegment as e: return return_object.get(10005, str(e)) except BadSegmentData as e: