diff --git a/connection.py b/connection.py index 87d9cde..dcd03fd 100644 --- a/connection.py +++ b/connection.py @@ -1,4 +1,4 @@ -from http_server import create_http_server +from http_server import HTTPServer import asyncio from http_webhook import HttpWebhookConnect from logger import get_logger @@ -17,8 +17,11 @@ async def init_connections(connection_list: list[dict]) -> None: connections.append({ "type": "http", "config": obc_config, - "add_event_func": await create_http_server(obc_config) + "object": (tmp := HTTPServer(obc_config)), + "add_event_func": tmp.push_event }) + await tmp.start_server() + del tmp case "http-webhook": connections.append({ "type": "http-webhook", @@ -44,5 +47,4 @@ async def init_connections(connection_list: list[dict]) -> None: "add_event_func": tmp.push_event }) asyncio.create_task(tmp.reconnect()) - # asyncio.create_task(tmp.setup_receive_loop()) del tmp diff --git a/http_server.py b/http_server.py index 3f3d23b..21279b0 100644 --- a/http_server.py +++ b/http_server.py @@ -1,4 +1,4 @@ -from typing import Callable +import asyncio import uvicorn_server import call_action import json @@ -6,7 +6,7 @@ import return_object from logger import get_logger -BASE_CONNECTION_CONFIG = { +BASE_CONFIG = { "host": "0.0.0.0", "port": 8080, "access_token": None, @@ -15,7 +15,7 @@ } logger = get_logger() - +''' async def create_http_server(_config: dict) -> Callable: """ 创建 HTTP 服务器 @@ -86,8 +86,94 @@ async def add_event(data: dict) -> None: event_list = event_list[-config["event_buffer_size"]:] return handle_http_connection, add_event +''' + +class HTTPServer: + + def __init__(self, config: dict) -> None: + """ + 初始化 HTTP 服务器 + + Args: + config (dict): 连接配置 + """ + self.config = config + self.config.update(BASE_CONFIG) + self.event_list = [] + + if self.config["event_enabled"] and self.config["event_buffer_size"] <= 0: + logger.warning("警告: 事件缓冲区大小配置不正确,可能导致内存泄露!") + + self.app = fastapi.FastAPI() + self.app.add_route("/", self.handle_http_connection, ["post"]) + + async def start_server(self): + await uvicorn_server.run(self.app, self.config["port"], self.config["host"]) + + async def handle_http_connection(self, request: fastapi.Request) -> fastapi.responses.JSONResponse: + """ + 处理 HTTP 请求 + + Args: + request (fastapi.Request): 请求信息 + + Returns: + dict: 返回值 + """ + logger.debug(request) + if verify_access_token(request, self.config["access_token"]): + raise fastapi.HTTPException(fastapi.status.HTTP_401_UNAUTHORIZED) + logger.debug(await request.body()) + return fastapi.responses.JSONResponse(await self.on_call_action(await request.body())) + + async def on_call_action(self, body: bytes) -> dict: + """ + 处理动作请求 + + Args: + body (bytes): 请求载荷 + + Returns: + dict: 返回内容 + """ + try: + data = json.loads(body) + except json.JSONDecodeError as e: + return return_object.get(10001, str(e)) + if "action" not in data.keys(): + return return_object.get(10001, "action 字段不存在") + if data["action"] == "get_latest_events": + return await self.get_latest_events(**data["params"]) + return await call_action.on_call_action(**data) + + async def get_latest_events(self, limit: int = 0, timeout: int = 0, **_) -> dict: + """ + 获取最新事件 + + Args: + limit (int, optional): 最多获取条数,为 0 无限制. Defaults to 0. + timeout (int, optional): 没有新事件时最多的等待时间,为 0 不等待. Defaults to 0. + + Returns: + dict: 返回数据 + """ + retried = 0 + while not (events := self.event_list[-limit:]): + await asyncio.sleep(1) + retried += 1 + if retried >= timeout: + break + return return_object._get(0, events) + + async def push_event(self, event: dict) -> None: + if self.config["event_enabled"]: + self.event_list.append(event) + self.event_list = self.event_list[-self.config["event_buffer_size"]:] + + + -def verify_access_token(request: fastapi.Request | fastapi.WebSocket, access_token: str) -> bool: +def verify_access_token(request: fastapi.Request | fastapi.WebSocket, access_token: str | None) -> bool: """ 鉴权 @@ -98,6 +184,8 @@ def verify_access_token(request: fastapi.Request | fastapi.WebSocket, access_tok Returns: bool: 是否通过验证 """ + if access_token is None: + return True if "Authorization" in request.headers.keys(): return request.headers["Authorization"] == f"Bearer {access_token}" return request.query_params.get("access_token") == access_token diff --git a/uvicorn_server.py b/uvicorn_server.py index c51216a..fff832c 100644 --- a/uvicorn_server.py +++ b/uvicorn_server.py @@ -3,6 +3,9 @@ from typing import List import fastapi import uvicorn +from logger import get_logger + +logger = get_logger() class Server(uvicorn.Server): @@ -31,4 +34,5 @@ async def run(app: fastapi.FastAPI, port: int, host: str = "0.0.0.0", **params): **params ) server = Server(config) - await server.run() \ No newline at end of file + await server.run() + logger.info(f"成功在 {host}:{port} 上开启 Uvicorn 服务器")