Skip to content

Commit

Permalink
添加 OneBot V11 HTTP 支持
Browse files Browse the repository at this point in the history
  • Loading branch information
This-is-XiaoDeng committed Oct 14, 2023
1 parent 691138b commit 6fd98ed
Show file tree
Hide file tree
Showing 7 changed files with 146 additions and 42 deletions.
7 changes: 6 additions & 1 deletion api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
logger = get_logger()
action_list = {}


'''
def register_action(name: str) -> Callable:
"""
Register an action
Expand All @@ -13,3 +13,8 @@ def decorator(func: Callable) -> None:
action_list[name] = func
logger.debug(f"成功注册动作:{name}")
return decorator
'''

def register_action(func: Callable) -> None:
action_list[func.__name__] = func
logger.debug(f"成功注册动作:{func.__name__}")
48 changes: 24 additions & 24 deletions basic_actions_v12.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
logger = get_logger()


@register_action("send_message")
@register_action
async def send_message(
detail_type: str,
message: list,
Expand Down Expand Up @@ -58,7 +58,7 @@ async def send_message(
return return_object.get(0, message_id=message_id, time=time.time())


@register_action("get_supported_actions")
@register_action
async def get_supported_actions() -> dict:
"""
获取支持的动作列表
Expand Down Expand Up @@ -87,12 +87,12 @@ async def get_status() -> dict:
)


@register_action("get_version")
@register_action
async def get_version() -> dict:
return return_object.get(0, impl="onedisc", version=VERSION, onebot_version="12")


@register_action("delete_message")
@register_action
async def delete_message(message_id: str) -> dict:
for message in client.cached_messages[::-1]:
if str(message.id) == message_id:
Expand All @@ -108,14 +108,14 @@ async def delete_message(message_id: str) -> dict:
return return_object.get(0)


@register_action("get_self_info")
@register_action
async def get_self_info() -> dict:
return return_object.get(
0, user_id=str(client.user.id), user_name=client.user.name, user_displayname=""
)


@register_action("get_user_info")
@register_action
async def get_user_info(user_id: str) -> dict:
if not (user := client.get_user(int(user_id))):
return return_object.get(35003, "用户不存在")
Expand All @@ -128,27 +128,27 @@ async def get_user_info(user_id: str) -> dict:
)


@register_action("get_friend_list")
@register_action
async def get_friend_list() -> dict:
return return_object._get(0, [])


@register_action("get_group_info")
@register_action
async def get_group_info(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
return return_object.get(0, group_id=str(channel.id), group_name=channel.name)


@register_action("get_group_list")
@register_action
async def get_group_list() -> dict:
channel_list = []
for channel in client.get_all_channels():
channel_list.append({"group_id": str(channel.id), "group_name": channel.name})
return return_object._get(0, channel_list)


@register_action("get_group_member_info")
@register_action
async def get_group_member_info(group_id: str, user_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
Expand All @@ -163,7 +163,7 @@ async def get_group_member_info(group_id: str, user_id: str) -> dict:
)


@register_action("get_group_member_list")
@register_action
async def get_group_member_list(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
Expand All @@ -179,40 +179,40 @@ async def get_group_member_list(group_id: str) -> dict:
return return_object._get(0, member_list)


@register_action("set_group_name")
@register_action
async def set_group_name(group_id: str, group_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改频道名")


@register_action("leave_group")
@register_action
async def leave_group(group_id: str) -> dict:
if not (channel := client.get_channel(int(group_id))):
return return_object.get(35001, "频道不存在")
await channel.leave()
return return_object.get(0)


@register_action("get_guild_info")
@register_action
async def get_guild_info(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
return return_object.get(0, guild_id=str(guild.id), guild_name=guild.name)


@register_action("get_guild_list")
@register_action
async def get_guild_list() -> dict:
guild_list = []
for guild in client.guilds:
guild_list.append({"guild_id": str(guild.id), "guild_name": guild.name})
return return_object._get(0, guild_list)


@register_action("get_guild_member_list")
@register_action
async def set_guild_name(guild_id: str, guild_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改群组名")


@register_action("get_guild_member_list")
@register_action
async def get_guild_member_info(guild_id: str, user_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -227,7 +227,7 @@ async def get_guild_member_info(guild_id: str, user_id: str) -> dict:
)


@register_action("get_guild_member_list")
@register_action
async def get_guild_member_list(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -243,7 +243,7 @@ async def get_guild_member_list(guild_id: str) -> dict:
return return_object._get(0, member_list)


@register_action("get_guild_member_list")
@register_action
async def leave_guild(guild_id: str) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand Down Expand Up @@ -274,7 +274,7 @@ def _parse_channel_action_data(_data: dict) -> dict:
return data


@register_action("get_channel_list")
@register_action
async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict:
if not (guild := client.get_guild(int(guild_id))):
return return_object.get(35004, "服务器不存在")
Expand All @@ -291,21 +291,21 @@ async def get_channel_list(guild_id: str, joined_only: bool = False) -> dict:
return return_object._get(0, channel_list)


@register_action("set_channel_name")
@register_action
async def set_channel_name(guild_id: str, channel_id: str, channel_name: str) -> dict:
return return_object.get(10002, "不支持机器人修改频道名")


@register_action("get_channel_member_info")
@register_action
async def get_channel_member_info(guild_id: str, channel_id: str, user_id: str) -> dict:
return _parse_channel_action_data(await get_group_member_info(channel_id, user_id))


@register_action("get_channel_member_list")
@register_action
async def get_channel_member_list(guild_id: str, channel_id: str) -> dict:
return _parse_channel_action_data(await get_group_member_list(channel_id))


@register_action("leave_channel")
@register_action
async def leave_channel(guild_id: str, channel_id: str) -> dict:
return _parse_channel_action_data(await leave_group(channel_id))
46 changes: 40 additions & 6 deletions call_action.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
from api import action_list
from typing import Callable
import inspect
import return_object
from checker import BadParam
from logger import get_logger
Expand All @@ -9,12 +11,44 @@

logger = get_logger()

def check_params(func: Callable, params: dict) -> tuple[bool, dict]:
"""
检查参数及类型类型
Args:
func (Callable): 动作函数
params (dict): 实参列表
Returns:
tuple[bool] | tuple[bool, dict]: 检查结果
"""
arg_spec = inspect.getfullargspec(func)
for key in list(params.keys()):
if key not in arg_spec.args:
if config["system"].get("ignore_unneeded_args", True):
logger.warning(f"参数 {key} 未在 {func.__name__} 中定义,已忽略")
del params[key]
continue
else:
return False, return_object.get(10004, f"参数 {key} 未在 {func.__name__} 中定义")
if key in arg_spec.annotations.keys() and not isinstance(params[key], arg_spec.annotations[key]):
if not config["system"].get("ignore_error_types"):
return False, return_object.get(10001, f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确")
logger.warning(f"参数 {key} ({type(params[key])},应为 {arg_spec.annotations[key]}) 类型不正确,已忽略")
return True, {}





async def on_call_action(action: str, params: dict, echo: str | None = None, **_) -> dict:
logger.debug(f"请求执行动作:{action} ({params=}, {echo=})")
if config['system'].get("allow_strike") and random.random() <= 0.1:
return return_object.get(36000, "I am tried.")
if action not in action_list.keys():
return return_object.get(10002, "action not found")
if not (params_checking_result := check_params(action_list[action], params))[0]:
return params_checking_result[1]
try:
return_data = await action_list[action](**params)
except UnsupportedSegment as e:
Expand All @@ -23,12 +57,12 @@ async def on_call_action(action: str, params: dict, echo: str | None = None, **_
return return_object.get(10006, str(e))
except BadParam as e:
return return_object.get(10003, str(e))
except TypeError as e:
if "got an unexpected keyword argument" in str(e):
return return_object.get(10004, str(e))
else:
logger.error(traceback.format_exc())
return_data = return_object.get(20002, str(e))
# except TypeError as e:
# if "got an unexpected keyword argument" in str(e):
# return return_object.get(10004, str(e))
# else:
# logger.error(traceback.format_exc())
# return_data = return_object.get(20002, str(e))
except Exception as e:
logger.error(traceback.format_exc())
return_data = return_object.get(20002, str(e))
Expand Down
33 changes: 28 additions & 5 deletions connection.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from http_server import HTTPServer
import asyncio
from http_webhook import HttpWebhookConnect
from http_server_v11 import HTTPServer4OB11
from logger import get_logger
from ws import WebSocketServer
from ws_reverse import WebSocketClient
Expand All @@ -12,8 +13,13 @@
async def init_connections(connection_list: list[dict]) -> None:
for obc_config in connection_list:
logger.debug(obc_config)
match obc_config["type"]:
case "http":

if "type" not in obc_config:
logger.error(f"无效的连接配置:{obc_config}")

match obc_config["type"], obc_config.get("protocol_version", 12):

case "http", 12:
connections.append({
"type": "http",
"config": obc_config,
Expand All @@ -22,15 +28,27 @@ async def init_connections(connection_list: list[dict]) -> None:
})
await tmp.start_server()
del tmp
case "http-webhook":

case "http", 11:
connection_list.append({
"type": "http",
"config": obc_config,
"object": (tmp := HTTPServer4OB11(obc_config)),
"add_event_func": tmp.push_event
})
await tmp.start_server()
del tmp

case "http-webhook", 12:
connections.append({
"type": "http-webhook",
"config": obc_config,
"object": (tmp := HttpWebhookConnect(obc_config)),
"add_event_func": tmp.on_event
})
del tmp
case "ws":

case "ws", 12:
connections.append({
"type": "ws",
"config": obc_config,
Expand All @@ -39,7 +57,8 @@ async def init_connections(connection_list: list[dict]) -> None:
})
await tmp.start_server()
del tmp
case "ws-reverse":

case "ws-reverse", 12:
connections.append({
"type": "ws-reverse",
"config": obc_config,
Expand All @@ -48,3 +67,7 @@ async def init_connections(connection_list: list[dict]) -> None:
})
asyncio.create_task(tmp.reconnect())
del tmp


case _:
logger.warning(f"无效的连接类型或协议版本,已忽略: {obc_config['type']} (协议版本: {obc_config.get('protocol_version', 12)}")
8 changes: 4 additions & 4 deletions file.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def upload_file_from_path(name: str, path: str) -> tuple[bool, str]:
return False, str(e)


@register_action("get_file_fragmented")
@register_action
async def get_file_fragmented(
stage: str,
file_id: str,
Expand Down Expand Up @@ -120,7 +120,7 @@ async def get_file_fragmented(

uploading_files = {}

@register_action("upload_file_fragmented")
@register_action
async def upload_file_fragmented(
stage: str,
name: str | None = None,
Expand Down Expand Up @@ -161,7 +161,7 @@ async def upload_file_fragmented(
)
return return_object.get(10003, f"无效的 stage 参数:{stage}")

@register_action("upload_file")
@register_action
async def upload_file(
type: str,
name: str,
Expand Down Expand Up @@ -218,7 +218,7 @@ def get_file_name_by_id(file_id: str) -> str:
def get_file_path(file_name: str) -> str:
return os.path.abspath(f".cache/files/{file_name}")

@register_action("get_file")
@register_action
async def get_file(file_id: str, type: str) -> dict:
"""
获取文件
Expand Down
4 changes: 2 additions & 2 deletions http_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ def __init__(self, config: dict) -> None:
Args:
config (dict): 连接配置
"""
self.config = config
self.config.update(BASE_CONFIG)
self.config = BASE_CONFIG.copy()
self.config.update(config)
self.event_list = []

if self.config["event_enabled"] and self.config["event_buffer_size"] <= 0:
Expand Down
Loading

0 comments on commit 6fd98ed

Please sign in to comment.