diff --git a/.env.example b/.env.example index d59cbc85..5e6da1ce 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,12 @@ # debug 开关 DEBUG=false +AUTO_RELOAD=false +RELOAD_DELAY=0.25 +RELOAD_DIRS=[] +RELOAD_INCLUDE=[] +RELOAD_EXCLUDE=[] + # MySQL DB_HOST=127.0.0.1 DB_PORT=3306 @@ -17,14 +23,14 @@ REDIS_PASSWORD="" # 联系 https://t.me/BotFather 使用 /newbot 命令创建机器人并获取 token BOT_TOKEN="xxxxxxx" -# bot 管理员 -ADMINS=[{ "username": "", "user_id": -1 }] +# bot 所有者 +OWNER=0 # 记录错误并发送消息通知开发人员 可选配置项 # ERROR_NOTIFICATION_CHAT_ID=chat_id # 文章推送群组 可选配置项 -# CHANNELS=[{ "name": "", "chat_id": 1}] +# CHANNELS=[] # 是否允许机器人邀请到其他群 默认不允许 如果允许 可以允许全部人或有认证选项 可选配置项 # JOIN_GROUPS = "NO_ALLOW" @@ -33,20 +39,20 @@ ADMINS=[{ "username": "", "user_id": -1 }] # VERIFY_GROUPS=[] # logger 配置 可选配置项 -LOGGER_NAME="TGPaimon" +# LOGGER_NAME="TGPaimon" # 打印时的宽度 -LOGGER_WIDTH=180 +# LOGGER_WIDTH=180 # log 文件存放目录 -LOGGER_LOG_PATH="logs" +# LOGGER_LOG_PATH="logs" # log 时间格式,参考 datetime.strftime -LOGGER_TIME_FORMAT="[%Y-%m-%d %X]" +# LOGGER_TIME_FORMAT="[%Y-%m-%d %X]" # log 高亮关键词 -LOGGER_RENDER_KEYWORDS=["BOT"] +# LOGGER_RENDER_KEYWORDS=["BOT"] # traceback 相关配置 -LOGGER_TRACEBACK_MAX_FRAMES=20 -LOGGER_LOCALS_MAX_DEPTH=0 -LOGGER_LOCALS_MAX_LENGTH=10 -LOGGER_LOCALS_MAX_STRING=80 +# LOGGER_TRACEBACK_MAX_FRAMES=20 +# LOGGER_LOCALS_MAX_DEPTH=0 +# LOGGER_LOCALS_MAX_LENGTH=10 +# LOGGER_LOCALS_MAX_STRING=80 # 可被 logger 打印的 record 的名称(默认包含了 LOGGER_NAME ) LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"] @@ -77,7 +83,7 @@ LOGGER_FILTERED_NAMES=["uvicorn","ErrorPush","ApiHelper"] # ENKA_NETWORK_API_AGENT="" # Web Server -# 目前只用于预览模板,仅开发环境启动 +# WEB_SWITCH=False # 是否开启 # WEB_URL=http://localhost:8080/ # WEB_HOST=localhost # WEB_PORT=8080 diff --git a/.github/workflows/integration-test.yml b/.github/workflows/integration-test.yml new file mode 100644 index 00000000..a51d080f --- /dev/null +++ b/.github/workflows/integration-test.yml @@ -0,0 +1,54 @@ +name: Integration Test + +on: + push: + branches: + - main + paths: + - 'tests/integration/**' + pull_request: + types: [ opened, synchronize ] + paths: + - 'core/services/**' + - 'core/dependence/**' + - 'tests/integration/**' + +jobs: + pytest: + name: pytest + runs-on: ubuntu-latest + services: + mysql: + image: mysql:5.7 + env: + MYSQL_DATABASE: integration_test + MYSQL_ROOT_PASSWORD: 123456test + ports: + - 3306:3306 + options: --health-cmd="mysqladmin ping" --health-interval=10s --health-timeout=5s --health-retries=3 + redis: + image: redis + ports: + - 6379:6379 + steps: + - name: Checkout code + uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v2 + with: + python-version: 3.11 + - name: Setup integration test environment + run: cp tests/integration/.env.example .env && cp tests/integration/.env.example tests/integration/.env + - name: Create venv + run: | + pip install --upgrade pip + python3 -m venv venv + - name: Install requirements + run: | + source venv/bin/activate + python3 -m pip install --upgrade poetry + python3 -m poetry install --extras all + - name: Run test + run: | + source venv/bin/activate + python3 -m pytest tests/integration \ No newline at end of file diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dfd38955..74b53878 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,19 +1,17 @@ -name: test +name: Test modules on: push: branches: - main paths: - - 'tests/**' + - 'tests/unit/**' pull_request: types: [ opened, synchronize ] paths: - 'modules/apihelper/**' - 'modules/wiki/**' - - 'tests/**' - schedule: - - cron: '0 4 * * 3' + - 'tests/unit/**' jobs: pytest: @@ -22,16 +20,15 @@ jobs: continue-on-error: ${{ matrix.experimental }} strategy: matrix: - python-version: [ '3.10' ] os: [ ubuntu-latest, windows-latest ] - experimental: [ false ] fail-fast: False steps: - - uses: actions/checkout@v3 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + - name: Checkout code + uses: actions/checkout@v3 + - name: Set up Python 3.11 + uses: actions/setup-python@v2 with: - python-version: ${{ matrix.python-version }} + python-version: 3.11 - name: restore or create a python virtualenv id: cache uses: syphar/restore-virtualenv@v1.2 @@ -45,4 +42,4 @@ jobs: poetry install --extras test - name: Test with pytest run: | - python -m pytest \ No newline at end of file + python -m pytest tests/unit \ No newline at end of file diff --git a/.gitignore b/.gitignore index 78c6a5f7..7ee81cbc 100644 --- a/.gitignore +++ b/.gitignore @@ -58,6 +58,5 @@ plugins/private .pytest_cache ### mtp ### -paimon.session -PaimonBot.session -PaimonBot.session-journal +paigram.session +paigram.session-journal diff --git a/README.md b/README.md index 18a96749..d0a15349 100644 --- a/README.md +++ b/README.md @@ -1,7 +1,6 @@

PaiGram

-
- +
· code_style @@ -19,7 +18,7 @@ ## 环境需求 -- Python 3.8+ +- Python 3.11+ - MySQL - Redis diff --git a/alembic/env.py b/alembic/env.py index 08006a66..66fe0860 100644 --- a/alembic/env.py +++ b/alembic/env.py @@ -6,19 +6,13 @@ from typing import Iterator from alembic import context -from sqlalchemy import ( - engine_from_config, - pool, -) +from sqlalchemy import engine_from_config, pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import AsyncEngine from sqlmodel import SQLModel -from utils.const import ( - CORE_DIR, - PLUGIN_DIR, - PROJECT_ROOT, -) +from core.config import config as BotConfig +from utils.const import CORE_DIR, PLUGIN_DIR, PROJECT_ROOT from utils.log import logger # this is the Alembic Config object, which provides @@ -28,7 +22,7 @@ # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None: - fileConfig(config.config_file_name) + fileConfig(config.config_file_name) # skipcq: PY-A6006 def scan_models() -> Iterator[str]: @@ -46,7 +40,7 @@ def import_models(): try: import_module(pkg) # 导入 models except Exception as e: # pylint: disable=W0703 - logger.error(f'在导入文件 "{pkg}" 的过程中遇到了错误: \n[red bold]{type(e).__name__}: {e}[/]') + logger.error("在导入文件 %s 的过程中遇到了错误: \n[red bold]%s: %s[/]", pkg, type(e).__name__, e, extra={"markup": True}) # register our models for alembic to auto-generate migrations @@ -61,14 +55,13 @@ def import_models(): # here we allow ourselves to pass interpolation vars to alembic.ini # from the application config module -from core.config import config as botConfig section = config.config_ini_section -config.set_section_option(section, "DB_HOST", botConfig.mysql.host) -config.set_section_option(section, "DB_PORT", str(botConfig.mysql.port)) -config.set_section_option(section, "DB_USERNAME", botConfig.mysql.username) -config.set_section_option(section, "DB_PASSWORD", botConfig.mysql.password) -config.set_section_option(section, "DB_DATABASE", botConfig.mysql.database) +config.set_section_option(section, "DB_HOST", BotConfig.mysql.host) +config.set_section_option(section, "DB_PORT", str(BotConfig.mysql.port)) +config.set_section_option(section, "DB_USERNAME", BotConfig.mysql.username) +config.set_section_option(section, "DB_PASSWORD", BotConfig.mysql.password) +config.set_section_option(section, "DB_DATABASE", BotConfig.mysql.database) def run_migrations_offline() -> None: diff --git a/alembic/versions/9e9a36470cd5_init.py b/alembic/versions/9e9a36470cd5_init.py index b9106625..b72315c5 100644 --- a/alembic/versions/9e9a36470cd5_init.py +++ b/alembic/versions/9e9a36470cd5_init.py @@ -5,16 +5,19 @@ Create Date: 2022-09-01 16:55:20.372560 """ -from alembic import op +from base64 import b64decode + import sqlalchemy as sa import sqlmodel - +from alembic import op # revision identifiers, used by Alembic. revision = "9e9a36470cd5" down_revision = None branch_labels = None depends_on = None +old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode() +old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode() def upgrade() -> None: @@ -22,7 +25,7 @@ def upgrade() -> None: op.create_table( "question", sa.Column("id", sa.Integer(), nullable=False), - sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("text", sqlmodel.AutoString(), nullable=True), sa.PrimaryKeyConstraint("id"), mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci", @@ -35,7 +38,7 @@ def upgrade() -> None: nullable=True, ), sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), sa.Column("yuanshen_uid", sa.Integer(), nullable=True), sa.Column("genshin_uid", sa.Integer(), nullable=True), sa.PrimaryKeyConstraint("id"), @@ -46,7 +49,7 @@ def upgrade() -> None: op.create_table( "admin", sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.user_id"], @@ -60,7 +63,7 @@ def upgrade() -> None: sa.Column("question_id", sa.Integer(), nullable=True), sa.Column("id", sa.Integer(), nullable=False), sa.Column("is_correct", sa.Boolean(), nullable=True), - sa.Column("text", sqlmodel.sql.sqltypes.AutoString(), nullable=True), + sa.Column("text", sqlmodel.AutoString(), nullable=True), sa.ForeignKeyConstraint( ["question_id"], ["question.id"], @@ -72,7 +75,7 @@ def upgrade() -> None: mysql_collate="utf8mb4_general_ci", ) op.create_table( - "hoyoverse_cookies", + old_cookies_database_name2, sa.Column("cookies", sa.JSON(), nullable=True), sa.Column( "status", @@ -85,7 +88,7 @@ def upgrade() -> None: nullable=True, ), sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.BigInteger(), nullable=True), sa.ForeignKeyConstraint( ["user_id"], ["user.user_id"], @@ -95,7 +98,7 @@ def upgrade() -> None: mysql_collate="utf8mb4_general_ci", ) op.create_table( - "mihoyo_cookies", + old_cookies_database_name1, sa.Column("cookies", sa.JSON(), nullable=True), sa.Column( "status", @@ -108,7 +111,7 @@ def upgrade() -> None: nullable=True, ), sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=True), + sa.Column("user_id", sa.BigInteger(), nullable=True), sa.ForeignKeyConstraint( ["user_id"], ["user.user_id"], @@ -119,6 +122,9 @@ def upgrade() -> None: ) op.create_table( "sign", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("chat_id", sa.BigInteger(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), @@ -140,14 +146,11 @@ def upgrade() -> None: ), nullable=True, ), - sa.Column("id", sa.Integer(), nullable=False), - sa.Column("user_id", sa.Integer(), nullable=False), - sa.Column("chat_id", sa.Integer(), nullable=True), sa.ForeignKeyConstraint( ["user_id"], ["user.user_id"], ), - sa.PrimaryKeyConstraint("id"), + sa.PrimaryKeyConstraint("id", "user_id"), mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci", ) @@ -157,8 +160,8 @@ def upgrade() -> None: def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_table("sign") - op.drop_table("mihoyo_cookies") - op.drop_table("hoyoverse_cookies") + op.drop_table(old_cookies_database_name1) + op.drop_table(old_cookies_database_name2) op.drop_table("answer") op.drop_table("admin") op.drop_table("user") diff --git a/alembic/versions/ddcfba3c7d5c_v4.py b/alembic/versions/ddcfba3c7d5c_v4.py new file mode 100644 index 00000000..22fcc5fe --- /dev/null +++ b/alembic/versions/ddcfba3c7d5c_v4.py @@ -0,0 +1,301 @@ +"""v4 + +Revision ID: ddcfba3c7d5c +Revises: 9e9a36470cd5 +Create Date: 2023-02-11 17:07:18.170175 + +""" +import json +import logging +from base64 import b64decode + +import sqlalchemy as sa +import sqlmodel +from alembic import op +from sqlalchemy import text +from sqlalchemy.exc import NoSuchTableError + +# revision identifiers, used by Alembic. +revision = "ddcfba3c7d5c" +down_revision = "9e9a36470cd5" +branch_labels = None +depends_on = None + +old_cookies_database_name1 = b64decode("bWlob3lvX2Nvb2tpZXM=").decode() +old_cookies_database_name2 = b64decode("aG95b3ZlcnNlX2Nvb2tpZXM=").decode() +logger = logging.getLogger(__name__) + + +def upgrade() -> None: + connection = op.get_bind() + # ### commands auto generated by Alembic - please adjust! ### + cookies_table = op.create_table( + "cookies", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("account_id", sa.BigInteger(), nullable=False), + sa.Column("data", sa.JSON(), nullable=True), + sa.Column( + "status", + sa.Enum( + "STATUS_SUCCESS", + "INVALID_COOKIES", + "TOO_MANY_REQUESTS", + name="cookiesstatusenum", + ), + nullable=True, + ), + sa.Column( + "region", + sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), + nullable=True, + ), + sa.Column("is_share", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.Index("index_user_account", "user_id", "account_id", unique=True), + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_general_ci", + ) + for old_cookies_database_name in (old_cookies_database_name1, old_cookies_database_name2): + try: + statement = f"SELECT * FROM {old_cookies_database_name};" # skipcq: BAN-B608 + old_cookies_table_data = connection.execute(text(statement)) + except NoSuchTableError: + logger.warning("Table '%s' doesn't exist", old_cookies_database_name) + continue + if old_cookies_table_data is None: + logger.warning("Old Cookies Database is None") + continue + for row in old_cookies_table_data: + try: + user_id = row["user_id"] + status = row["status"] + cookies_row = row["cookies"] + cookies_data = json.loads(cookies_row) + account_id = cookies_data.get("account_id") + if account_id is None: # Cleaning Data 清洗数据 + account_id = cookies_data.get("ltuid") + else: + account_mid_v2 = cookies_data.get("account_mid_v2") + if account_mid_v2 is not None: + cookies_data.pop("account_id") + cookies_data.setdefault("account_uid_v2", account_id) + if old_cookies_database_name == old_cookies_database_name1: + region = "HYPERION" + else: + region = "HOYOLAB" + if account_id is None: + logger.warning("Can not get user account_id, user_id :%s", user_id) + continue + insert = cookies_table.insert().values( + user_id=int(user_id), + account_id=int(account_id), + status=status, + data=cookies_data, + region=region, + is_share=True, + ) + with op.get_context().autocommit_block(): + connection.execute(insert) + except Exception as exc: # pylint: disable=W0703 + logger.error( + "Process %s->cookies Exception", old_cookies_database_name, exc_info=exc + ) # pylint: disable=W0703 + players_table = op.create_table( + "players", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("account_id", sa.BigInteger(), nullable=True), + sa.Column("player_id", sa.BigInteger(), nullable=False), + sa.Column( + "region", + sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), + nullable=True, + ), + sa.Column("is_chosen", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True), + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_general_ci", + ) + + try: + statement = "SELECT * FROM user;" + old_user_table_data = connection.execute(text(statement)) + except NoSuchTableError: + logger.warning("Table 'user' doesn't exist") + return # should not happen + if old_user_table_data is not None: + for row in old_user_table_data: + try: + user_id = row["user_id"] + y_uid = row["yuanshen_uid"] + g_uid = row["genshin_uid"] + region = row["region"] + account_id = None + cookies_row = connection.execute( + cookies_table.select().where(cookies_table.c.user_id == user_id) + ).first() + if cookies_row is not None: + account_id = cookies_row["account_id"] + if y_uid: + insert = players_table.insert().values( + user_id=int(user_id), + player_id=int(y_uid), + is_chosen=(region == "HYPERION"), + region="HYPERION", + account_id=account_id, + ) + with op.get_context().autocommit_block(): + connection.execute(insert) + if g_uid: + insert = players_table.insert().values( + user_id=int(user_id), + player_id=int(g_uid), + is_chosen=(region == "HOYOLAB"), + region="HOYOLAB", + account_id=account_id, + ) + with op.get_context().autocommit_block(): + connection.execute(insert) + except Exception as exc: # pylint: disable=W0703 + logger.error("Process user->player Exception", exc_info=exc) + else: + logger.warning("Old User Database is None") + + users_table = op.create_table( + "users", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False, primary_key=True), + sa.Column( + "permissions", + sa.Enum("OWNER", "ADMIN", "PUBLIC", name="permissionsenum"), + nullable=True, + ), + sa.Column("locale", sqlmodel.AutoString(), nullable=True), + sa.Column("is_banned", sa.BigInteger(), nullable=True), + sa.Column("ban_end_time", sa.DateTime(timezone=True), nullable=True), + sa.Column("ban_start_time", sa.DateTime(timezone=True), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.UniqueConstraint("user_id"), + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_general_ci", + ) + + try: + statement = "SELECT * FROM admin;" + old_user_table_data = connection.execute(text(statement)) + except NoSuchTableError: + logger.warning("Table 'admin' doesn't exist") + return # should not happen + if old_user_table_data is not None: + for row in old_user_table_data: + try: + user_id = row["user_id"] + insert = users_table.insert().values( + user_id=int(user_id), + permissions="ADMIN", + ) + with op.get_context().autocommit_block(): + connection.execute(insert) + except Exception as exc: # pylint: disable=W0703 + logger.error("Process admin->users Exception", exc_info=exc) + else: + logger.warning("Old User Database is None") + + op.create_table( + "players_info", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("user_id", sa.BigInteger(), nullable=False), + sa.Column("player_id", sa.BigInteger(), nullable=False), + sa.Column("nickname", sqlmodel.AutoString(length=128), nullable=True), + sa.Column("signature", sqlmodel.AutoString(length=255), nullable=True), + sa.Column("hand_image", sa.Integer(), nullable=True), + sa.Column("name_card", sa.Integer(), nullable=True), + sa.Column("extra_data", sa.VARCHAR(length=512), nullable=True), + sa.Column( + "create_time", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=True, + ), + sa.Column("last_save_time", sa.DateTime(timezone=True), nullable=True), + sa.Column("is_update", sa.Boolean(), nullable=True), + sa.PrimaryKeyConstraint("id"), + sa.Index("index_user_player", "user_id", "player_id", unique=True), + mysql_charset="utf8mb4", + mysql_collate="utf8mb4_general_ci", + ) + + op.drop_table(old_cookies_database_name1) + op.drop_table(old_cookies_database_name2) + op.drop_table("admin") + op.drop_constraint("sign_ibfk_1", "sign", type_="foreignkey") + op.drop_index("user_id", table_name="sign") + op.drop_table("user") + # ### end Alembic commands ### + + +def downgrade() -> None: + op.create_table( + "user", + sa.Column("region", sa.Enum("NULL", "HYPERION", "HOYOLAB", name="regionenum"), nullable=True), + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False), + sa.Column("yuanshen_uid", sa.INTEGER(), autoincrement=False, nullable=True), + sa.Column("genshin_uid", sa.INTEGER(), autoincrement=False, nullable=True), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_general_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", + ) + op.create_index("user_id", "user", ["user_id"], unique=False) + op.create_table( + "admin", + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=False), + sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="admin_ibfk_1"), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_general_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", + ) + op.create_table( + old_cookies_database_name1, + sa.Column("cookies", sa.JSON(), nullable=True), + sa.Column( + "status", + sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"), + nullable=True, + ), + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="mihoyo_cookies_ibfk_1"), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_general_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", + ) + op.create_table( + old_cookies_database_name2, + sa.Column("cookies", sa.JSON(), nullable=True), + sa.Column( + "status", + sa.Enum("STATUS_SUCCESS", "INVALID_COOKIES", "TOO_MANY_REQUESTS", name="cookiesstatusenum"), + nullable=True, + ), + sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), + sa.Column("user_id", sa.BigInteger(), autoincrement=False, nullable=True), + sa.ForeignKeyConstraint(["user_id"], ["user.user_id"], name="hoyoverse_cookies_ibfk_1"), + sa.PrimaryKeyConstraint("id"), + mysql_collate="utf8mb4_general_ci", + mysql_default_charset="utf8mb4", + mysql_engine="InnoDB", + ) + op.create_foreign_key("sign_ibfk_1", "sign", "user", ["user_id"], ["user_id"]) + op.create_index("user_id", "sign", ["user_id"], unique=False) + op.drop_table("users") + op.drop_table("players") + op.drop_table("cookies") + op.drop_table("players_info") + # ### end Alembic commands ### diff --git a/core/base/__init__.py b/core/__init__.py similarity index 100% rename from core/base/__init__.py rename to core/__init__.py diff --git a/core/admin/__init__.py b/core/admin/__init__.py deleted file mode 100644 index 5b1d97a4..00000000 --- a/core/admin/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from core.service import init_service -from core.base.mysql import MySQL -from core.base.redisdb import RedisDB -from core.admin.cache import BotAdminCache -from core.admin.repositories import BotAdminRepository -from core.admin.services import BotAdminService - - -@init_service -def create_bot_admin_service(mysql: MySQL, redis: RedisDB): - _cache = BotAdminCache(redis) - _repository = BotAdminRepository(mysql) - _service = BotAdminService(_repository, _cache) - return _service diff --git a/core/admin/cache.py b/core/admin/cache.py deleted file mode 100644 index ddce28ea..00000000 --- a/core/admin/cache.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import List - -from core.base.redisdb import RedisDB - - -class BotAdminCache: - def __init__(self, redis: RedisDB): - self.client = redis.client - self.qname = "bot:admin" - - async def get_list(self): - return [int(str_data) for str_data in await self.client.lrange(self.qname, 0, -1)] - - async def set_list(self, str_list: List[int], ttl: int = -1): - await self.client.ltrim(self.qname, 1, 0) - await self.client.lpush(self.qname, *str_list) - if ttl != -1: - await self.client.expire(self.qname, ttl) - count = await self.client.llen(self.qname) - return count - - -class GroupAdminCache: - def __init__(self, redis: RedisDB): - self.client = redis.client - self.qname = "group:admin_list" - - async def get_chat_admin(self, chat_id: int): - qname = f"{self.qname}:{chat_id}" - return [int(str_id) for str_id in await self.client.lrange(qname, 0, -1)] - - async def set_chat_admin(self, chat_id: int, admin_list: List[int]): - qname = f"{self.qname}:{chat_id}" - await self.client.ltrim(qname, 1, 0) - await self.client.lpush(qname, *admin_list) - await self.client.expire(qname, 60) - count = await self.client.llen(qname) - return count diff --git a/core/admin/models.py b/core/admin/models.py deleted file mode 100644 index e4bf7c16..00000000 --- a/core/admin/models.py +++ /dev/null @@ -1,8 +0,0 @@ -from sqlmodel import SQLModel, Field - - -class Admin(SQLModel, table=True): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - user_id: int = Field(foreign_key="user.user_id") diff --git a/core/admin/repositories.py b/core/admin/repositories.py deleted file mode 100644 index e1001fe9..00000000 --- a/core/admin/repositories.py +++ /dev/null @@ -1,33 +0,0 @@ -from typing import List, cast - -from sqlalchemy import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from core.admin.models import Admin -from core.base.mysql import MySQL - - -class BotAdminRepository: - def __init__(self, mysql: MySQL): - self.mysql = mysql - - async def delete_by_user_id(self, user_id: int): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - statement = select(Admin).where(Admin.user_id == user_id) - results = await session.exec(statement) - admin = results.one() - await session.delete(admin) - - async def add_by_user_id(self, user_id: int): - async with self.mysql.Session() as session: - admin = Admin(user_id=user_id) - session.add(admin) - await session.commit() - - async def get_all_user_id(self) -> List[int]: - async with self.mysql.Session() as session: - query = select(Admin) - results = await session.exec(query) - admins = results.all() - return [admin[0].user_id for admin in admins] diff --git a/core/admin/services.py b/core/admin/services.py deleted file mode 100644 index f1a332a2..00000000 --- a/core/admin/services.py +++ /dev/null @@ -1,60 +0,0 @@ -from typing import List - -from asyncmy.errors import IntegrityError -from telegram import Bot - -from core.admin.cache import BotAdminCache, GroupAdminCache -from core.admin.repositories import BotAdminRepository -from core.config import config -from utils.log import logger - - -class BotAdminService: - def __init__(self, repository: BotAdminRepository, cache: BotAdminCache): - self._repository = repository - self._cache = cache - - async def get_admin_list(self) -> List[int]: - admin_list = await self._cache.get_list() - if len(admin_list) == 0: - admin_list = await self._repository.get_all_user_id() - for config_admin in config.admins: - admin_list.append(config_admin.user_id) - await self._cache.set_list(admin_list) - return admin_list - - async def add_admin(self, user_id: int) -> bool: - try: - await self._repository.add_by_user_id(user_id) - except IntegrityError: - logger.warning("用户 %s 已经存在 Admin 数据库", user_id) - admin_list = await self._repository.get_all_user_id() - for config_admin in config.admins: - admin_list.append(config_admin.user_id) - await self._cache.set_list(admin_list) - return True - - async def delete_admin(self, user_id: int) -> bool: - try: - await self._repository.delete_by_user_id(user_id) - except ValueError: - return False - admin_list = await self._repository.get_all_user_id() - for config_admin in config.admins: - admin_list.append(config_admin.user_id) - await self._cache.set_list(admin_list) - return True - - -class GroupAdminService: - def __init__(self, cache: GroupAdminCache): - self._cache = cache - - async def get_admins(self, bot: Bot, chat_id: int, extra_user: List[int]) -> List[int]: - admin_id_list = await self._cache.get_chat_admin(chat_id) - if len(admin_id_list) == 0: - admin_list = await bot.get_chat_administrators(chat_id) - admin_id_list = [admin.user.id for admin in admin_list] - await self._cache.set_chat_admin(chat_id, admin_id_list) - admin_id_list += extra_user - return admin_id_list diff --git a/core/application.py b/core/application.py new file mode 100644 index 00000000..768d4f6a --- /dev/null +++ b/core/application.py @@ -0,0 +1,287 @@ +"""BOT""" +import asyncio +import signal +from functools import wraps +from signal import SIGABRT, SIGINT, SIGTERM, signal as signal_func +from ssl import SSLZeroReturnError +from typing import Callable, List, Optional, TYPE_CHECKING, TypeVar + +import pytz +import uvicorn +from fastapi import FastAPI +from telegram import Bot, Update +from telegram.error import NetworkError, TelegramError, TimedOut +from telegram.ext import ( + Application as TelegramApplication, + ApplicationBuilder as TelegramApplicationBuilder, + Defaults, + JobQueue, +) +from typing_extensions import ParamSpec +from uvicorn import Server + +from core.config import config as application_config +from core.handler.limiterhandler import LimiterHandler +from core.manager import Managers +from core.override.telegram import HTTPXRequest +from utils.const import WRAPPER_ASSIGNMENTS +from utils.log import logger +from utils.models.signal import Singleton + +if TYPE_CHECKING: + from asyncio import Task + from types import FrameType + +__all__ = ("Application",) + +R = TypeVar("R") +T = TypeVar("T") +P = ParamSpec("P") + + +class Application(Singleton): + """Application""" + + _web_server_task: Optional["Task"] = None + + _startup_funcs: List[Callable] = [] + _shutdown_funcs: List[Callable] = [] + + def __init__(self, managers: "Managers", telegram: "TelegramApplication", web_server: "Server") -> None: + self._running = False + self.managers = managers + self.telegram = telegram + self.web_server = web_server + self.managers.set_application(application=self) # 给 managers 设置 application + self.managers.build_executor("Application") + + @classmethod + def build(cls): + managers = Managers() + telegram = ( + TelegramApplicationBuilder() + .get_updates_read_timeout(application_config.update_read_timeout) + .get_updates_write_timeout(application_config.update_write_timeout) + .get_updates_connect_timeout(application_config.update_connect_timeout) + .get_updates_pool_timeout(application_config.update_pool_timeout) + .defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai"))) + .token(application_config.bot_token) + .request( + HTTPXRequest( + connection_pool_size=application_config.connection_pool_size, + proxy_url=application_config.proxy_url, + read_timeout=application_config.read_timeout, + write_timeout=application_config.write_timeout, + connect_timeout=application_config.connect_timeout, + pool_timeout=application_config.pool_timeout, + ) + ) + .build() + ) + web_server = Server( + uvicorn.Config( + app=FastAPI(debug=application_config.debug), + port=application_config.webserver.port, + host=application_config.webserver.host, + log_config=None, + ) + ) + return cls(managers, telegram, web_server) + + @property + def running(self) -> bool: + """bot 是否正在运行""" + with self._lock: + return self._running + + @property + def web_app(self) -> FastAPI: + """fastapi app""" + return self.web_server.config.app + + @property + def bot(self) -> Optional[Bot]: + return self.telegram.bot + + @property + def job_queue(self) -> Optional[JobQueue]: + return self.telegram.job_queue + + async def _on_startup(self) -> None: + for func in self._startup_funcs: + await self.managers.executor(func, block=getattr(func, "block", False)) + + async def _on_shutdown(self) -> None: + for func in self._shutdown_funcs: + await self.managers.executor(func, block=getattr(func, "block", False)) + + async def initialize(self): + """BOT 初始化""" + self.telegram.add_handler(LimiterHandler(limit_time=10), group=-1) # 启用入口洪水限制 + await self.managers.start_dependency() # 启动基础服务 + await self.managers.init_components() # 实例化组件 + await self.managers.start_services() # 启动其他服务 + await self.managers.install_plugins() # 安装插件 + + async def shutdown(self): + """BOT 关闭""" + await self.managers.uninstall_plugins() # 卸载插件 + await self.managers.stop_services() # 终止其他服务 + await self.managers.stop_dependency() # 终止基础服务 + + async def start(self) -> None: + """启动 BOT""" + logger.info("正在启动 BOT 中...") + + def error_callback(exc: TelegramError) -> None: + """错误信息回调""" + self.telegram.create_task(self.telegram.process_error(error=exc, update=None)) + + await self.telegram.initialize() + logger.info("[blue]Telegram[/] 初始化成功", extra={"markup": True}) + + if application_config.webserver.enable: # 如果使用 web app + server_config = self.web_server.config + server_config.setup_event_loop() + if not server_config.loaded: + server_config.load() + self.web_server.lifespan = server_config.lifespan_class(server_config) + try: + await self.web_server.startup() + except OSError as e: + if e.errno == 10048: + logger.error("Web Server 端口被占用:%s", e) + logger.error("Web Server 启动失败,正在退出") + raise SystemExit from None + + if self.web_server.should_exit: + logger.error("Web Server 启动失败,正在退出") + raise SystemExit from None + logger.success("Web Server 启动成功") + + self._web_server_task = asyncio.create_task(self.web_server.main_loop()) + + for _ in range(5): # 连接至 telegram 服务器 + try: + await self.telegram.updater.start_polling( + error_callback=error_callback, allowed_updates=Update.ALL_TYPES + ) + break + except TimedOut: + logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True}) + continue + except NetworkError as e: + logger.exception() + if isinstance(e, SSLZeroReturnError): + logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.") + else: + logger.error("网络连接出现问题, 请检查您的网络状况.") + raise SystemExit from e + + await self.initialize() + logger.success("BOT 初始化成功") + logger.debug("BOT 开始启动") + + await self._on_startup() + await self.telegram.start() + self._running = True + logger.success("BOT 启动成功") + + def stop_signal_handler(self, signum: int): + """终止信号处理""" + signals = {k: v for v, k in signal.__dict__.items() if v.startswith("SIG") and not v.startswith("SIG_")} + logger.debug("接收到了终止信号 %s 正在退出...", signals[signum]) + if self._web_server_task: + self._web_server_task.cancel() + + async def idle(self) -> None: + """在接收到中止信号之前,堵塞loop""" + + task = None + + def stop_handler(signum: int, _: "FrameType") -> None: + self.stop_signal_handler(signum) + task.cancel() + + for s in (SIGINT, SIGTERM, SIGABRT): + signal_func(s, stop_handler) + + while True: + task = asyncio.create_task(asyncio.sleep(600)) + + try: + await task + except asyncio.CancelledError: + break + + async def stop(self) -> None: + """关闭""" + logger.info("BOT 正在关闭") + self._running = False + + await self._on_shutdown() + + if self.telegram.updater.running: + await self.telegram.updater.stop() + + await self.shutdown() + + if self.telegram.running: + await self.telegram.stop() + + await self.telegram.shutdown() + if self.web_server is not None: + try: + await self.web_server.shutdown() + logger.info("Web Server 已经关闭") + except AttributeError: + pass + + logger.success("BOT 关闭成功") + + def launch(self) -> None: + """启动""" + loop = asyncio.get_event_loop() + try: + loop.run_until_complete(self.start()) + loop.run_until_complete(self.idle()) + except (SystemExit, KeyboardInterrupt) as exc: + logger.debug("接收到了终止信号,BOT 即将关闭", exc_info=exc) # 接收到了终止信号 + except NetworkError as e: + if isinstance(e, SSLZeroReturnError): + logger.critical("代理服务出现异常, 请检查您的代理服务是否配置成功.") + else: + logger.critical("网络连接出现问题, 请检查您的网络状况.") + except Exception as e: + logger.critical("遇到了未知错误: %s", {type(e)}, exc_info=e) + finally: + loop.run_until_complete(self.stop()) + + if application_config.reload: + raise SystemExit from None + + def on_startup(self, func: Callable[P, R]) -> Callable[P, R]: + """注册一个在 BOT 启动时执行的函数""" + + if func not in self._startup_funcs: + self._startup_funcs.append(func) + + # noinspection PyTypeChecker + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper + + def on_shutdown(self, func: Callable[P, R]) -> Callable[P, R]: + """注册一个在 BOT 停止时执行的函数""" + + if func not in self._shutdown_funcs: + self._shutdown_funcs.append(func) + + # noinspection PyTypeChecker + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper diff --git a/core/base/mysql.py b/core/base/mysql.py deleted file mode 100644 index 02f28bd0..00000000 --- a/core/base/mysql.py +++ /dev/null @@ -1,31 +0,0 @@ -from sqlalchemy.ext.asyncio import create_async_engine -from sqlalchemy.orm import sessionmaker -from sqlmodel.ext.asyncio.session import AsyncSession -from typing_extensions import Self - -from core.config import BotConfig -from core.service import Service - - -class MySQL(Service): - @classmethod - def from_config(cls, config: BotConfig) -> Self: - return cls(**config.mysql.dict()) - - def __init__(self, host: str, port: int, username: str, password: str, database: str): - self.database = database - self.password = password - self.user = username - self.port = port - self.host = host - self.url = f"mysql+asyncmy://{self.user}:{self.password}@{self.host}:{self.port}/{self.database}" - self.engine = create_async_engine(self.url) - self.Session = sessionmaker(bind=self.engine, class_=AsyncSession) - - async def get_session(self): - """获取会话""" - async with self.Session() as session: - yield session - - async def stop(self): - self.Session.close_all() diff --git a/core/base/webserver.py b/core/base/webserver.py deleted file mode 100644 index 8f57cbe8..00000000 --- a/core/base/webserver.py +++ /dev/null @@ -1,64 +0,0 @@ -import asyncio - -import uvicorn -from fastapi import FastAPI - -from core.config import ( - BotConfig, - config as botConfig, -) -from core.service import Service - -__all__ = ["webapp", "WebServer"] - -webapp = FastAPI(debug=botConfig.debug) - - -@webapp.get("/") -def index(): - return {"Hello": "Paimon"} - - -class WebServer(Service): - debug: bool - - host: str - port: int - - server: uvicorn.Server - - _server_task: asyncio.Task - - @classmethod - def from_config(cls, config: BotConfig) -> Service: - return cls(debug=config.debug, host=config.webserver.host, port=config.webserver.port) - - def __init__(self, debug: bool, host: str, port: int): - self.debug = debug - self.host = host - self.port = port - - self.server = uvicorn.Server( - uvicorn.Config(app=webapp, port=port, use_colors=False, host=host, log_config=None) - ) - - async def start(self): - """启动 service""" - - # 暂时只在开发环境启动 webserver 用于开发调试 - if not self.debug: - return - - # 防止 uvicorn server 拦截 signals - self.server.install_signal_handlers = lambda: None - self._server_task = asyncio.create_task(self.server.serve()) - - async def stop(self): - """关闭 service""" - if not self.debug: - return - - self.server.should_exit = True - - # 等待 task 结束 - await self._server_task diff --git a/core/base_service.py b/core/base_service.py new file mode 100644 index 00000000..c61a6e8b --- /dev/null +++ b/core/base_service.py @@ -0,0 +1,60 @@ +from abc import ABC +from itertools import chain +from typing import ClassVar, Iterable, Type, TypeVar + +from typing_extensions import Self + +from utils.helpers import isabstract + +__all__ = ("BaseService", "BaseServiceType", "DependenceType", "ComponentType", "get_all_services") + + +class _BaseService: + """服务基类""" + + _is_component: ClassVar[bool] = False + _is_dependence: ClassVar[bool] = False + + def __init_subclass__(cls, load: bool = True, **kwargs): + cls.is_dependence = cls._is_dependence + cls.is_component = cls._is_component + cls.load = load + + async def __aenter__(self) -> Self: + await self.initialize() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.shutdown() + + async def initialize(self) -> None: + """Initialize resources used by this service""" + + async def shutdown(self) -> None: + """Stop & clear resources used by this service""" + + +class _Dependence(_BaseService, ABC): + _is_dependence: ClassVar[bool] = True + + +class _Component(_BaseService, ABC): + _is_component: ClassVar[bool] = True + + +class BaseService(_BaseService, ABC): + Dependence: Type[_BaseService] = _Dependence + Component: Type[_BaseService] = _Component + + +BaseServiceType = TypeVar("BaseServiceType", bound=_BaseService) +DependenceType = TypeVar("DependenceType", bound=_Dependence) +ComponentType = TypeVar("ComponentType", bound=_Component) + + +# noinspection PyProtectedMember +def get_all_services() -> Iterable[Type[_BaseService]]: + return filter( + lambda x: x.__name__[0] != "_" and x.load and not isabstract(x), + chain(BaseService.__subclasses__(), _Dependence.__subclasses__(), _Component.__subclasses__()), + ) diff --git a/core/basemodel.py b/core/basemodel.py new file mode 100644 index 00000000..c65f58a6 --- /dev/null +++ b/core/basemodel.py @@ -0,0 +1,29 @@ +import enum + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +from pydantic import BaseSettings + +__all__ = ("RegionEnum", "Settings") + + +class RegionEnum(int, enum.Enum): + """账号数据所在服务器""" + + NULL = 0 + HYPERION = 1 # 米忽悠国服 hyperion + HOYOLAB = 2 # 米忽悠国际服 hoyolab + + +class Settings(BaseSettings): + def __new__(cls, *args, **kwargs): + cls.update_forward_refs() + return super(Settings, cls).__new__(cls) # pylint: disable=E1120 + + class Config(BaseSettings.Config): + case_sensitive = False + json_loads = jsonlib.loads + json_dumps = jsonlib.dumps diff --git a/core/baseplugin.py b/core/baseplugin.py deleted file mode 100644 index 33478560..00000000 --- a/core/baseplugin.py +++ /dev/null @@ -1,69 +0,0 @@ -from telegram import Update, ReplyKeyboardRemove -from telegram.error import BadRequest, Forbidden -from telegram.ext import CallbackContext, ConversationHandler - -from core.plugin import handler, conversation -from utils.bot import get_chat -from utils.log import logger - - -async def clean_message(context: CallbackContext): - job = context.job - message_id = job.data - chat_info = f"chat_id[{job.chat_id}]" - try: - chat = await get_chat(job.chat_id) - full_name = chat.full_name - if full_name: - chat_info = f"{full_name}[{chat.id}]" - else: - chat_info = f"{chat.title}[{chat.id}]" - except (BadRequest, Forbidden) as exc: - logger.warning("获取 chat info 失败 %s", exc.message) - except Exception as exc: - logger.warning("获取 chat info 消息失败 %s", str(exc)) - logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) - try: - # noinspection PyTypeChecker - await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) - except BadRequest as exc: - if "not found" in exc.message: - logger.warning("删除消息 %s message_id[%s] 失败 消息不存在", chat_info, message_id) - elif "Message can't be deleted" in exc.message: - logger.warning("删除消息 %s message_id[%s] 失败 消息无法删除 可能是没有授权", chat_info, message_id) - else: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - except Forbidden as exc: - if "bot was kicked" in exc.message: - logger.warning("删除消息 %s message_id[%s] 失败 已经被踢出群", chat_info, message_id) - else: - logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) - - -def add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int): - context.job_queue.run_once( - callback=clean_message, - when=delete_seconds, - data=message_id, - name=f"{chat_id}|{message_id}|clean_message", - chat_id=chat_id, - job_kwargs={"replace_existing": True, "id": f"{chat_id}|{message_id}|clean_message"}, - ) - - -class _BasePlugin: - @staticmethod - def _add_delete_message_job(context: CallbackContext, chat_id: int, message_id: int, delete_seconds: int = 60): - return add_delete_message_job(context, chat_id, message_id, delete_seconds) - - -class _Conversation(_BasePlugin): - @conversation.fallback - @handler.command(command="cancel", block=True) - async def cancel(self, update: Update, _: CallbackContext) -> int: - await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - - -class BasePlugin(_BasePlugin): - Conversation = _Conversation diff --git a/core/bot.py b/core/bot.py deleted file mode 100644 index 3321f6bc..00000000 --- a/core/bot.py +++ /dev/null @@ -1,345 +0,0 @@ -import asyncio -import inspect -import os -from asyncio import CancelledError -from importlib import import_module -from multiprocessing import RLock as Lock -from pathlib import Path -from typing import Any, Callable, ClassVar, Dict, Iterator, List, NoReturn, Optional, TYPE_CHECKING, Type, TypeVar - -import genshin -import pytz -from async_timeout import timeout -from telegram import Update -from telegram import __version__ as tg_version -from telegram.error import NetworkError, TimedOut -from telegram.ext import ( - AIORateLimiter, - Application as TgApplication, - CallbackContext, - Defaults, - JobQueue, - MessageHandler, - filters, - TypeHandler, -) -from telegram.ext.filters import StatusUpdate - -from core.config import BotConfig, config # pylint: disable=W0611 -from core.error import ServiceNotFoundError - -# noinspection PyProtectedMember -from core.plugin import Plugin, _Plugin -from core.service import Service -from metadata.scripts.metadatas import make_github_fast -from utils.const import PLUGIN_DIR, PROJECT_ROOT -from utils.log import logger - - -__all__ = ["bot"] - -T = TypeVar("T") -PluginType = TypeVar("PluginType", bound=_Plugin) - -try: - from telegram import __version_info__ as tg_version_info -except ImportError: - tg_version_info = (0, 0, 0, 0, 0) # type: ignore[assignment] - -if tg_version_info < (20, 0, 0, "alpha", 6): - logger.warning( - "Bot与当前PTB版本 [cyan bold]%s[/] [red bold]不兼容[/],请更新到最新版本后使用 [blue bold]poetry install[/] 重新安装依赖", - tg_version, - extra={"markup": True}, - ) - - -class Bot: - _lock: ClassVar[Lock] = Lock() - _instance: ClassVar[Optional["Bot"]] = None - - def __new__(cls, *args, **kwargs) -> "Bot": - """实现单例""" - with cls._lock: # 使线程、进程安全 - if cls._instance is None: - cls._instance = object.__new__(cls) - return cls._instance - - app: Optional[TgApplication] = None - _config: BotConfig = config - _services: Dict[Type[T], T] = {} - _running: bool = False - - def _inject(self, signature: inspect.Signature, target: Callable[..., T]) -> T: - kwargs = {} - for name, parameter in signature.parameters.items(): - if name != "self" and parameter.annotation != inspect.Parameter.empty: - if value := self._services.get(parameter.annotation): - kwargs[name] = value - return target(**kwargs) - - def init_inject(self, target: Callable[..., T]) -> T: - """用于实例化Plugin的方法。用于给插件传入一些必要组件,如 MySQL、Redis等""" - if isinstance(target, type): - signature = inspect.signature(target.__init__) - else: - signature = inspect.signature(target) - return self._inject(signature, target) - - async def async_inject(self, target: Callable[..., T]) -> T: - return await self._inject(inspect.signature(target), target) - - def _gen_pkg(self, root: Path) -> Iterator[str]: - """生成可以用于 import_module 导入的字符串""" - for path in root.iterdir(): - if not path.name.startswith("_"): - if path.is_dir(): - yield from self._gen_pkg(path) - elif path.suffix == ".py": - yield str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".") - - async def install_plugins(self): - """安装插件""" - for pkg in self._gen_pkg(PLUGIN_DIR): - try: - import_module(pkg) # 导入插件 - except Exception as e: # pylint: disable=W0703 - logger.exception( - '在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True} - ) - continue # 如有错误则继续 - callback_dict: Dict[int, List[Callable]] = {} - for plugin_cls in {*Plugin.__subclasses__(), *Plugin.Conversation.__subclasses__()}: - path = f"{plugin_cls.__module__}.{plugin_cls.__name__}" - try: - plugin: PluginType = self.init_inject(plugin_cls) - if hasattr(plugin, "__async_init__"): - await self.async_inject(plugin.__async_init__) - handlers = plugin.handlers - for index, handler in enumerate(handlers): - if isinstance(handler, TypeHandler): # 对 TypeHandler 进行特殊处理,优先级必须设置 -1,否则无用 - handlers.pop(index) - self.app.add_handler(handler, group=-1) - self.app.add_handlers(handlers) - if handlers: - logger.debug('插件 "%s" 添加了 %s 个 handler ', path, len(handlers)) - - # noinspection PyProtectedMember - for priority, callback in plugin._new_chat_members_handler_funcs(): # pylint: disable=W0212 - if not callback_dict.get(priority): - callback_dict[priority] = [] - callback_dict[priority].append(callback) - - error_handlers = plugin.error_handlers - for callback, block in error_handlers.items(): - self.app.add_error_handler(callback, block) - if error_handlers: - logger.debug('插件 "%s" 添加了 %s 个 error handler ', path, len(error_handlers)) - - if jobs := plugin.jobs: - logger.debug('插件 "%s" 添加了 %s 个 jobs ', path, len(jobs)) - logger.success('插件 "%s" 载入成功', path) - except Exception as e: # pylint: disable=W0703 - logger.exception( - '在安装插件 "%s" 的过程中遇到了错误 [red bold]%s[/]', path, type(e).__name__, exc_info=e, extra={"markup": True} - ) - if callback_dict: - num = sum(len(callback_dict[i]) for i in callback_dict) - - async def _new_chat_member_callback(update: "Update", context: "CallbackContext"): - nonlocal callback - for _, value in callback_dict.items(): - for callback in value: - await callback(update, context) - - self.app.add_handler( - MessageHandler(callback=_new_chat_member_callback, filters=StatusUpdate.NEW_CHAT_MEMBERS, block=False) - ) - logger.success( - "成功添加了 %s 个针对 [blue]%s[/] 的 [blue]MessageHandler[/]", - num, - StatusUpdate.NEW_CHAT_MEMBERS, - extra={"markup": True}, - ) - # special handler - from plugins.system.start import StartPlugin - - self.app.add_handler( - MessageHandler( - callback=StartPlugin.unknown_command, filters=filters.COMMAND & filters.ChatType.PRIVATE, block=False - ) - ) - - async def _start_base_services(self): - for pkg in self._gen_pkg(PROJECT_ROOT / "core/base"): - try: - import_module(pkg) - except Exception as e: # pylint: disable=W0703 - logger.exception( - '在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True} - ) - raise SystemExit from e - for base_service_cls in Service.__subclasses__(): - try: - if hasattr(base_service_cls, "from_config"): - instance = base_service_cls.from_config(self._config) - else: - instance = self.init_inject(base_service_cls) - await instance.start() - logger.success('服务 "%s" 初始化成功', base_service_cls.__name__) - self._services.update({base_service_cls: instance}) - except Exception as e: - logger.error('服务 "%s" 初始化失败', base_service_cls.__name__) - raise SystemExit from e - - async def start_services(self): - """启动服务""" - await self._start_base_services() - for path in (PROJECT_ROOT / "core").iterdir(): - if not path.name.startswith("_") and path.is_dir() and path.name != "base": - pkg = str(path.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".") - try: - import_module(pkg) - except Exception as e: # pylint: disable=W0703 - logger.exception( - '在导入文件 "%s" 的过程中遇到了错误 [red bold]%s[/]', - pkg, - type(e).__name__, - exc_info=e, - extra={"markup": True}, - ) - continue - - async def stop_services(self): - """关闭服务""" - if not self._services: - return - logger.info("正在关闭服务") - for _, service in filter(lambda x: not isinstance(x[1], TgApplication), self._services.items()): - async with timeout(5): - try: - if hasattr(service, "stop"): - if inspect.iscoroutinefunction(service.stop): - await service.stop() - else: - service.stop() - logger.success('服务 "%s" 关闭成功', service.__class__.__name__) - except CancelledError: - logger.warning('服务 "%s" 关闭超时', service.__class__.__name__) - except Exception as e: # pylint: disable=W0703 - logger.exception('服务 "%s" 关闭失败', service.__class__.__name__, exc_info=e) - - async def _post_init(self, context: CallbackContext) -> NoReturn: - logger.info("开始初始化 genshin.py 相关资源") - try: - # 替换为 fastgit 镜像源 - for i in dir(genshin.utility.extdb): - if "_URL" in i: - setattr( - genshin.utility.extdb, - i, - make_github_fast(getattr(genshin.utility.extdb, i)), - ) - await genshin.utility.update_characters_enka() - except Exception as exc: # pylint: disable=W0703 - logger.error("初始化 genshin.py 相关资源失败") - logger.exception(exc) - else: - logger.success("初始化 genshin.py 相关资源成功") - self._services.update({CallbackContext: context}) - logger.info("开始初始化服务") - await self.start_services() - logger.info("开始安装插件") - await self.install_plugins() - logger.info("BOT 初始化成功") - - def launch(self) -> NoReturn: - """启动机器人""" - self._running = True - logger.info("正在初始化BOT") - self.app = ( - TgApplication.builder() - .read_timeout(self.config.read_timeout) - .write_timeout(self.config.write_timeout) - .connect_timeout(self.config.connect_timeout) - .pool_timeout(self.config.pool_timeout) - .get_updates_read_timeout(self.config.update_read_timeout) - .get_updates_write_timeout(self.config.update_write_timeout) - .get_updates_connect_timeout(self.config.update_connect_timeout) - .get_updates_pool_timeout(self.config.update_pool_timeout) - .rate_limiter(AIORateLimiter()) - .defaults(Defaults(tzinfo=pytz.timezone("Asia/Shanghai"))) - .token(self._config.bot_token) - .post_init(self._post_init) - .build() - ) - try: - for _ in range(5): - try: - self.app.run_polling( - close_loop=False, - timeout=self.config.timeout, - allowed_updates=Update.ALL_TYPES, - ) - break - except TimedOut: - logger.warning("连接至 [blue]telegram[/] 服务器失败,正在重试", extra={"markup": True}) - continue - except NetworkError as e: - if "SSLZeroReturnError" in str(e): - logger.error("代理服务出现异常, 请检查您的代理服务是否配置成功.") - else: - logger.error("网络连接出现问题, 请检查您的网络状况.") - break - except (SystemExit, KeyboardInterrupt): - pass - except Exception as e: # pylint: disable=W0703 - logger.exception("BOT 执行过程中出现错误", exc_info=e) - finally: - loop = asyncio.get_event_loop() - loop.run_until_complete(self.stop_services()) - loop.close() - logger.info("BOT 已经关闭") - self._running = False - - def find_service(self, target: Type[T]) -> T: - """查找服务。若没找到则抛出 ServiceNotFoundError""" - if (result := self._services.get(target)) is None: - raise ServiceNotFoundError(target) - return result - - def add_service(self, service: T) -> NoReturn: - """添加服务。若已经有同类型的服务,则会抛出异常""" - if type(service) in self._services: - raise ValueError(f'Service "{type(service)}" is already existed.') - self.update_service(service) - - def update_service(self, service: T): - """更新服务。若服务不存在,则添加;若存在,则更新""" - self._services.update({type(service): service}) - - def contain_service(self, service: Any) -> bool: - """判断服务是否存在""" - if isinstance(service, type): - return service in self._services - else: - return service in self._services.values() - - @property - def job_queue(self) -> JobQueue: - return self.app.job_queue - - @property - def services(self) -> Dict[Type[T], T]: - return self._services - - @property - def config(self) -> BotConfig: - return self._config - - @property - def is_running(self) -> bool: - return self._running - - -bot = Bot() diff --git a/core/builtins/__init__.py b/core/builtins/__init__.py new file mode 100644 index 00000000..4f296665 --- /dev/null +++ b/core/builtins/__init__.py @@ -0,0 +1 @@ +"""bot builtins""" diff --git a/core/builtins/contexts.py b/core/builtins/contexts.py new file mode 100644 index 00000000..832c9788 --- /dev/null +++ b/core/builtins/contexts.py @@ -0,0 +1,38 @@ +"""上下文管理""" +from contextlib import contextmanager +from contextvars import ContextVar +from typing import TYPE_CHECKING + +if TYPE_CHECKING: + from telegram.ext import CallbackContext + from telegram import Update + +__all__ = [ + "CallbackContextCV", + "UpdateCV", + "handler_contexts", + "job_contexts", +] + +CallbackContextCV: ContextVar["CallbackContext"] = ContextVar("TelegramContextCallback") +UpdateCV: ContextVar["Update"] = ContextVar("TelegramUpdate") + + +@contextmanager +def handler_contexts(update: "Update", context: "CallbackContext") -> None: + context_token = CallbackContextCV.set(context) + update_token = UpdateCV.set(update) + try: + yield + finally: + CallbackContextCV.reset(context_token) + UpdateCV.reset(update_token) + + +@contextmanager +def job_contexts(context: "CallbackContext") -> None: + token = CallbackContextCV.set(context) + try: + yield + finally: + CallbackContextCV.reset(token) diff --git a/core/builtins/dispatcher.py b/core/builtins/dispatcher.py new file mode 100644 index 00000000..a51cfcd7 --- /dev/null +++ b/core/builtins/dispatcher.py @@ -0,0 +1,309 @@ +"""参数分发器""" +import asyncio +import inspect +from abc import ABC, abstractmethod +from asyncio import AbstractEventLoop +from functools import cached_property, lru_cache, partial, wraps +from inspect import Parameter, Signature +from itertools import chain +from types import GenericAlias, MethodType +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Sequence, + Type, + Union, +) + +from arkowrapper import ArkoWrapper +from fastapi import FastAPI +from telegram import Bot as TelegramBot, Chat, Message, Update, User +from telegram.ext import Application as TelegramApplication, CallbackContext, Job +from typing_extensions import ParamSpec +from uvicorn import Server + +from core.application import Application +from utils.const import WRAPPER_ASSIGNMENTS +from utils.typedefs import R, T + +__all__ = ( + "catch", + "AbstractDispatcher", + "BaseDispatcher", + "HandlerDispatcher", + "JobDispatcher", + "dispatched", +) + +P = ParamSpec("P") + +TargetType = Union[Type, str, Callable[[Any], bool]] + +_CATCH_TARGET_ATTR = "_catch_targets" + + +def catch(*targets: Union[str, Type]) -> Callable[[Callable[P, R]], Callable[P, R]]: + def decorate(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, _CATCH_TARGET_ATTR, targets) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapper + + return decorate + + +@lru_cache(64) +def get_signature(func: Union[type, Callable]) -> Signature: + if isinstance(func, type): + return inspect.signature(func.__init__) + return inspect.signature(func) + + +class AbstractDispatcher(ABC): + """参数分发器""" + + IGNORED_ATTRS = [] + + _args: List[Any] = [] + _kwargs: Dict[Union[str, Type], Any] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + def __init__(self, *args, **kwargs) -> None: + self._args = list(args) + self._kwargs = dict(kwargs) + + for _, value in kwargs.items(): + type_arg = type(value) + if type_arg != str: + self._kwargs[type_arg] = value + + for arg in args: + type_arg = type(arg) + if type_arg != str: + self._kwargs[type_arg] = arg + + @cached_property + def catch_funcs(self) -> List[MethodType]: + # noinspection PyTypeChecker + return list( + ArkoWrapper(dir(self)) + .filter(lambda x: not x.startswith("_")) + .filter( + lambda x: x not in self.IGNORED_ATTRS + ["dispatch", "catch_funcs", "catch_func_map", "dispatch_funcs"] + ) + .map(lambda x: getattr(self, x)) + .filter(lambda x: isinstance(x, MethodType)) + .filter(lambda x: hasattr(x, "_catch_targets")) + ) + + @cached_property + def catch_func_map(self) -> Dict[Union[str, Type[T]], Callable[..., T]]: + result = {} + for catch_func in self.catch_funcs: + catch_targets = getattr(catch_func, _CATCH_TARGET_ATTR) + for catch_target in catch_targets: + result[catch_target] = catch_func + return result + + @cached_property + def dispatch_funcs(self) -> List[MethodType]: + return list( + ArkoWrapper(dir(self)) + .filter(lambda x: x.startswith("dispatch_by_")) + .map(lambda x: getattr(self, x)) + .filter(lambda x: isinstance(x, MethodType)) + ) + + @abstractmethod + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + """默认的 dispatch 方法""" + + @abstractmethod + def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: + """使用 catch_func 获取并分配参数""" + + def dispatch(self, func: Callable[P, R]) -> Callable[..., R]: + """将参数分配给函数,从而合成一个无需参数即可执行的函数""" + params = {} + signature = get_signature(func) + parameters: Dict[str, Parameter] = dict(signature.parameters) + + for name, parameter in list(parameters.items()): + parameter: Parameter + if any( + [ + name == "self" and isinstance(func, (type, MethodType)), + parameter.kind in [Parameter.VAR_KEYWORD, Parameter.VAR_POSITIONAL], + ] + ): + del parameters[name] + continue + + for dispatch_func in self.dispatch_funcs: + parameters[name] = dispatch_func(parameter) + + for name, parameter in parameters.items(): + if parameter.default != Parameter.empty: + params[name] = parameter.default + else: + params[name] = None + + return partial(func, **params) + + @catch(Application) + def catch_application(self) -> Application: + return self.application + + +class BaseDispatcher(AbstractDispatcher): + """默认参数分发器""" + + _instances: Sequence[Any] + + def _get_kwargs(self) -> Dict[Type[T], T]: + result = self._get_default_kwargs() + result[AbstractDispatcher] = self + result.update(self._kwargs) + return result + + def _get_default_kwargs(self) -> Dict[Type[T], T]: + application = self.application + _default_kwargs = { + FastAPI: application.web_app, + Server: application.web_server, + TelegramApplication: application.telegram, + TelegramBot: application.telegram.bot, + } + if not application.running: + for obj in chain( + application.managers.dependency, + application.managers.components, + application.managers.services, + application.managers.plugins, + ): + _default_kwargs[type(obj)] = obj + return {k: v for k, v in _default_kwargs.items() if v is not None} + + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + annotation = parameter.annotation + # noinspection PyTypeChecker + if isinstance(annotation, type) and (value := self._get_kwargs().get(annotation, None)) is not None: + parameter._default = value # pylint: disable=W0212 + return parameter + + def dispatch_by_catch_funcs(self, parameter: Parameter) -> Parameter: + annotation = parameter.annotation + if annotation != Any and isinstance(annotation, GenericAlias): + return parameter + + catch_func = self.catch_func_map.get(annotation) or self.catch_func_map.get(parameter.name) + if catch_func is not None: + # noinspection PyUnresolvedReferences,PyProtectedMember + parameter._default = catch_func() # pylint: disable=W0212 + return parameter + + @catch(AbstractEventLoop) + def catch_loop(self) -> AbstractEventLoop: + return asyncio.get_event_loop() + + +class HandlerDispatcher(BaseDispatcher): + """Handler 参数分发器""" + + def __init__(self, update: Optional[Update] = None, context: Optional[CallbackContext] = None, **kwargs) -> None: + super().__init__(update=update, context=context, **kwargs) + self._update = update + self._context = context + + def dispatch( + self, func: Callable[P, R], *, update: Optional[Update] = None, context: Optional[CallbackContext] = None + ) -> Callable[..., R]: + self._update = update or self._update + self._context = context or self._context + if self._update is None: + from core.builtins.contexts import UpdateCV + + self._update = UpdateCV.get() + if self._context is None: + from core.builtins.contexts import CallbackContextCV + + self._context = CallbackContextCV.get() + return super().dispatch(func) + + def dispatch_by_default(self, parameter: Parameter) -> Parameter: + """HandlerDispatcher 默认不使用 dispatch_by_default""" + return parameter + + @catch(Update) + def catch_update(self) -> Update: + return self._update + + @catch(CallbackContext) + def catch_context(self) -> CallbackContext: + return self._context + + @catch(Message) + def catch_message(self) -> Message: + return self._update.effective_message + + @catch(User) + def catch_user(self) -> User: + return self._update.effective_user + + @catch(Chat) + def catch_chat(self) -> Chat: + return self._update.effective_chat + + +class JobDispatcher(BaseDispatcher): + """Job 参数分发器""" + + def __init__(self, context: Optional[CallbackContext] = None, **kwargs) -> None: + super().__init__(context=context, **kwargs) + self._context = context + + def dispatch(self, func: Callable[P, R], *, context: Optional[CallbackContext] = None) -> Callable[..., R]: + self._context = context or self._context + if self._context is None: + from core.builtins.contexts import CallbackContextCV + + self._context = CallbackContextCV.get() + return super().dispatch(func) + + @catch("data") + def catch_data(self) -> Any: + return self._context.job.data + + @catch(Job) + def catch_job(self) -> Job: + return self._context.job + + @catch(CallbackContext) + def catch_context(self) -> CallbackContext: + return self._context + + +def dispatched(dispatcher: Type[AbstractDispatcher] = BaseDispatcher): + def decorate(func: Callable[P, R]) -> Callable[P, R]: + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapper(*args: P.args, **kwargs: P.kwargs) -> R: + return dispatcher().dispatch(func)(*args, **kwargs) + + return wrapper + + return decorate diff --git a/core/builtins/executor.py b/core/builtins/executor.py new file mode 100644 index 00000000..7fd1a547 --- /dev/null +++ b/core/builtins/executor.py @@ -0,0 +1,131 @@ +"""执行器""" +import inspect +from functools import cached_property +from multiprocessing import RLock as Lock +from typing import Callable, ClassVar, Dict, Generic, Optional, TYPE_CHECKING, Type, TypeVar + +from telegram import Update +from telegram.ext import CallbackContext +from typing_extensions import ParamSpec, Self + +from core.builtins.contexts import handler_contexts, job_contexts + +if TYPE_CHECKING: + from core.application import Application + from core.builtins.dispatcher import AbstractDispatcher, HandlerDispatcher + from multiprocessing.synchronize import RLock as LockType + +__all__ = ("BaseExecutor", "Executor", "HandlerExecutor", "JobExecutor") + +T = TypeVar("T") +R = TypeVar("R") +P = ParamSpec("P") + + +class BaseExecutor: + """执行器 + Args: + name(str): 该执行器的名称。执行器的名称是唯一的。 + + 只支持执行只拥有 POSITIONAL_OR_KEYWORD 和 KEYWORD_ONLY 两种参数类型的函数 + """ + + _lock: ClassVar["LockType"] = Lock() + _instances: ClassVar[Dict[str, Self]] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + def __new__(cls: Type[T], name: str, *args, **kwargs) -> T: + with cls._lock: + if (instance := cls._instances.get(name)) is None: + instance = object.__new__(cls) + instance.__init__(name, *args, **kwargs) + cls._instances.update({name: instance}) + return instance + + @cached_property + def name(self) -> str: + """当前执行器的名称""" + return self._name + + def __init__(self, name: str, dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: + self._name = name + self._dispatcher = dispatcher + + +class Executor(BaseExecutor, Generic[P, R]): + async def __call__( + self, + target: Callable[P, R], + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs, + ) -> R: + dispatcher = self._dispatcher or dispatcher + dispatcher_instance = dispatcher(**kwargs) + dispatcher_instance.set_application(application=self.application) + dispatched_func = dispatcher_instance.dispatch(target) # 分发参数,组成新函数 + + # 执行 + if inspect.iscoroutinefunction(target): + result = await dispatched_func() + else: + result = dispatched_func() + + return result + + +class HandlerExecutor(BaseExecutor, Generic[P, R]): + """Handler专用执行器""" + + _callback: Callable[P, R] + _dispatcher: "HandlerDispatcher" + + def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["HandlerDispatcher"]] = None) -> None: + if dispatcher is None: + from core.builtins.dispatcher import HandlerDispatcher + + dispatcher = HandlerDispatcher + super().__init__("handler", dispatcher) + self._callback = func + self._dispatcher = dispatcher() + + def set_application(self, application: "Application") -> None: + self._application = application + if self._dispatcher is not None: + self._dispatcher.set_application(application) + + async def __call__(self, update: Update, context: CallbackContext) -> R: + with handler_contexts(update, context): + dispatched_func = self._dispatcher.dispatch(self._callback, update=update, context=context) + return await dispatched_func() + + +class JobExecutor(BaseExecutor): + """Job 专用执行器""" + + def __init__(self, func: Callable[P, R], dispatcher: Optional[Type["AbstractDispatcher"]] = None) -> None: + if dispatcher is None: + from core.builtins.dispatcher import JobDispatcher + + dispatcher = JobDispatcher + super().__init__("job", dispatcher) + self._callback = func + self._dispatcher = dispatcher() + + def set_application(self, application: "Application") -> None: + self._application = application + if self._dispatcher is not None: + self._dispatcher.set_application(application) + + async def __call__(self, context: CallbackContext) -> R: + with job_contexts(context): + dispatched_func = self._dispatcher.dispatch(self._callback, context=context) + return await dispatched_func() diff --git a/core/builtins/reloader.py b/core/builtins/reloader.py new file mode 100644 index 00000000..6b09f075 --- /dev/null +++ b/core/builtins/reloader.py @@ -0,0 +1,185 @@ +import inspect +import multiprocessing +import os +import signal +import threading +from pathlib import Path +from typing import Callable, Iterator, List, Optional, TYPE_CHECKING + +from watchfiles import watch + +from utils.const import HANDLED_SIGNALS, PROJECT_ROOT +from utils.log import logger +from utils.typedefs import StrOrPath + +if TYPE_CHECKING: + from multiprocessing.process import BaseProcess + +__all__ = ("Reloader",) + +multiprocessing.allow_connection_pickling() +spawn = multiprocessing.get_context("spawn") + + +class FileFilter: + """监控文件过滤""" + + def __init__(self, includes: List[str], excludes: List[str]) -> None: + default_includes = ["*.py"] + self.includes = [default for default in default_includes if default not in excludes] + self.includes.extend(includes) + self.includes = list(set(self.includes)) + + default_excludes = [".*", ".py[cod]", ".sw.*", "~*", __file__] + self.excludes = [default for default in default_excludes if default not in includes] + self.exclude_dirs = [] + for e in excludes: + p = Path(e) + try: + is_dir = p.is_dir() + except OSError: + is_dir = False + + if is_dir: + self.exclude_dirs.append(p) + else: + self.excludes.append(e) + self.excludes = list(set(self.excludes)) + + def __call__(self, path: Path) -> bool: + for include_pattern in self.includes: + if path.match(include_pattern): + for exclude_dir in self.exclude_dirs: + if exclude_dir in path.parents: + return False + + for exclude_pattern in self.excludes: + if path.match(exclude_pattern): + return False + + return True + return False + + +class Reloader: + _target: Callable[..., None] + _process: "BaseProcess" + + @property + def process(self) -> "BaseProcess": + return self._process + + @property + def target(self) -> Callable[..., None]: + return self._target + + def __init__( + self, + target: Callable[..., None], + *, + reload_delay: float = 0.25, + reload_dirs: List[StrOrPath] = None, + reload_includes: List[str] = None, + reload_excludes: List[str] = None, + ): + if inspect.iscoroutinefunction(target): + raise ValueError("不支持异步函数") + self._target = target + + self.reload_delay = reload_delay + + _reload_dirs = [] + for reload_dir in reload_dirs or []: + _reload_dirs.append(PROJECT_ROOT.joinpath(Path(reload_dir))) + + self.reload_dirs = [] + for reload_dir in _reload_dirs: + append = True + for parent in reload_dir.parents: + if parent in _reload_dirs: + append = False + break + if append: + self.reload_dirs.append(reload_dir) + + if not self.reload_dirs: + logger.warning("需要检测的目标文件夹列表为空", extra={"tag": "Reloader"}) + + self._should_exit = threading.Event() + + frame = inspect.currentframe().f_back + + self.watch_filter = FileFilter(reload_includes or [], (reload_excludes or []) + [frame.f_globals["__file__"]]) + self.watcher = watch( + *self.reload_dirs, + watch_filter=None, + stop_event=self._should_exit, + yield_on_timeout=True, + ) + + def get_changes(self) -> Optional[List[Path]]: + if not self._process.is_alive(): + logger.info("目标进程已经关闭", extra={"tag": "Reloader"}) + self._should_exit.set() + try: + changes = next(self.watcher) + except StopIteration: + return None + if changes: + unique_paths = {Path(c[1]) for c in changes} + return [p for p in unique_paths if self.watch_filter(p)] + return None + + def __iter__(self) -> Iterator[Optional[List[Path]]]: + return self + + def __next__(self) -> Optional[List[Path]]: + return self.get_changes() + + def run(self) -> None: + self.startup() + for changes in self: + if changes: + logger.warning( + "检测到文件 %s 发生改变, 正在重载...", + [str(c.relative_to(PROJECT_ROOT)).replace(os.sep, "/") for c in changes], + extra={"tag": "Reloader"}, + ) + self.restart() + + self.shutdown() + + def signal_handler(self, *_) -> None: + """当接收到结束信号量时""" + self._process.join(3) + if self._process.is_alive(): + self._process.terminate() + self._process.join() + self._should_exit.set() + + def startup(self) -> None: + """启动进程""" + logger.info("目标进程正在启动", extra={"tag": "Reloader"}) + + for sig in HANDLED_SIGNALS: + signal.signal(sig, self.signal_handler) + + self._process = spawn.Process(target=self._target) + self._process.start() + logger.success("目标进程启动成功", extra={"tag": "Reloader"}) + + def restart(self) -> None: + """重启进程""" + self._process.terminate() + self._process.join(10) + + self._process = spawn.Process(target=self._target) + self._process.start() + logger.info("目标进程已经重载", extra={"tag": "Reloader"}) + + def shutdown(self) -> None: + """关闭进程""" + self._process.terminate() + self._process.join(10) + + logger.info("重载器已经关闭", extra={"tag": "Reloader"}) diff --git a/core/config.py b/core/config.py index f6e0cbe7..39872bdb 100644 --- a/core/config.py +++ b/core/config.py @@ -1,19 +1,15 @@ from enum import Enum from pathlib import Path -from typing import ( - List, - Optional, - Union, -) +from typing import List, Optional, Union import dotenv -from pydantic import AnyUrl, BaseModel, Field +from pydantic import AnyUrl, Field +from core.basemodel import Settings from utils.const import PROJECT_ROOT -from utils.models.base import Settings from utils.typedefs import NaturalNumber -__all__ = ["BotConfig", "config", "JoinGroups"] +__all__ = ("ApplicationConfig", "config", "JoinGroups") dotenv.load_dotenv() @@ -25,22 +21,12 @@ class JoinGroups(str, Enum): ALLOW_ALL = "ALLOW_ALL" -class ConfigChannel(BaseModel): - name: str - chat_id: int - - -class ConfigUser(BaseModel): - username: Optional[str] - user_id: int - - class MySqlConfig(Settings): host: str = "127.0.0.1" port: int = 3306 - username: str = None - password: str = None - database: str = None + username: Optional[str] = None + password: Optional[str] = None + database: Optional[str] = None class Config(Settings.Config): env_prefix = "db_" @@ -58,7 +44,7 @@ class Config(Settings.Config): class LoggerConfig(Settings): name: str = "TGPaimon" - width: int = 180 + width: Optional[int] = None time_format: str = "[%Y-%m-%d %X]" traceback_max_frames: int = 20 path: Path = PROJECT_ROOT / "logs" @@ -78,6 +64,9 @@ class MTProtoConfig(Settings): class WebServerConfig(Settings): + enable: bool = False + """是否启用WebServer""" + url: AnyUrl = "http://localhost:8080" host: str = "localhost" port: int = 8080 @@ -97,31 +86,49 @@ class Config(Settings.Config): env_prefix = "error_" -class NoticeConfig(Settings): - user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!" +class ReloadConfig(Settings): + delay: float = 0.25 + dirs: List[str] = [] + include: List[str] = [] + exclude: List[str] = [] class Config(Settings.Config): - env_prefix = "notice_" + env_prefix = "reload_" -class PluginConfig(Settings): - download_file_max_size: int = 5 +class NoticeConfig(Settings): + user_mismatch: str = "再乱点我叫西风骑士团、千岩军、天领奉行、三十人团和风纪官了!" class Config(Settings.Config): - env_prefix = "plugin_" + env_prefix = "notice_" -class BotConfig(Settings): +class ApplicationConfig(Settings): debug: bool = False + """debug 开关""" + retry: int = 5 + """重试次数""" + auto_reload: bool = False + """自动重载""" + + proxy_url: Optional[AnyUrl] = None + """代理链接""" bot_token: str = "" + """BOT的token""" + + owner: Optional[int] = None + + channels: List[int] = [] + """文章推送群组""" - channels: List["ConfigChannel"] = [] - admins: List["ConfigUser"] = [] verify_groups: List[Union[int, str]] = [] + """启用群验证功能的群组""" join_groups: Optional[JoinGroups] = JoinGroups.NO_ALLOW + """是否允许机器人被邀请到其它群组""" timeout: int = 10 + connection_pool_size: int = 256 read_timeout: Optional[float] = None write_timeout: Optional[float] = None connect_timeout: Optional[float] = None @@ -138,6 +145,7 @@ class BotConfig(Settings): pass_challenge_app_key: str = "" pass_challenge_user_web: str = "" + reload: ReloadConfig = ReloadConfig() mysql: MySqlConfig = MySqlConfig() logger: LoggerConfig = LoggerConfig() webserver: WebServerConfig = WebServerConfig() @@ -145,8 +153,7 @@ class BotConfig(Settings): mtproto: MTProtoConfig = MTProtoConfig() error: ErrorConfig = ErrorConfig() notice: NoticeConfig = NoticeConfig() - plugin: PluginConfig = PluginConfig() -BotConfig.update_forward_refs() -config = BotConfig() +ApplicationConfig.update_forward_refs() +config = ApplicationConfig() diff --git a/core/cookies/__init__.py b/core/cookies/__init__.py deleted file mode 100644 index a3c15cd4..00000000 --- a/core/cookies/__init__.py +++ /dev/null @@ -1,21 +0,0 @@ -from core.base.mysql import MySQL -from core.base.redisdb import RedisDB -from core.cookies.cache import PublicCookiesCache -from core.cookies.repositories import CookiesRepository -from core.cookies.services import CookiesService, PublicCookiesService -from core.service import init_service - - -@init_service -def create_cookie_service(mysql: MySQL): - _repository = CookiesRepository(mysql) - _service = CookiesService(_repository) - return _service - - -@init_service -def create_public_cookie_service(mysql: MySQL, redis: RedisDB): - _repository = CookiesRepository(mysql) - _cache = PublicCookiesCache(redis) - _service = PublicCookiesService(_repository, _cache) - return _service diff --git a/core/cookies/models.py b/core/cookies/models.py deleted file mode 100644 index 93010da5..00000000 --- a/core/cookies/models.py +++ /dev/null @@ -1,27 +0,0 @@ -import enum -from typing import Optional, Dict - -from sqlmodel import SQLModel, Field, JSON, Enum, Column - - -class CookiesStatusEnum(int, enum.Enum): - STATUS_SUCCESS = 0 - INVALID_COOKIES = 1 - TOO_MANY_REQUESTS = 2 - - -class Cookies(SQLModel): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - user_id: Optional[int] = Field(foreign_key="user.user_id") - cookies: Optional[Dict[str, str]] = Field(sa_column=Column(JSON)) - status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum))) - - -class HyperionCookie(Cookies, table=True): - __tablename__ = "mihoyo_cookies" - - -class HoyolabCookie(Cookies, table=True): - __tablename__ = "hoyoverse_cookies" diff --git a/core/cookies/repositories.py b/core/cookies/repositories.py deleted file mode 100644 index 9067e26b..00000000 --- a/core/cookies/repositories.py +++ /dev/null @@ -1,109 +0,0 @@ -from typing import cast, List - -from sqlalchemy import select -from sqlalchemy.exc import NoResultFound -from sqlmodel.ext.asyncio.session import AsyncSession - -from core.base.mysql import MySQL -from utils.error import RegionNotFoundError -from utils.models.base import RegionEnum -from .error import CookiesNotFoundError -from .models import HyperionCookie, HoyolabCookie, Cookies - - -class CookiesRepository: - def __init__(self, mysql: MySQL): - self.mysql = mysql - - async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region == RegionEnum.HYPERION: - db_data = HyperionCookie(user_id=user_id, cookies=cookies) - elif region == RegionEnum.HOYOLAB: - db_data = HoyolabCookie(user_id=user_id, cookies=cookies) - else: - raise RegionNotFoundError(region.name) - session.add(db_data) - await session.commit() - - async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region == RegionEnum.HYPERION: - statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) - elif region == RegionEnum.HOYOLAB: - statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id) - else: - raise RegionNotFoundError(region.name) - results = await session.exec(statement) - db_cookies = results.first() - if db_cookies is None: - raise CookiesNotFoundError(user_id) - db_cookies = db_cookies[0] - db_cookies.cookies = cookies - session.add(db_cookies) - await session.commit() - await session.refresh(db_cookies) - - async def update_cookies_ex(self, cookies: Cookies, region: RegionEnum): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region not in [RegionEnum.HYPERION, RegionEnum.HOYOLAB]: - raise RegionNotFoundError(region.name) - session.add(cookies) - await session.commit() - await session.refresh(cookies) - - async def get_cookies(self, user_id, region: RegionEnum) -> Cookies: - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region == RegionEnum.HYPERION: - statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) - results = await session.exec(statement) - db_cookies = results.first() - if db_cookies is None: - raise CookiesNotFoundError(user_id) - return db_cookies[0] - elif region == RegionEnum.HOYOLAB: - statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id) - results = await session.exec(statement) - db_cookies = results.first() - if db_cookies is None: - raise CookiesNotFoundError(user_id) - return db_cookies[0] - else: - raise RegionNotFoundError(region.name) - - async def get_all_cookies(self, region: RegionEnum) -> List[Cookies]: - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region == RegionEnum.HYPERION: - statement = select(HyperionCookie) - results = await session.exec(statement) - db_cookies = results.all() - return [cookies[0] for cookies in db_cookies] - elif region == RegionEnum.HOYOLAB: - statement = select(HoyolabCookie) - results = await session.exec(statement) - db_cookies = results.all() - return [cookies[0] for cookies in db_cookies] - else: - raise RegionNotFoundError(region.name) - - async def del_cookies(self, user_id, region: RegionEnum): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - if region == RegionEnum.HYPERION: - statement = select(HyperionCookie).where(HyperionCookie.user_id == user_id) - elif region == RegionEnum.HOYOLAB: - statement = select(HoyolabCookie).where(HoyolabCookie.user_id == user_id) - else: - raise RegionNotFoundError(region.name) - results = await session.execute(statement) - try: - db_cookies = results.unique().scalar_one() - except NoResultFound as exc: - raise CookiesNotFoundError(user_id) from exc - await session.delete(db_cookies) - await session.commit() diff --git a/core/dependence/__init__.py b/core/dependence/__init__.py new file mode 100644 index 00000000..4ac55c27 --- /dev/null +++ b/core/dependence/__init__.py @@ -0,0 +1 @@ +"""基础服务""" diff --git a/core/base/aiobrowser.py b/core/dependence/aiobrowser.py similarity index 56% rename from core/base/aiobrowser.py rename to core/dependence/aiobrowser.py index 7a467137..50c4037d 100644 --- a/core/base/aiobrowser.py +++ b/core/dependence/aiobrowser.py @@ -1,26 +1,40 @@ -from typing import Optional +from typing import Optional, TYPE_CHECKING -from playwright.async_api import Browser, Playwright, async_playwright, Error +from playwright.async_api import Error, async_playwright -from core.service import Service +from core.base_service import BaseService from utils.log import logger +if TYPE_CHECKING: + from playwright.async_api import Playwright as AsyncPlaywright, Browser + +__all__ = ("AioBrowser",) + + +class AioBrowser(BaseService.Dependence): + @property + def browser(self): + return self._browser -class AioBrowser(Service): def __init__(self, loop=None): - self.browser: Optional[Browser] = None - self._playwright: Optional[Playwright] = None + self._browser: Optional["Browser"] = None + self._playwright: Optional["AsyncPlaywright"] = None self._loop = loop - async def start(self): + async def get_browser(self): + if self._browser is None: + await self.initialize() + return self._browser + + async def initialize(self): if self._playwright is None: logger.info("正在尝试启动 [blue]Playwright[/]", extra={"markup": True}) self._playwright = await async_playwright().start() logger.success("[blue]Playwright[/] 启动成功", extra={"markup": True}) - if self.browser is None: + if self._browser is None: logger.info("正在尝试启动 [blue]Browser[/]", extra={"markup": True}) try: - self.browser = await self._playwright.chromium.launch(timeout=5000) + self._browser = await self._playwright.chromium.launch(timeout=5000) logger.success("[blue]Browser[/] 启动成功", extra={"markup": True}) except Error as err: if "playwright install" in str(err): @@ -33,15 +47,10 @@ async def start(self): raise RuntimeError("检查到 playwright 刚刚安装或者未升级\n请运行以下命令下载新浏览器\nplaywright install chromium") raise err - return self.browser + return self._browser - async def stop(self): - if self.browser is not None: - await self.browser.close() + async def shutdown(self): + if self._browser is not None: + await self._browser.close() if self._playwright is not None: - await self._playwright.stop() - - async def get_browser(self) -> Browser: - if self.browser is None: - await self.start() - return self.browser + self._playwright.stop() diff --git a/core/dependence/aiobrowser.pyi b/core/dependence/aiobrowser.pyi new file mode 100644 index 00000000..b823a61e --- /dev/null +++ b/core/dependence/aiobrowser.pyi @@ -0,0 +1,16 @@ +from asyncio import AbstractEventLoop + +from playwright.async_api import Browser, Playwright as AsyncPlaywright + +from core.base_service import BaseService + +__all__ = ("AioBrowser",) + +class AioBrowser(BaseService.Dependence): + _browser: Browser | None + _playwright: AsyncPlaywright | None + _loop: AbstractEventLoop + + @property + def browser(self) -> Browser | None: ... + async def get_browser(self) -> Browser: ... diff --git a/core/base/assets.py b/core/dependence/assets.py similarity index 97% rename from core/base/assets.py rename to core/dependence/assets.py index 68210b1d..5316b89b 100644 --- a/core/base/assets.py +++ b/core/dependence/assets.py @@ -17,7 +17,7 @@ from httpx import AsyncClient, HTTPError, HTTPStatusError, TransportError, URL from typing_extensions import Self -from core.service import Service +from core.base_service import BaseService from metadata.genshin import AVATAR_DATA, HONEY_DATA, MATERIAL_DATA, NAMECARD_DATA, WEAPON_DATA from metadata.scripts.honey import update_honey_metadata from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github @@ -31,6 +31,8 @@ from httpx import Response from multiprocessing.synchronize import RLock +__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets") + ICON_TYPE = Union[Callable[[bool], Awaitable[Optional[Path]]], Callable[..., Awaitable[Optional[Path]]]] NAME_MAP_TYPE = Dict[str, StrOrURL] @@ -127,7 +129,7 @@ async def _request(self, url: str, interval: float = 0.2) -> "Response": async def _download(self, url: StrOrURL, path: Path, retry: int = 5) -> Path | None: """从 url 下载图标至 path""" - logger.debug(f"正在从 {url} 下载图标至 {path}") + logger.debug("正在从 %s 下载图标至 %s", url, path) headers = {"user-agent": "TGPaimonBot/3.0"} if URL(url).host == "enka.network" else None for time in range(retry): try: @@ -204,8 +206,8 @@ def __getattr__(self, item: str): """魔法""" if item in self.icon_types: return partial(self._get_img, item=item) - else: - object.__getattribute__(self, item) + object.__getattribute__(self, item) + return None @abstractmethod @cached_property @@ -498,7 +500,7 @@ def honey_name_map(self) -> dict[str, str]: } -class AssetsService(Service): +class AssetsService(BaseService.Dependence): """asset服务 用于储存和管理 asset : @@ -527,8 +529,10 @@ def __init__(self): ): setattr(self, attr, globals()[assets_type_name]()) - async def start(self): # pylint: disable=R0201 + async def initialize(self) -> None: # pylint: disable=R0201 + """启动 AssetsService 服务,刷新元数据""" logger.info("正在刷新元数据") + # todo 这3个任务同时异步下载 await update_metadata_from_github(False) await update_metadata_from_ambr(False) await update_honey_metadata(False) diff --git a/core/dependence/assets.pyi b/core/dependence/assets.pyi new file mode 100644 index 00000000..ba1b2503 --- /dev/null +++ b/core/dependence/assets.pyi @@ -0,0 +1,167 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from functools import partial +from pathlib import Path +from typing import Awaitable, Callable, ClassVar, TypeVar + +from enkanetwork import Assets as EnkaAssets +from enkanetwork.model.assets import CharacterAsset as EnkaCharacterAsset +from httpx import AsyncClient +from typing_extensions import Self + +from core.base_service import BaseService +from utils.typedefs import StrOrInt + +__all__ = ("AssetsServiceType", "AssetsService", "AssetsServiceError", "AssetsCouldNotFound", "DEFAULT_EnkaAssets") + +ICON_TYPE = Callable[[bool], Awaitable[Path | None]] | Callable[..., Awaitable[Path | None]] +DEFAULT_EnkaAssets: EnkaAssets +_GET_TYPE = partial | list[str] | int | str | ICON_TYPE | Path | AsyncClient | None | Self | dict[str, str] + +class AssetsServiceError(Exception): ... + +class AssetsCouldNotFound(AssetsServiceError): + message: str + target: str + def __init__(self, message: str, target: str): ... + +class _AssetsService(ABC): + icon_types: ClassVar[list[str]] + id: int + type: str + + icon: ICON_TYPE + """图标""" + + @abstractmethod + @property + def game_name(self) -> str: + """游戏数据中的名称""" + @property + def honey_id(self) -> str: + """当前资源在 Honey Impact 所对应的 ID""" + @property + def path(self) -> Path: + """当前资源的文件夹""" + @property + def client(self) -> AsyncClient: + """当前的 http client""" + def __init__(self, client: AsyncClient | None = None) -> None: ... + def __call__(self, target: int) -> Self: + """用于生成与 target 对应的 assets""" + def __getattr__(self, item: str) -> _GET_TYPE: + """魔法""" + async def get_link(self, item: str) -> str | None: + """获取相应图标链接""" + @abstractmethod + @property + def game_name_map(self) -> dict[str, str]: + """游戏中的图标名""" + @abstractmethod + @property + def honey_name_map(self) -> dict[str, str]: + """来自honey的图标名""" + +class _AvatarAssets(_AssetsService): + enka: EnkaCharacterAsset | None + + side: ICON_TYPE + """侧视图图标""" + + card: ICON_TYPE + """卡片图标""" + + gacha: ICON_TYPE + """抽卡立绘""" + + gacha_card: ICON_TYPE + """抽卡卡片""" + + @property + def honey_name_map(self) -> dict[str, str]: ... + @property + def game_name_map(self) -> dict[str, str]: ... + @property + def enka(self) -> EnkaCharacterAsset | None: ... + def __init__(self, client: AsyncClient | None = None, enka: EnkaAssets | None = None) -> None: ... + def __call__(self, target: StrOrInt) -> Self: ... + def __getitem__(self, item: str) -> _GET_TYPE | EnkaCharacterAsset: ... + def game_name(self) -> str: ... + +class _WeaponAssets(_AssetsService): + awaken: ICON_TYPE + """突破后图标""" + + gacha: ICON_TYPE + """抽卡立绘""" + + @property + def honey_name_map(self) -> dict[str, str]: ... + @property + def game_name_map(self) -> dict[str, str]: ... + def __call__(self, target: StrOrInt) -> Self: ... + def game_name(self) -> str: ... + +class _MaterialAssets(_AssetsService): + @property + def honey_name_map(self) -> dict[str, str]: ... + @property + def game_name_map(self) -> dict[str, str]: ... + def __call__(self, target: StrOrInt) -> Self: ... + def game_name(self) -> str: ... + +class _ArtifactAssets(_AssetsService): + flower: ICON_TYPE + """生之花""" + + plume: ICON_TYPE + """死之羽""" + + sands: ICON_TYPE + """时之沙""" + + goblet: ICON_TYPE + """空之杯""" + + circlet: ICON_TYPE + """理之冠""" + + @property + def honey_name_map(self) -> dict[str, str]: ... + @property + def game_name_map(self) -> dict[str, str]: ... + def game_name(self) -> str: ... + +class _NamecardAssets(_AssetsService): + enka: EnkaCharacterAsset | None + + navbar: ICON_TYPE + """好友名片背景""" + + profile: ICON_TYPE + """个人资料名片背景""" + + @property + def honey_name_map(self) -> dict[str, str]: ... + @property + def game_name_map(self) -> dict[str, str]: ... + def game_name(self) -> str: ... + +class AssetsService(BaseService.Dependence): + avatar: _AvatarAssets + """角色""" + + weapon: _WeaponAssets + """武器""" + + material: _MaterialAssets + """素材""" + + artifact: _ArtifactAssets + """圣遗物""" + + namecard: _NamecardAssets + """名片""" + +AssetsServiceType = TypeVar("AssetsServiceType", bound=_AssetsService) diff --git a/core/base/mtproto.py b/core/dependence/mtproto.py similarity index 73% rename from core/base/mtproto.py rename to core/dependence/mtproto.py index d491aac2..b6f9ddcf 100644 --- a/core/base/mtproto.py +++ b/core/dependence/mtproto.py @@ -4,6 +4,8 @@ import aiofiles +from core.base_service import BaseService +from core.config import config as bot_config from utils.log import logger try: @@ -13,13 +15,12 @@ session.log.debug = lambda *args, **kwargs: None # 关闭日记 PYROGRAM_AVAILABLE = True except ImportError: + Client = None + session = None PYROGRAM_AVAILABLE = False -from core.bot import bot -from core.service import Service - -class MTProto(Service): +class MTProto(BaseService.Dependence): async def get_session(self): async with aiofiles.open(self.session_path, mode="r") as f: return await f.read() @@ -32,9 +33,9 @@ def session_exists(self): return os.path.exists(self.session_path) def __init__(self): - self.name = "PaimonBot" + self.name = "paigram" current_dir = os.getcwd() - self.session_path = os.path.join(current_dir, "paimon.session") + self.session_path = os.path.join(current_dir, "paigram.session") self.client: Optional[Client] = None self.proxy: Optional[dict] = None http_proxy = os.environ.get("HTTP_PROXY") @@ -42,25 +43,25 @@ def __init__(self): http_proxy_url = urlparse(http_proxy) self.proxy = {"scheme": "http", "hostname": http_proxy_url.hostname, "port": http_proxy_url.port} - async def start(self): # pylint: disable=W0221 + async def initialize(self): # pylint: disable=W0221 if not PYROGRAM_AVAILABLE: logger.info("MTProto 服务需要的 pyrogram 模块未导入 本次服务 client 为 None") return - if bot.config.mtproto.api_id is None: + if bot_config.mtproto.api_id is None: logger.info("MTProto 服务需要的 api_id 未配置 本次服务 client 为 None") return - if bot.config.mtproto.api_hash is None: + if bot_config.mtproto.api_hash is None: logger.info("MTProto 服务需要的 api_hash 未配置 本次服务 client 为 None") return self.client = Client( - api_id=bot.config.mtproto.api_id, - api_hash=bot.config.mtproto.api_hash, + api_id=bot_config.mtproto.api_id, + api_hash=bot_config.mtproto.api_hash, name=self.name, - bot_token=bot.config.bot_token, + bot_token=bot_config.bot_token, proxy=self.proxy, ) await self.client.start() - async def stop(self): # pylint: disable=W0221 + async def shutdown(self): # pylint: disable=W0221 if self.client is not None: await self.client.stop(block=False) diff --git a/core/dependence/mtproto.pyi b/core/dependence/mtproto.pyi new file mode 100644 index 00000000..a5f69a14 --- /dev/null +++ b/core/dependence/mtproto.pyi @@ -0,0 +1,31 @@ +from __future__ import annotations +from typing import TypedDict + +from core.base_service import BaseService + +try: + from pyrogram import Client + from pyrogram.session import session + + PYROGRAM_AVAILABLE = True +except ImportError: + Client = None + session = None + PYROGRAM_AVAILABLE = False + +__all__ = ("MTProto",) + +class _ProxyType(TypedDict): + scheme: str + hostname: str | None + port: int | None + +class MTProto(BaseService.Dependence): + name: str + session_path: str + client: Client | None + proxy: _ProxyType | None + + async def get_session(self) -> str: ... + async def set_session(self, b: str) -> None: ... + def session_exists(self) -> bool: ... diff --git a/core/dependence/mysql.py b/core/dependence/mysql.py new file mode 100644 index 00000000..a1b71b30 --- /dev/null +++ b/core/dependence/mysql.py @@ -0,0 +1,50 @@ +import contextlib +from typing import Optional + +from sqlalchemy.engine import URL +from sqlalchemy.ext.asyncio import create_async_engine +from sqlalchemy.orm import sessionmaker +from typing_extensions import Self + +from core.base_service import BaseService +from core.config import ApplicationConfig +from core.sqlmodel.session import AsyncSession + +__all__ = ("MySQL",) + + +class MySQL(BaseService.Dependence): + @classmethod + def from_config(cls, config: ApplicationConfig) -> Self: + return cls(**config.mysql.dict()) + + def __init__( + self, + host: Optional[str] = None, + port: Optional[int] = None, + username: Optional[str] = None, + password: Optional[str] = None, + database: Optional[str] = None, + ): + self.database = database + self.password = password + self.username = username + self.port = port + self.host = host + self.url = URL.create( + "mysql+asyncmy", + username=self.username, + password=self.password, + host=self.host, + port=self.port, + database=self.database, + ) + self.engine = create_async_engine(self.url) + self.Session = sessionmaker(bind=self.engine, class_=AsyncSession) + + @contextlib.asynccontextmanager + async def session(self) -> AsyncSession: + yield self.Session() + + async def shutdown(self): + self.Session.close_all() diff --git a/core/base/redisdb.py b/core/dependence/redisdb.py similarity index 84% rename from core/base/redisdb.py rename to core/dependence/redisdb.py index 1c171b20..d53e02d0 100644 --- a/core/base/redisdb.py +++ b/core/dependence/redisdb.py @@ -1,4 +1,3 @@ -import asyncio from typing import Optional, Union import fakeredis.aioredis @@ -6,14 +5,16 @@ from redis.exceptions import ConnectionError as RedisConnectionError, TimeoutError as RedisTimeoutError from typing_extensions import Self -from core.config import BotConfig -from core.service import Service +from core.base_service import BaseService +from core.config import ApplicationConfig from utils.log import logger +__all__ = ["RedisDB"] -class RedisDB(Service): + +class RedisDB(BaseService.Dependence): @classmethod - def from_config(cls, config: BotConfig) -> Self: + def from_config(cls, config: ApplicationConfig) -> Self: return cls(**config.redis.dict()) def __init__( @@ -24,6 +25,7 @@ def __init__( self.key_prefix = "paimon_bot" async def ping(self): + # noinspection PyUnresolvedReferences if await self.client.ping(): logger.info("连接 [red]Redis[/] 成功", extra={"markup": True}) else: @@ -34,7 +36,7 @@ async def start_fake_redis(self): self.client = fakeredis.aioredis.FakeRedis() await self.ping() - async def start(self): # pylint: disable=W0221 + async def initialize(self): logger.info("正在尝试建立与 [red]Redis[/] 连接", extra={"markup": True}) try: await self.ping() @@ -45,5 +47,5 @@ async def start(self): # pylint: disable=W0221 logger.warning("连接 [red]Redis[/] 失败,使用 [red]fakeredis[/] 模拟", extra={"markup": True}) await self.start_fake_redis() - async def stop(self): # pylint: disable=W0221 + async def shutdown(self): await self.client.close() diff --git a/core/game/__init__.py b/core/game/__init__.py deleted file mode 100644 index cb4e271f..00000000 --- a/core/game/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -from core.base.redisdb import RedisDB -from core.service import init_service -from .cache import GameCache -from .services import GameMaterialService, GameStrategyService - - -@init_service -def create_game_strategy_service(redis: RedisDB): - _cache = GameCache(redis, "game:strategy") - return GameStrategyService(_cache) - - -@init_service -def create_game_material_service(redis: RedisDB): - _cache = GameCache(redis, "game:material") - return GameMaterialService(_cache) diff --git a/plugins/genshin/daily/__init__.py b/core/handler/__init__.py similarity index 100% rename from plugins/genshin/daily/__init__.py rename to core/handler/__init__.py diff --git a/core/handler/adminhandler.py b/core/handler/adminhandler.py new file mode 100644 index 00000000..64648a33 --- /dev/null +++ b/core/handler/adminhandler.py @@ -0,0 +1,59 @@ +import asyncio +from typing import TypeVar, TYPE_CHECKING, Any, Optional + +from telegram import Update +from telegram.ext import ApplicationHandlerStop, BaseHandler + +from core.error import ServiceNotFoundError +from core.services.users.services import UserAdminService +from utils.log import logger + +if TYPE_CHECKING: + from core.application import Application + from telegram.ext import Application as TelegramApplication + +RT = TypeVar("RT") +UT = TypeVar("UT") + +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class AdminHandler(BaseHandler[Update, CCT]): + _lock = asyncio.Lock() + + def __init__(self, handler: BaseHandler[Update, CCT], application: "Application") -> None: + self.handler = handler + self.application = application + self.user_service: Optional["UserAdminService"] = None + super().__init__(self.handler.callback) + + def check_update(self, update: object) -> bool: + if not isinstance(update, Update): + return False + return self.handler.check_update(update) + + async def _user_service(self) -> "UserAdminService": + async with self._lock: + if self.user_service is not None: + return self.user_service + user_service: UserAdminService = self.application.managers.services_map.get(UserAdminService, None) + if user_service is None: + raise ServiceNotFoundError("UserAdminService") + self.user_service = user_service + return self.user_service + + async def handle_update( + self, + update: "UT", + application: "TelegramApplication[Any, CCT, Any, Any, Any, Any]", + check_result: Any, + context: "CCT", + ) -> RT: + user_service = await self._user_service() + user = update.effective_user + if await user_service.is_admin(user.id): + return await self.handler.handle_update(update, application, check_result, context) + message = update.effective_message + logger.warning("用户 %s[%s] 触发尝试调用Admin命令但权限不足", user.full_name, user.id) + await message.reply_text("权限不足") + raise ApplicationHandlerStop diff --git a/core/handler/callbackqueryhandler.py b/core/handler/callbackqueryhandler.py new file mode 100644 index 00000000..f931e4e3 --- /dev/null +++ b/core/handler/callbackqueryhandler.py @@ -0,0 +1,62 @@ +import asyncio +from contextlib import AbstractAsyncContextManager +from types import TracebackType +from typing import TypeVar, TYPE_CHECKING, Any, Optional, Type + +from telegram.ext import CallbackQueryHandler as BaseCallbackQueryHandler, ApplicationHandlerStop + +from utils.log import logger + +if TYPE_CHECKING: + from telegram.ext import Application + +RT = TypeVar("RT") +UT = TypeVar("UT") +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class OverlappingException(Exception): + pass + + +class OverlappingContext(AbstractAsyncContextManager): + _lock = asyncio.Lock() + + def __init__(self, context: "CCT"): + self.context = context + + async def __aenter__(self) -> None: + async with self._lock: + flag = self.context.user_data.get("overlapping", False) + if flag: + raise OverlappingException + self.context.user_data["overlapping"] = True + return None + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: + async with self._lock: + del self.context.user_data["overlapping"] + return None + + +class CallbackQueryHandler(BaseCallbackQueryHandler): + async def handle_update( + self, + update: "UT", + application: "Application[Any, CCT, Any, Any, Any, Any]", + check_result: Any, + context: "CCT", + ) -> RT: + self.collect_additional_context(context, update, application, check_result) + try: + async with OverlappingContext(context): + return await self.callback(update, context) + except OverlappingException as exc: + user = update.effective_user + logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id) + raise ApplicationHandlerStop from exc diff --git a/core/handler/limiterhandler.py b/core/handler/limiterhandler.py new file mode 100644 index 00000000..53bc4c0c --- /dev/null +++ b/core/handler/limiterhandler.py @@ -0,0 +1,71 @@ +import asyncio +from typing import TypeVar, Optional + +from telegram import Update +from telegram.ext import ContextTypes, ApplicationHandlerStop, TypeHandler + +from utils.log import logger + +UT = TypeVar("UT") +CCT = TypeVar("CCT", bound="CallbackContext[Any, Any, Any, Any]") + + +class LimiterHandler(TypeHandler[UT, CCT]): + _lock = asyncio.Lock() + + def __init__( + self, max_rate: float = 5, time_period: float = 10, amount: float = 1, limit_time: Optional[float] = None + ): + """Limiter Handler 通过 + `Leaky bucket algorithm `_ + 实现对用户的输入的精确控制 + + 输入超过一定速率后,代码会抛出 + :class:`telegram.ext.ApplicationHandlerStop` + 异常并在一段时间内防止用户执行任何其他操作 + + :param max_rate: 在抛出异常之前最多允许 频率/秒 的速度 + :param time_period: 在限制速率的时间段的持续时间 + :param amount: 提供的容量 + :param limit_time: 限制时间 如果不提供限制时间为 max_rate / time_period * amount + """ + self.max_rate = max_rate + self.amount = amount + self._rate_per_sec = max_rate / time_period + self.limit_time = limit_time + super().__init__(Update, self.limiter_callback) + + async def limiter_callback(self, update: Update, context: ContextTypes.DEFAULT_TYPE): + if update.inline_query is not None: + return + loop = asyncio.get_running_loop() + async with self._lock: + time = loop.time() + user_data = context.user_data + if user_data is None: + return + user_limit_time = user_data.get("limit_time") + if user_limit_time is not None: + if time >= user_limit_time: + del user_data["limit_time"] + else: + raise ApplicationHandlerStop + last_task_time = user_data.get("last_task_time", 0) + if last_task_time: + task_level = user_data.get("task_level", 0) + elapsed = time - last_task_time + decrement = elapsed * self._rate_per_sec + task_level = max(task_level - decrement, 0) + user_data["task_level"] = task_level + if not task_level + self.amount <= self.max_rate: + if self.limit_time: + limit_time = self.limit_time + else: + limit_time = 1 / self._rate_per_sec * self.amount + user_data["limit_time"] = time + limit_time + user = update.effective_user + logger.warning("用户 %s[%s] 触发洪水限制 已被限制 %s 秒", user.full_name, user.id, limit_time) + raise ApplicationHandlerStop + user_data["last_task_time"] = time + task_level = user_data.get("task_level", 0) + user_data["task_level"] = task_level + self.amount diff --git a/core/manager.py b/core/manager.py new file mode 100644 index 00000000..aad1512a --- /dev/null +++ b/core/manager.py @@ -0,0 +1,286 @@ +import asyncio +from importlib import import_module +from pathlib import Path +from typing import Dict, Generic, List, Optional, TYPE_CHECKING, Type, TypeVar + +from arkowrapper import ArkoWrapper +from async_timeout import timeout +from typing_extensions import ParamSpec + +from core.base_service import BaseServiceType, ComponentType, DependenceType, get_all_services +from core.config import config as bot_config +from utils.const import PLUGIN_DIR, PROJECT_ROOT +from utils.helpers import gen_pkg +from utils.log import logger + +if TYPE_CHECKING: + from core.application import Application + from core.plugin import PluginType + from core.builtins.executor import Executor + +__all__ = ("DependenceManager", "PluginManager", "ComponentManager", "ServiceManager", "Managers") + +R = TypeVar("R") +T = TypeVar("T") +P = ParamSpec("P") + + +def _load_module(path: Path) -> None: + for pkg in gen_pkg(path): + try: + logger.debug('正在导入 "%s"', pkg) + import_module(pkg) + except Exception as e: + logger.exception( + '在导入 "%s" 的过程中遇到了错误 [red bold]%s[/]', pkg, type(e).__name__, exc_info=e, extra={"markup": True} + ) + raise SystemExit from e + + +class Manager(Generic[T]): + """生命周期控制基类""" + + _executor: Optional["Executor"] = None + _lib: Dict[Type[T], T] = {} + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError(f"No application was set for this {self.__class__.__name__}.") + return self._application + + @property + def executor(self) -> "Executor": + """执行器""" + if self._executor is None: + raise RuntimeError(f"No executor was set for this {self.__class__.__name__}.") + return self._executor + + def build_executor(self, name: str): + from core.builtins.executor import Executor + from core.builtins.dispatcher import BaseDispatcher + + self._executor = Executor(name, dispatcher=BaseDispatcher) + self._executor.set_application(self.application) + + +class DependenceManager(Manager[DependenceType]): + """基础依赖管理""" + + _dependency: Dict[Type[DependenceType], DependenceType] = {} + + @property + def dependency(self) -> List[DependenceType]: + return list(self._dependency.values()) + + @property + def dependency_map(self) -> Dict[Type[DependenceType], DependenceType]: + return self._dependency + + async def start_dependency(self) -> None: + _load_module(PROJECT_ROOT / "core/dependence") + + for dependence in filter(lambda x: x.is_dependence, get_all_services()): + dependence: Type[DependenceType] + instance: DependenceType + try: + if hasattr(dependence, "from_config"): # 如果有 from_config 方法 + instance = dependence.from_config(bot_config) # 用 from_config 实例化服务 + else: + instance = await self.executor(dependence) + + await instance.initialize() + logger.success('基础服务 "%s" 启动成功', dependence.__name__) + + self._lib[dependence] = instance + self._dependency[dependence] = instance + + except Exception as e: + logger.exception('基础服务 "%s" 初始化失败,BOT 将自动关闭', dependence.__name__) + raise SystemExit from e + + async def stop_dependency(self) -> None: + async def task(d): + try: + async with timeout(5): + await d.shutdown() + logger.debug('基础服务 "%s" 关闭成功', d.__class__.__name__) + except asyncio.TimeoutError: + logger.warning('基础服务 "%s" 关闭超时', d.__class__.__name__) + except Exception as e: + logger.error('基础服务 "%s" 关闭错误', d.__class__.__name__, exc_info=e) + + tasks = [] + for dependence in self._dependency.values(): + tasks.append(asyncio.create_task(task(dependence))) + + await asyncio.gather(*tasks) + + +class ComponentManager(Manager[ComponentType]): + """组件管理""" + + _components: Dict[Type[ComponentType], ComponentType] = {} + + @property + def components(self) -> List[ComponentType]: + return list(self._components.values()) + + @property + def components_map(self) -> Dict[Type[ComponentType], ComponentType]: + return self._components + + async def init_components(self): + for path in filter( + lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() + ): + _load_module(path) + components = ArkoWrapper(get_all_services()).filter(lambda x: x.is_component) + retry_times = 0 + max_retry_times = len(components) + while components: + start_len = len(components) + for component in list(components): + component: Type[ComponentType] + instance: ComponentType + try: + instance = await self.executor(component) + self._lib[component] = instance + self._components[component] = instance + components = components.remove(component) + except Exception as e: # pylint: disable=W0703 + logger.debug('组件 "%s" 初始化失败: [red]%s[/]', component.__name__, e, extra={"markup": True}) + end_len = len(list(components)) + if start_len == end_len: + retry_times += 1 + + if retry_times == max_retry_times and components: + for component in components: + logger.error('组件 "%s" 初始化失败', component.__name__) + raise SystemExit + + +class ServiceManager(Manager[BaseServiceType]): + """服务控制类""" + + _services: Dict[Type[BaseServiceType], BaseServiceType] = {} + + @property + def services(self) -> List[BaseServiceType]: + return list(self._services.values()) + + @property + def services_map(self) -> Dict[Type[BaseServiceType], BaseServiceType]: + return self._services + + async def _initialize_service(self, target: Type[BaseServiceType]) -> BaseServiceType: + instance: BaseServiceType + try: + if hasattr(target, "from_config"): # 如果有 from_config 方法 + instance = target.from_config(bot_config) # 用 from_config 实例化服务 + else: + instance = await self.executor(target) + + await instance.initialize() + logger.success('服务 "%s" 启动成功', target.__name__) + + return instance + + except Exception as e: # pylint: disable=W0703 + logger.exception('服务 "%s" 初始化失败,BOT 将自动关闭', target.__name__) + raise SystemExit from e + + async def start_services(self) -> None: + for path in filter( + lambda x: x.is_dir() and not x.name.startswith("_"), PROJECT_ROOT.joinpath("core/services").iterdir() + ): + _load_module(path) + + for service in filter(lambda x: not x.is_component and not x.is_dependence, get_all_services()): # 遍历所有服务类 + instance = await self._initialize_service(service) + + self._lib[service] = instance + self._services[service] = instance + + async def stop_services(self) -> None: + """关闭服务""" + if not self._services: + return + + async def task(s): + try: + async with timeout(5): + await s.shutdown() + logger.success('服务 "%s" 关闭成功', s.__class__.__name__) + except asyncio.TimeoutError: + logger.warning('服务 "%s" 关闭超时', s.__class__.__name__) + except Exception as e: + logger.warning('服务 "%s" 关闭失败', s.__class__.__name__, exc_info=e) + + logger.info("正在关闭服务") + tasks = [] + for service in self._services.values(): + tasks.append(asyncio.create_task(task(service))) + + await asyncio.gather(*tasks) + + +class PluginManager(Manager["PluginType"]): + """插件管理""" + + _plugins: Dict[Type["PluginType"], "PluginType"] = {} + + @property + def plugins(self) -> List["PluginType"]: + """所有已经加载的插件""" + return list(self._plugins.values()) + + @property + def plugins_map(self) -> Dict[Type["PluginType"], "PluginType"]: + return self._plugins + + async def install_plugins(self) -> None: + """安装所有插件""" + from core.plugin import get_all_plugins + + for path in filter(lambda x: x.is_dir(), PLUGIN_DIR.iterdir()): + _load_module(path) + + for plugin in get_all_plugins(): + plugin: Type["PluginType"] + + try: + instance: "PluginType" = await self.executor(plugin) + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 初始化失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + continue + + self._plugins[plugin] = instance + + if self._application is not None: + instance.set_application(self._application) + + await asyncio.create_task(self.plugin_install_task(plugin, instance)) + + @staticmethod + async def plugin_install_task(plugin: Type["PluginType"], instance: "PluginType"): + try: + await instance.install() + logger.success('插件 "%s" 安装成功', f"{plugin.__module__}.{plugin.__name__}") + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 安装失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + + async def uninstall_plugins(self) -> None: + for plugin in self._plugins.values(): + try: + await plugin.uninstall() + except Exception as e: # pylint: disable=W0703 + logger.error('插件 "%s" 卸载失败', f"{plugin.__module__}.{plugin.__name__}", exc_info=e) + + +class Managers(DependenceManager, ComponentManager, ServiceManager, PluginManager): + """BOT 除自身外的生命周期管理类""" diff --git a/plugins/genshin/gacha/__init__.py b/core/override/__init__.py similarity index 100% rename from plugins/genshin/gacha/__init__.py rename to core/override/__init__.py diff --git a/core/override/telegram.py b/core/override/telegram.py new file mode 100644 index 00000000..18698b14 --- /dev/null +++ b/core/override/telegram.py @@ -0,0 +1,106 @@ +"""重写 telegram.request.HTTPXRequest 使其使用 ujson 库进行 json 序列化""" +from typing import Any, AsyncIterable, Optional + +import httpcore +from httpx import ( + AsyncByteStream, + AsyncHTTPTransport as DefaultAsyncHTTPTransport, + Limits, + Response as DefaultResponse, + Timeout, +) +from telegram.request import HTTPXRequest as DefaultHTTPXRequest + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +__all__ = ("HTTPXRequest",) + + +class Response(DefaultResponse): + def json(self, **kwargs: Any) -> Any: + # noinspection PyProtectedMember + from httpx._utils import guess_json_utf + + if self.charset_encoding is None and self.content and len(self.content) > 3: + encoding = guess_json_utf(self.content) + if encoding is not None: + return jsonlib.loads(self.content.decode(encoding), **kwargs) + return jsonlib.loads(self.text, **kwargs) + + +# noinspection PyProtectedMember +class AsyncHTTPTransport(DefaultAsyncHTTPTransport): + async def handle_async_request(self, request) -> Response: + from httpx._transports.default import ( + map_httpcore_exceptions, + AsyncResponseStream, + ) + + if not isinstance(request.stream, AsyncByteStream): + raise AssertionError + + req = httpcore.Request( + method=request.method, + url=httpcore.URL( + scheme=request.url.raw_scheme, + host=request.url.raw_host, + port=request.url.port, + target=request.url.raw_path, + ), + headers=request.headers.raw, + content=request.stream, + extensions=request.extensions, + ) + with map_httpcore_exceptions(): + resp = await self._pool.handle_async_request(req) + + if not isinstance(resp.stream, AsyncIterable): + raise AssertionError + + return Response( + status_code=resp.status, + headers=resp.headers, + stream=AsyncResponseStream(resp.stream), + extensions=resp.extensions, + ) + + +class HTTPXRequest(DefaultHTTPXRequest): + def __init__( # pylint: disable=W0231 + self, + connection_pool_size: int = 1, + proxy_url: str = None, + read_timeout: Optional[float] = 5.0, + write_timeout: Optional[float] = 5.0, + connect_timeout: Optional[float] = 5.0, + pool_timeout: Optional[float] = 1.0, + ): + timeout = Timeout( + connect=connect_timeout, + read=read_timeout, + write=write_timeout, + pool=pool_timeout, + ) + limits = Limits( + max_connections=connection_pool_size, + max_keepalive_connections=connection_pool_size, + ) + self._client_kwargs = dict( + timeout=timeout, + proxies=proxy_url, + limits=limits, + transport=AsyncHTTPTransport(limits=limits), + ) + + try: + self._client = self._build_client() + except ImportError as exc: + if "httpx[socks]" not in str(exc): + raise exc + + raise RuntimeError( + "To use Socks5 proxies, PTB must be installed via `pip install python-telegram-bot[socks]`." + ) from exc diff --git a/core/plugin.py b/core/plugin.py deleted file mode 100644 index fc906191..00000000 --- a/core/plugin.py +++ /dev/null @@ -1,483 +0,0 @@ -import copy -import datetime -import re -from importlib import import_module -from re import Pattern -from types import MethodType -from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union - -# noinspection PyProtectedMember -from telegram._utils.defaultvalue import DEFAULT_TRUE - -# noinspection PyProtectedMember -from telegram._utils.types import DVInput, JSONDict -from telegram.ext import BaseHandler, ConversationHandler, Job - -# noinspection PyProtectedMember -from telegram.ext._utils.types import JobCallback -from telegram.ext.filters import BaseFilter -from typing_extensions import ParamSpec - -__all__ = ["Plugin", "handler", "conversation", "job", "error_handler"] - -P = ParamSpec("P") -T = TypeVar("T") -HandlerType = TypeVar("HandlerType", bound=BaseHandler) -TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time] - -_Module = import_module("telegram.ext") - -_NORMAL_HANDLER_ATTR_NAME = "_handler_data" -_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_data" -_JOB_ATTR_NAME = "_job_data" - -_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"] - - -class _Plugin: - def _make_handler(self, datas: Union[List[Dict], Dict]) -> List[HandlerType]: - result = [] - if isinstance(datas, list): - for data in filter(lambda x: x, datas): - func = getattr(self, data.pop("func")) - result.append(data.pop("type")(callback=func, **data.pop("kwargs"))) - else: - func = getattr(self, datas.pop("func")) - result.append(datas.pop("type")(callback=func, **datas.pop("kwargs"))) - return result - - @property - def handlers(self) -> List[HandlerType]: - result = [] - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None)) - ): - for data in datas: - if data["type"] not in ["error", "new_chat_member"]: - result.extend(self._make_handler(data)) - return result - - def _new_chat_members_handler_funcs(self) -> List[Tuple[int, Callable]]: - result = [] - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None)) - ): - for data in datas: - if data and data["type"] == "new_chat_member": - result.append((data["priority"], func)) - - return result - - @property - def error_handlers(self) -> Dict[Callable, bool]: - result = {} - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None)) - ): - for data in datas: - if data and data["type"] == "error": - result.update({func: data["block"]}) - return result - - @property - def jobs(self) -> List[Job]: - from core.bot import bot - - result = [] - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) - and isinstance(func := getattr(self, attr), MethodType) - and (datas := getattr(func, _JOB_ATTR_NAME, None)) - ): - for data in datas: - _job = getattr(bot.job_queue, data.pop("type"))( - callback=func, **data.pop("kwargs"), **{key: data.pop(key) for key in list(data.keys())} - ) - result.append(_job) - return result - - -class _Conversation(_Plugin): - _conversation_kwargs: Dict - - def __init_subclass__(cls, **kwargs): - cls._conversation_kwargs = kwargs - super(_Conversation, cls).__init_subclass__() - return cls - - @property - def handlers(self) -> List[HandlerType]: - result: List[HandlerType] = [] - - entry_points: List[HandlerType] = [] - states: Dict[Any, List[HandlerType]] = {} - fallbacks: List[HandlerType] = [] - for attr in dir(self): - # noinspection PyUnboundLocalVariable - if ( - not (attr.startswith("_") or attr == "handlers") - and isinstance(func := getattr(self, attr), Callable) - and (handler_datas := getattr(func, _NORMAL_HANDLER_ATTR_NAME, None)) - ): - conversation_data = getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None) - if attr == "cancel": - handler_datas = copy.deepcopy(handler_datas) - conversation_data = copy.deepcopy(conversation_data) - _handlers = self._make_handler(handler_datas) - if conversation_data: - if (_type := conversation_data.pop("type")) == "entry": - entry_points.extend(_handlers) - elif _type == "state": - if (key := conversation_data.pop("state")) in states: - states[key].extend(_handlers) - else: - states[key] = _handlers - elif _type == "fallback": - fallbacks.extend(_handlers) - else: - result.extend(_handlers) - if entry_points or states or fallbacks: - result.append( - ConversationHandler( - entry_points, states, fallbacks, **self.__class__._conversation_kwargs # pylint: disable=W0212 - ) - ) - return result - - -class Plugin(_Plugin): - Conversation = _Conversation - - -class _Handler: - def __init__(self, **kwargs): - self.kwargs = kwargs - - @property - def _type(self) -> Type[BaseHandler]: - return getattr(_Module, f"{self.__class__.__name__.strip('_')}Handler") - - def __call__(self, func: Callable[P, T]) -> Callable[P, T]: - data = {"type": self._type, "func": func.__name__, "kwargs": self.kwargs} - if hasattr(func, _NORMAL_HANDLER_ATTR_NAME): - handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME) - handler_datas.append(data) - setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas) - else: - setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data]) - return func - - -class _CallbackQuery(_Handler): - def __init__( - self, - pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None, - block: DVInput[bool] = DEFAULT_TRUE, - ): - super(_CallbackQuery, self).__init__(pattern=pattern, block=block) - - -class _ChatJoinRequest(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE): - super(_ChatJoinRequest, self).__init__(block=block) - - -class _ChatMember(_Handler): - def __init__(self, chat_member_types: int = -1, block: DVInput[bool] = DEFAULT_TRUE): - super().__init__(chat_member_types=chat_member_types, block=block) - - -class _ChosenInlineResult(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE, pattern: Union[str, Pattern] = None): - super().__init__(block=block, pattern=pattern) - - -class _Command(_Handler): - def __init__(self, command: str, filters: "BaseFilter" = None, block: DVInput[bool] = DEFAULT_TRUE): - super(_Command, self).__init__(command=command, filters=filters, block=block) - - -class _InlineQuery(_Handler): - def __init__( - self, pattern: Union[str, Pattern] = None, block: DVInput[bool] = DEFAULT_TRUE, chat_types: List[str] = None - ): - super().__init__(pattern=pattern, block=block, chat_types=chat_types) - - -class _MessageNewChatMembers(_Handler): - def __init__(self, func: Callable[P, T] = None, *, priority: int = 5): - super().__init__() - self.func = func - self.priority = priority - - def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: - self.func = self.func or func - data = {"type": "new_chat_member", "priority": self.priority} - if hasattr(func, _NORMAL_HANDLER_ATTR_NAME): - handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME) - handler_datas.append(data) - setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas) - else: - setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data]) - return func - - -class _Message(_Handler): - def __init__( - self, - filters: "BaseFilter", - block: DVInput[bool] = DEFAULT_TRUE, - ): - super(_Message, self).__init__(filters=filters, block=block) - - new_chat_members = _MessageNewChatMembers - - -class _PollAnswer(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE): - super(_PollAnswer, self).__init__(block=block) - - -class _Poll(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE): - super(_Poll, self).__init__(block=block) - - -class _PreCheckoutQuery(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE): - super(_PreCheckoutQuery, self).__init__(block=block) - - -class _Prefix(_Handler): - def __init__( - self, - prefix: str, - command: str, - filters: BaseFilter = None, - block: DVInput[bool] = DEFAULT_TRUE, - ): - super(_Prefix, self).__init__(prefix=prefix, command=command, filters=filters, block=block) - - -class _ShippingQuery(_Handler): - def __init__(self, block: DVInput[bool] = DEFAULT_TRUE): - super(_ShippingQuery, self).__init__(block=block) - - -class _StringCommand(_Handler): - def __init__(self, command: str): - super(_StringCommand, self).__init__(command=command) - - -class _StringRegex(_Handler): - def __init__(self, pattern: Union[str, Pattern], block: DVInput[bool] = DEFAULT_TRUE): - super(_StringRegex, self).__init__(pattern=pattern, block=block) - - -class _Type(_Handler): - # noinspection PyShadowingBuiltins - def __init__( - self, type: Type, strict: bool = False, block: DVInput[bool] = DEFAULT_TRUE # pylint: disable=redefined-builtin - ): - super(_Type, self).__init__(type=type, strict=strict, block=block) - - -# noinspection PyPep8Naming -class handler(_Handler): - def __init__(self, handler_type: Callable[P, HandlerType], **kwargs: P.kwargs): - self._type_ = handler_type - super(handler, self).__init__(**kwargs) - - @property - def _type(self) -> Type[BaseHandler]: - # noinspection PyTypeChecker - return self._type_ - - callback_query = _CallbackQuery - chat_join_request = _ChatJoinRequest - chat_member = _ChatMember - chosen_inline_result = _ChosenInlineResult - command = _Command - inline_query = _InlineQuery - message = _Message - poll_answer = _PollAnswer - pool = _Poll - pre_checkout_query = _PreCheckoutQuery - prefix = _Prefix - shipping_query = _ShippingQuery - string_command = _StringCommand - string_regex = _StringRegex - type = _Type - - -# noinspection PyPep8Naming -class error_handler: - def __init__(self, func: Callable[P, T] = None, *, block: bool = DEFAULT_TRUE): - self._func = func - self._block = block - - def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: - self._func = func or self._func - data = {"type": "error", "block": self._block} - if hasattr(func, _NORMAL_HANDLER_ATTR_NAME): - handler_datas = getattr(func, _NORMAL_HANDLER_ATTR_NAME) - handler_datas.append(data) - setattr(func, _NORMAL_HANDLER_ATTR_NAME, handler_datas) - else: - setattr(func, _NORMAL_HANDLER_ATTR_NAME, [data]) - return func - - -def _entry(func: Callable[P, T]) -> Callable[P, T]: - setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "entry"}) - return func - - -class _State: - def __init__(self, state: Any): - self.state = state - - def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: - setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "state", "state": self.state}) - return func - - -def _fallback(func: Callable[P, T]) -> Callable[P, T]: - setattr(func, _CONVERSATION_HANDLER_ATTR_NAME, {"type": "fallback"}) - return func - - -# noinspection PyPep8Naming -class conversation(_Handler): - entry_point = _entry - state = _State - fallback = _fallback - - -class _Job: - kwargs: Dict = {} - - def __init__( - self, - name: str = None, - data: object = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - **kwargs, - ): - self.name = name - self.data = data - self.chat_id = chat_id - self.user_id = user_id - self.job_kwargs = {} if job_kwargs is None else job_kwargs - self.kwargs = kwargs - - def __call__(self, func: JobCallback) -> JobCallback: - data = { - "name": self.name, - "data": self.data, - "chat_id": self.chat_id, - "user_id": self.user_id, - "job_kwargs": self.job_kwargs, - "kwargs": self.kwargs, - "type": re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"), - } - if hasattr(func, _JOB_ATTR_NAME): - job_datas = getattr(func, _JOB_ATTR_NAME) - job_datas.append(data) - setattr(func, _JOB_ATTR_NAME, job_datas) - else: - setattr(func, _JOB_ATTR_NAME, [data]) - return func - - -class _RunOnce(_Job): - def __init__( - self, - when: TimeType, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, when=when) - - -class _RunRepeating(_Job): - def __init__( - self, - interval: Union[float, datetime.timedelta], - first: TimeType = None, - last: TimeType = None, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, interval=interval, first=first, last=last) - - -class _RunMonthly(_Job): - def __init__( - self, - when: datetime.time, - day: int, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, when=when, day=day) - - -class _RunDaily(_Job): - def __init__( - self, - time: datetime.time, - days: Tuple[int, ...] = tuple(range(7)), - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs, time=time, days=days) - - -class _RunCustom(_Job): - def __init__( - self, - data: object = None, - name: str = None, - chat_id: int = None, - user_id: int = None, - job_kwargs: JSONDict = None, - ): - super().__init__(name, data, chat_id, user_id, job_kwargs) - - -# noinspection PyPep8Naming -class job: - run_once = _RunOnce - run_repeating = _RunRepeating - run_monthly = _RunMonthly - run_daily = _RunDaily - run_custom = _RunCustom diff --git a/core/plugin/__init__.py b/core/plugin/__init__.py new file mode 100644 index 00000000..fcd11c67 --- /dev/null +++ b/core/plugin/__init__.py @@ -0,0 +1,16 @@ +"""插件""" + +from core.plugin._handler import conversation, error_handler, handler +from core.plugin._job import TimeType, job +from core.plugin._plugin import Plugin, PluginType, get_all_plugins + +__all__ = ( + "Plugin", + "PluginType", + "get_all_plugins", + "handler", + "error_handler", + "conversation", + "job", + "TimeType", +) diff --git a/core/plugin/_funcs.py b/core/plugin/_funcs.py new file mode 100644 index 00000000..474a2f3d --- /dev/null +++ b/core/plugin/_funcs.py @@ -0,0 +1,175 @@ +from pathlib import Path +from typing import List, Optional, Union, TYPE_CHECKING + +import aiofiles +import httpx +from httpx import UnsupportedProtocol +from telegram import Chat, Message, ReplyKeyboardRemove, Update +from telegram.error import BadRequest, Forbidden +from telegram.ext import CallbackContext, ConversationHandler, Job + +from core.dependence.redisdb import RedisDB +from core.plugin._handler import conversation, handler +from utils.const import CACHE_DIR, REQUEST_HEADERS +from utils.error import UrlResourcesNotFoundError +from utils.helpers import sha1 +from utils.log import logger + +if TYPE_CHECKING: + from core.application import Application + +try: + import ujson as json +except ImportError: + import json + +__all__ = ( + "PluginFuncs", + "ConversationFuncs", +) + + +class PluginFuncs: + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError("No application was set for this PluginManager.") + return self._application + + async def _delete_message(self, context: CallbackContext) -> None: + job = context.job + message_id = job.data + chat_info = f"chat_id[{job.chat_id}]" + + try: + chat = await self.get_chat(job.chat_id) + full_name = chat.full_name + if full_name: + chat_info = f"{full_name}[{chat.id}]" + else: + chat_info = f"{chat.title}[{chat.id}]" + except (BadRequest, Forbidden) as exc: + logger.warning("获取 chat info 失败 %s", exc.message) + except Exception as exc: + logger.warning("获取 chat info 消息失败 %s", str(exc)) + + logger.debug("删除消息 %s message_id[%s]", chat_info, message_id) + + try: + # noinspection PyTypeChecker + await context.bot.delete_message(chat_id=job.chat_id, message_id=message_id) + except BadRequest as exc: + logger.warning("删除消息 %s message_id[%s] 失败 %s", chat_info, message_id, exc.message) + + async def get_chat(self, chat_id: Union[str, int], redis_db: Optional[RedisDB] = None, ttl: int = 86400) -> Chat: + application = self.application + redis_db: RedisDB = redis_db or self.application.managers.services_map.get(RedisDB, None) + + if not redis_db: + return await application.bot.get_chat(chat_id) + + qname = f"bot:chat:{chat_id}" + + data = await redis_db.client.get(qname) + if data: + json_data = json.loads(data) + return Chat.de_json(json_data, application.telegram.bot) + + chat_info = await application.telegram.bot.get_chat(chat_id) + await redis_db.client.set(qname, chat_info.to_json()) + await redis_db.client.expire(qname, ttl) + return chat_info + + def add_delete_message_job( + self, + message: Optional[Union[int, Message]] = None, + *, + delay: int = 60, + name: Optional[str] = None, + chat: Optional[Union[int, Chat]] = None, + context: Optional[CallbackContext] = None, + ) -> Job: + """延迟删除消息""" + + if isinstance(message, Message): + if chat is None: + chat = message.chat_id + message = message.id + + chat = chat.id if isinstance(chat, Chat) else chat + + job_queue = self.application.job_queue or context.job_queue + + if job_queue is None: + raise RuntimeError + + return job_queue.run_once( + callback=self._delete_message, + when=delay, + data=message, + name=f"{chat}|{message}|{name}|delete_message" if name else f"{chat}|{message}|delete_message", + chat_id=chat, + job_kwargs={"replace_existing": True, "id": f"{chat}|{message}|delete_message"}, + ) + + @staticmethod + async def download_resource(url: str, return_path: bool = False) -> str: + url_sha1 = sha1(url) # url 的 hash 值 + pathed_url = Path(url) + + file_name = url_sha1 + pathed_url.suffix + file_path = CACHE_DIR.joinpath(file_name) + + if not file_path.exists(): # 若文件不存在,则下载 + async with httpx.AsyncClient(headers=REQUEST_HEADERS) as client: + try: + response = await client.get(url) + except UnsupportedProtocol: + logger.error("链接不支持 url[%s]", url) + return "" + + if response.is_error: + logger.error("请求出现错误 url[%s] status_code[%s]", url, response.status_code) + raise UrlResourcesNotFoundError(url) + + if response.status_code != 200: + logger.error("download_resource 获取url[%s] 错误 status_code[%s]", url, response.status_code) + raise UrlResourcesNotFoundError(url) + + async with aiofiles.open(file_path, mode="wb") as f: + await f.write(response.content) + + logger.debug("download_resource 获取url[%s] 并下载到 file_dir[%s]", url, file_path) + + return file_path if return_path else Path(file_path).as_uri() + + @staticmethod + def get_args(context: Optional[CallbackContext] = None) -> List[str]: + args = context.args + match = context.match + + if args is None: + if match is not None and (command := match.groups()[0]): + temp = [] + command_parts = command.split(" ") + for command_part in command_parts: + if command_part: + temp.append(command_part) + return temp + return [] + if len(args) >= 1: + return args + return [] + + +class ConversationFuncs: + @conversation.fallback + @handler.command(command="cancel", block=True) + async def cancel(self, update: Update, _) -> int: + await update.effective_message.reply_text("退出命令", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END diff --git a/core/plugin/_handler.py b/core/plugin/_handler.py new file mode 100644 index 00000000..223b631c --- /dev/null +++ b/core/plugin/_handler.py @@ -0,0 +1,380 @@ +from dataclasses import dataclass +from enum import Enum +from functools import wraps +from importlib import import_module +from typing import ( + Any, + Callable, + ClassVar, + Dict, + List, + Optional, + Pattern, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) + +from pydantic import BaseModel + +# noinspection PyProtectedMember +from telegram._utils.defaultvalue import DEFAULT_TRUE + +# noinspection PyProtectedMember +from telegram._utils.types import DVInput +from telegram.ext import BaseHandler +from telegram.ext.filters import BaseFilter +from typing_extensions import ParamSpec + +from core.handler.callbackqueryhandler import CallbackQueryHandler +from utils.const import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS + +if TYPE_CHECKING: + from core.builtins.dispatcher import AbstractDispatcher + +__all__ = ( + "handler", + "conversation", + "ConversationDataType", + "ConversationData", + "HandlerData", + "ErrorHandlerData", + "error_handler", +) + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") +UT = TypeVar("UT") + +HandlerType = TypeVar("HandlerType", bound=BaseHandler) +HandlerCls = Type[HandlerType] + +Module = import_module("telegram.ext") + +HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + +ERROR_HANDLER_ATTR_NAME = "_error_handler_data" + +CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" +"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" + +WRAPPER_ASSIGNMENTS = list( + set( + _WRAPPER_ASSIGNMENTS + + [ + HANDLER_DATA_ATTR_NAME, + ERROR_HANDLER_ATTR_NAME, + CONVERSATION_HANDLER_ATTR_NAME, + ] + ) +) + + +@dataclass(init=True) +class HandlerData: + type: Type[HandlerType] + admin: bool + kwargs: Dict[str, Any] + dispatcher: Optional[Type["AbstractDispatcher"]] = None + + +class _Handler: + _type: Type["HandlerType"] + + kwargs: Dict[str, Any] = {} + + def __init_subclass__(cls, **kwargs) -> None: + """用于获取 python-telegram-bot 中对应的 handler class""" + + handler_name = f"{cls.__name__.strip('_')}Handler" + + if handler_name == "CallbackQueryHandler": + cls._type = CallbackQueryHandler + return + + cls._type = getattr(Module, handler_name, None) + + def __init__(self, admin: bool = False, dispatcher: Optional[Type["AbstractDispatcher"]] = None, **kwargs) -> None: + self.dispatcher = dispatcher + self.admin = admin + self.kwargs = kwargs + + def __call__(self, func: Callable[P, R]) -> Callable[P, R]: + """decorator实现,从 func 生成 Handler""" + + handler_datas = getattr(func, HANDLER_DATA_ATTR_NAME, []) + handler_datas.append( + HandlerData(type=self._type, admin=self.admin, kwargs=self.kwargs, dispatcher=self.dispatcher) + ) + setattr(func, HANDLER_DATA_ATTR_NAME, handler_datas) + + return func + + +class _CallbackQuery(_Handler): + def __init__( + self, + pattern: Union[str, Pattern, type, Callable[[object], Optional[bool]]] = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_CallbackQuery, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) + + +class _ChatJoinRequest(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_ChatJoinRequest, self).__init__(block=block, dispatcher=dispatcher) + + +class _ChatMember(_Handler): + def __init__( + self, + chat_member_types: int = -1, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(chat_member_types=chat_member_types, block=block, dispatcher=dispatcher) + + +class _ChosenInlineResult(_Handler): + def __init__( + self, + block: DVInput[bool] = DEFAULT_TRUE, + *, + pattern: Union[str, Pattern] = None, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(block=block, pattern=pattern, dispatcher=dispatcher) + + +class _Command(_Handler): + def __init__( + self, + command: Union[str, List[str]], + filters: "BaseFilter" = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_Command, self).__init__( + command=command, filters=filters, block=block, admin=admin, dispatcher=dispatcher + ) + + +class _InlineQuery(_Handler): + def __init__( + self, + pattern: Union[str, Pattern] = None, + chat_types: List[str] = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_InlineQuery, self).__init__(pattern=pattern, block=block, chat_types=chat_types, dispatcher=dispatcher) + + +class _Message(_Handler): + def __init__( + self, + filters: BaseFilter, + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ) -> None: + super(_Message, self).__init__(filters=filters, block=block, admin=admin, dispatcher=dispatcher) + + +class _PollAnswer(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_PollAnswer, self).__init__(block=block, dispatcher=dispatcher) + + +class _Poll(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_Poll, self).__init__(block=block, dispatcher=dispatcher) + + +class _PreCheckoutQuery(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_PreCheckoutQuery, self).__init__(block=block, dispatcher=dispatcher) + + +class _Prefix(_Handler): + def __init__( + self, + prefix: str, + command: str, + filters: BaseFilter = None, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_Prefix, self).__init__( + prefix=prefix, command=command, filters=filters, block=block, dispatcher=dispatcher + ) + + +class _ShippingQuery(_Handler): + def __init__(self, *, block: DVInput[bool] = DEFAULT_TRUE, dispatcher: Optional[Type["AbstractDispatcher"]] = None): + super(_ShippingQuery, self).__init__(block=block, dispatcher=dispatcher) + + +class _StringCommand(_Handler): + def __init__( + self, + command: str, + *, + admin: bool = False, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_StringCommand, self).__init__(command=command, block=block, admin=admin, dispatcher=dispatcher) + + +class _StringRegex(_Handler): + def __init__( + self, + pattern: Union[str, Pattern], + *, + block: DVInput[bool] = DEFAULT_TRUE, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super(_StringRegex, self).__init__(pattern=pattern, block=block, admin=admin, dispatcher=dispatcher) + + +class _Type(_Handler): + # noinspection PyShadowingBuiltins + def __init__( + self, + type: Type[UT], # pylint: disable=W0622 + strict: bool = False, + *, + block: DVInput[bool] = DEFAULT_TRUE, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): # pylint: disable=redefined-builtin + super(_Type, self).__init__(type=type, strict=strict, block=block, dispatcher=dispatcher) + + +# noinspection PyPep8Naming +class handler(_Handler): + callback_query = _CallbackQuery + chat_join_request = _ChatJoinRequest + chat_member = _ChatMember + chosen_inline_result = _ChosenInlineResult + command = _Command + inline_query = _InlineQuery + message = _Message + poll_answer = _PollAnswer + pool = _Poll + pre_checkout_query = _PreCheckoutQuery + prefix = _Prefix + shipping_query = _ShippingQuery + string_command = _StringCommand + string_regex = _StringRegex + type = _Type + + def __init__( + self, + handler_type: Union[Callable[P, "HandlerType"], Type["HandlerType"]], + *, + admin: bool = False, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs: P.kwargs, + ) -> None: + self._type = handler_type + super().__init__(admin=admin, dispatcher=dispatcher, **kwargs) + + +class ConversationDataType(Enum): + """conversation handler 的类型""" + + Entry = "entry" + State = "state" + Fallback = "fallback" + + +class ConversationData(BaseModel): + """用于储存 conversation handler 的数据""" + + type: ConversationDataType + state: Optional[Any] = None + + +class _ConversationType: + _type: ClassVar[ConversationDataType] + + def __init_subclass__(cls, **kwargs) -> None: + cls._type = ConversationDataType(cls.__name__.lstrip("_").lower()) + + +def _entry(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Entry)) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapped + + +class _State(_ConversationType): + def __init__(self, state: Any) -> None: + self.state = state + + def __call__(self, func: Callable[P, T] = None) -> Callable[P, T]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=self._type, state=self.state)) + return func + + +def _fallback(func: Callable[P, R]) -> Callable[P, R]: + setattr(func, CONVERSATION_HANDLER_ATTR_NAME, ConversationData(type=ConversationDataType.Fallback)) + + @wraps(func, assigned=WRAPPER_ASSIGNMENTS) + def wrapped(*args: P.args, **kwargs: P.kwargs) -> R: + return func(*args, **kwargs) + + return wrapped + + +# noinspection PyPep8Naming +class conversation(_Handler): + entry_point = _entry + state = _State + fallback = _fallback + + +@dataclass(init=True) +class ErrorHandlerData: + block: bool + func: Optional[Callable] = None + + +# noinspection PyPep8Naming +class error_handler: + _func: Callable[P, R] + + def __init__( + self, + *, + block: bool = DEFAULT_TRUE, + ): + self._block = block + + def __call__(self, func: Callable[P, T]) -> Callable[P, T]: + self._func = func + wraps(func, assigned=WRAPPER_ASSIGNMENTS)(self) + + handler_datas = getattr(func, ERROR_HANDLER_ATTR_NAME, []) + handler_datas.append(ErrorHandlerData(block=self._block)) + setattr(self._func, ERROR_HANDLER_ATTR_NAME, handler_datas) + + return self._func diff --git a/core/plugin/_job.py b/core/plugin/_job.py new file mode 100644 index 00000000..393ad874 --- /dev/null +++ b/core/plugin/_job.py @@ -0,0 +1,173 @@ +"""插件""" +import datetime +import re +from dataclasses import dataclass, field +from typing import Any, Dict, Optional, TYPE_CHECKING, Tuple, Type, TypeVar, Union + +# noinspection PyProtectedMember +from telegram._utils.types import JSONDict + +# noinspection PyProtectedMember +from telegram.ext._utils.types import JobCallback +from typing_extensions import ParamSpec + +if TYPE_CHECKING: + from core.builtins.dispatcher import AbstractDispatcher + +__all__ = ["TimeType", "job", "JobData"] + +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") + +TimeType = Union[float, datetime.timedelta, datetime.datetime, datetime.time] + +_JOB_ATTR_NAME = "_job_data" + + +@dataclass(init=True) +class JobData: + name: str + data: Any + chat_id: int + user_id: int + type: str + job_kwargs: JSONDict = field(default_factory=dict) + kwargs: JSONDict = field(default_factory=dict) + dispatcher: Optional[Type["AbstractDispatcher"]] = None + + +class _Job: + kwargs: Dict = {} + + def __init__( + self, + name: str = None, + data: object = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + **kwargs, + ): + self.name = name + self.data = data + self.chat_id = chat_id + self.user_id = user_id + self.job_kwargs = {} if job_kwargs is None else job_kwargs + self.kwargs = kwargs + if dispatcher is None: + from core.builtins.dispatcher import JobDispatcher + + dispatcher = JobDispatcher + + self.dispatcher = dispatcher + + def __call__(self, func: JobCallback) -> JobCallback: + data = JobData( + name=self.name, + data=self.data, + chat_id=self.chat_id, + user_id=self.user_id, + job_kwargs=self.job_kwargs, + kwargs=self.kwargs, + type=re.sub(r"([A-Z])", lambda x: "_" + x.group().lower(), self.__class__.__name__).lstrip("_"), + dispatcher=self.dispatcher, + ) + if hasattr(func, _JOB_ATTR_NAME): + job_datas = getattr(func, _JOB_ATTR_NAME) + job_datas.append(data) + setattr(func, _JOB_ATTR_NAME, job_datas) + else: + setattr(func, _JOB_ATTR_NAME, [data]) + return func + + +class _RunOnce(_Job): + def __init__( + self, + when: TimeType, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when) + + +class _RunRepeating(_Job): + def __init__( + self, + interval: Union[float, datetime.timedelta], + first: TimeType = None, + last: TimeType = None, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__( + name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, interval=interval, first=first, last=last + ) + + +class _RunMonthly(_Job): + def __init__( + self, + when: datetime.time, + day: int, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, when=when, day=day) + + +class _RunDaily(_Job): + def __init__( + self, + time: datetime.time, + days: Tuple[int, ...] = tuple(range(7)), + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher, time=time, days=days) + + +class _RunCustom(_Job): + def __init__( + self, + data: object = None, + name: str = None, + chat_id: int = None, + user_id: int = None, + job_kwargs: JSONDict = None, + *, + dispatcher: Optional[Type["AbstractDispatcher"]] = None, + ): + super().__init__(name, data, chat_id, user_id, job_kwargs, dispatcher=dispatcher) + + +# noinspection PyPep8Naming +class job: + run_once = _RunOnce + run_repeating = _RunRepeating + run_monthly = _RunMonthly + run_daily = _RunDaily + run_custom = _RunCustom diff --git a/core/plugin/_plugin.py b/core/plugin/_plugin.py new file mode 100644 index 00000000..d3c293bc --- /dev/null +++ b/core/plugin/_plugin.py @@ -0,0 +1,303 @@ +"""插件""" +import asyncio +from abc import ABC +from dataclasses import asdict +from datetime import timedelta +from functools import partial, wraps +from itertools import chain +from multiprocessing import RLock as Lock +from types import MethodType +from typing import ( + Any, + ClassVar, + Dict, + Iterable, + List, + Optional, + TYPE_CHECKING, + Type, + TypeVar, + Union, +) + +from pydantic import BaseModel +from telegram.ext import BaseHandler, ConversationHandler, Job, TypeHandler +from typing_extensions import ParamSpec + +from core.handler.adminhandler import AdminHandler +from core.plugin._funcs import ConversationFuncs, PluginFuncs +from core.plugin._handler import ConversationDataType +from utils.const import WRAPPER_ASSIGNMENTS +from utils.helpers import isabstract +from utils.log import logger + +if TYPE_CHECKING: + from core.application import Application + from core.plugin._handler import ConversationData, HandlerData, ErrorHandlerData + from core.plugin._job import JobData + from multiprocessing.synchronize import RLock as LockType + +__all__ = ("Plugin", "PluginType", "get_all_plugins") + +wraps = partial(wraps, assigned=WRAPPER_ASSIGNMENTS) +P = ParamSpec("P") +T = TypeVar("T") +R = TypeVar("R") + +HandlerType = TypeVar("HandlerType", bound=BaseHandler) + +_HANDLER_DATA_ATTR_NAME = "_handler_datas" +"""用于储存生成 handler 时所需要的参数(例如 block)的属性名""" + +_CONVERSATION_HANDLER_ATTR_NAME = "_conversation_handler_data" +"""用于储存生成 ConversationHandler 时所需要的参数(例如 block)的属性名""" + +_ERROR_HANDLER_ATTR_NAME = "_error_handler_data" + +_JOB_ATTR_NAME = "_job_data" + +_EXCLUDE_ATTRS = ["handlers", "jobs", "error_handlers"] + + +class _Plugin(PluginFuncs): + """插件""" + + _lock: ClassVar["LockType"] = Lock() + _asyncio_lock: ClassVar["LockType"] = asyncio.Lock() + _installed: bool = False + + _handlers: Optional[List[HandlerType]] = None + _error_handlers: Optional[List["ErrorHandlerData"]] = None + _jobs: Optional[List[Job]] = None + _application: "Optional[Application]" = None + + def set_application(self, application: "Application") -> None: + self._application = application + + @property + def application(self) -> "Application": + if self._application is None: + raise RuntimeError("No application was set for this Plugin.") + return self._application + + @property + def handlers(self) -> List[HandlerType]: + """该插件的所有 handler""" + with self._lock: + if self._handlers is None: + self._handlers = [] + + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) + ): + for data in datas: + data: "HandlerData" + if data.admin: + self._handlers.append( + AdminHandler( + handler=data.type( + callback=func, + **data.kwargs, + ), + application=self.application, + ) + ) + else: + self._handlers.append( + data.type( + callback=func, + **data.kwargs, + ) + ) + return self._handlers + + @property + def error_handlers(self) -> List["ErrorHandlerData"]: + with self._lock: + if self._error_handlers is None: + self._error_handlers = [] + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _ERROR_HANDLER_ATTR_NAME, [])) + ): + for data in datas: + data: "ErrorHandlerData" + data.func = func + self._error_handlers.append(data) + + return self._error_handlers + + def _install_jobs(self) -> None: + if self._jobs is None: + self._jobs = [] + for attr in dir(self): + # noinspection PyUnboundLocalVariable + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and isinstance(func := getattr(self, attr), MethodType) + and (datas := getattr(func, _JOB_ATTR_NAME, [])) + ): + for data in datas: + data: "JobData" + self._jobs.append( + getattr(self.application.telegram.job_queue, data.type)( + callback=func, + **data.kwargs, + **{ + key: value + for key, value in asdict(data).items() + if key not in ["type", "kwargs", "dispatcher"] + }, + ) + ) + + @property + def jobs(self) -> List[Job]: + with self._lock: + if self._jobs is None: + self._jobs = [] + self._install_jobs() + return self._jobs + + async def initialize(self) -> None: + """初始化插件""" + + async def shutdown(self) -> None: + """销毁插件""" + + async def install(self) -> None: + """安装""" + group = id(self) + if not self._installed: + await self.initialize() + # initialize 必须先执行 如果出现异常不会执行 add_handler 以免出现问题 + async with self._asyncio_lock: + self._install_jobs() + + for h in self.handlers: + if not isinstance(h, TypeHandler): + self.application.telegram.add_handler(h, group) + else: + self.application.telegram.add_handler(h, -1) + + for h in self.error_handlers: + self.application.telegram.add_error_handler(h.func, h.block) + self._installed = True + + async def uninstall(self) -> None: + """卸载""" + group = id(self) + + with self._lock: + if self._installed: + if group in self.application.telegram.handlers: + del self.application.telegram.handlers[id(self)] + + for h in self.handlers: + if isinstance(h, TypeHandler): + self.application.telegram.remove_handler(h, -1) + for h in self.error_handlers: + self.application.telegram.remove_error_handler(h.func) + + for j in self.application.telegram.job_queue.jobs(): + j.schedule_removal() + await self.shutdown() + self._installed = False + + async def reload(self) -> None: + await self.uninstall() + await self.install() + + +class _Conversation(_Plugin, ConversationFuncs, ABC): + """Conversation类""" + + # noinspection SpellCheckingInspection + class Config(BaseModel): + allow_reentry: bool = False + per_chat: bool = True + per_user: bool = True + per_message: bool = False + conversation_timeout: Optional[Union[float, timedelta]] = None + name: Optional[str] = None + map_to_parent: Optional[Dict[object, object]] = None + block: bool = False + + def __init_subclass__(cls, **kwargs): + cls._conversation_kwargs = kwargs + super(_Conversation, cls).__init_subclass__() + return cls + + @property + def handlers(self) -> List[HandlerType]: + with self._lock: + if self._handlers is None: + self._handlers = [] + + entry_points: List[HandlerType] = [] + states: Dict[Any, List[HandlerType]] = {} + fallbacks: List[HandlerType] = [] + for attr in dir(self): + if ( + not (attr.startswith("_") or attr in _EXCLUDE_ATTRS) + and (func := getattr(self, attr, None)) is not None + and (datas := getattr(func, _HANDLER_DATA_ATTR_NAME, [])) + ): + conversation_data: "ConversationData" + + handlers: List[HandlerType] = [] + for data in datas: + handlers.append( + data.type( + callback=func, + **data.kwargs, + ) + ) + + if conversation_data := getattr(func, _CONVERSATION_HANDLER_ATTR_NAME, None): + if (_type := conversation_data.type) is ConversationDataType.Entry: + entry_points.extend(handlers) + elif _type is ConversationDataType.State: + if conversation_data.state in states: + states[conversation_data.state].extend(handlers) + else: + states[conversation_data.state] = handlers + elif _type is ConversationDataType.Fallback: + fallbacks.extend(handlers) + else: + self._handlers.extend(handlers) + else: + self._handlers.extend(handlers) + if entry_points and states and fallbacks: + kwargs = self._conversation_kwargs + kwargs.update(self.Config().dict()) + self._handlers.append(ConversationHandler(entry_points, states, fallbacks, **kwargs)) + else: + temp_dict = {"entry_points": entry_points, "states": states, "fallbacks": fallbacks} + reason = map(lambda x: f"'{x[0]}'", filter(lambda x: not x[1], temp_dict.items())) + logger.warning( + "'%s' 因缺少 '%s' 而生成无法生成 ConversationHandler", self.__class__.__name__, ", ".join(reason) + ) + return self._handlers + + +class Plugin(_Plugin, ABC): + """插件""" + + Conversation = _Conversation + + +PluginType = TypeVar("PluginType", bound=_Plugin) + + +def get_all_plugins() -> Iterable[Type[PluginType]]: + """获取所有 Plugin 的子类""" + return filter( + lambda x: x.__name__[0] != "_" and not isabstract(x), + chain(Plugin.__subclasses__(), _Conversation.__subclasses__()), + ) diff --git a/core/quiz/__init__.py b/core/quiz/__init__.py deleted file mode 100644 index f2eddf4d..00000000 --- a/core/quiz/__init__.py +++ /dev/null @@ -1,14 +0,0 @@ -from core.base.mysql import MySQL -from core.base.redisdb import RedisDB -from core.service import init_service -from .cache import QuizCache -from .repositories import QuizRepository -from .services import QuizService - - -@init_service -def create_quiz_service(mysql: MySQL, redis: RedisDB): - _repository = QuizRepository(mysql) - _cache = QuizCache(redis) - _service = QuizService(_repository, _cache) - return _service diff --git a/core/quiz/base.py b/core/quiz/base.py deleted file mode 100644 index 77bcca75..00000000 --- a/core/quiz/base.py +++ /dev/null @@ -1,19 +0,0 @@ -from typing import List - -from .models import Answer, Question - - -def CreatQuestionFromSQLData(data: tuple) -> List[Question]: - temp_list = [] - for temp_data in data: - (question_id, text) = temp_data - temp_list.append(Question(question_id, text)) - return temp_list - - -def CreatAnswerFromSQLData(data: tuple) -> List[Answer]: - temp_list = [] - for temp_data in data: - (answer_id, question_id, is_correct, text) = temp_data - temp_list.append(Answer(answer_id, question_id, is_correct, text)) - return temp_list diff --git a/core/quiz/models.py b/core/quiz/models.py deleted file mode 100644 index ad41a38a..00000000 --- a/core/quiz/models.py +++ /dev/null @@ -1,98 +0,0 @@ -from typing import List, Optional - -from sqlmodel import SQLModel, Field, Column, Integer, ForeignKey - -from utils.baseobject import BaseObject -from utils.typedefs import JSONDict - - -class AnswerDB(SQLModel, table=True): - __tablename__ = "answer" - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - question_id: Optional[int] = Field( - sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT")) - ) - is_correct: Optional[bool] = Field() - text: Optional[str] = Field() - - -class QuestionDB(SQLModel, table=True): - __tablename__ = "question" - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - text: Optional[str] = Field() - - -class Answer(BaseObject): - def __init__(self, answer_id: int = 0, question_id: int = 0, is_correct: bool = True, text: str = ""): - """Answer类 - - :param answer_id: 答案ID - :param question_id: 与之对应的问题ID - :param is_correct: 该答案是否正确 - :param text: 答案文本 - """ - self.answer_id = answer_id - self.question_id = question_id - self.text = text - self.is_correct = is_correct - - __slots__ = ("answer_id", "question_id", "text", "is_correct") - - def to_database_data(self) -> AnswerDB: - data = AnswerDB() - data.id = self.answer_id - data.question_id = self.question_id - data.text = self.text - data.is_correct = self.is_correct - return data - - @classmethod - def de_database_data(cls, data: Optional[AnswerDB]) -> Optional["Answer"]: - if data is None: - return cls() - return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct) - - -class Question(BaseObject): - def __init__(self, question_id: int = 0, text: str = "", answers: List[Answer] = None): - """Question类 - - :param question_id: 问题ID - :param text: 问题文本 - :param answers: 答案列表 - """ - self.question_id = question_id - self.text = text - self.answers = [] if answers is None else answers - - def to_database_data(self) -> QuestionDB: - data = QuestionDB() - data.text = self.text - data.id = self.question_id - return data - - @classmethod - def de_database_data(cls, data: Optional[QuestionDB]) -> Optional["Question"]: - if data is None: - return cls() - return cls(question_id=data.id, text=data.text) - - def to_dict(self) -> JSONDict: - data = super().to_dict() - if self.answers: - data["answers"] = [e.to_dict() for e in self.answers] - return data - - @classmethod - def de_json(cls, data: Optional[JSONDict]) -> Optional["Question"]: - data = cls._parse_data(data) - if not data: - return None - data["answers"] = Answer.de_list(data.get("answers")) - return cls(**data) - - __slots__ = ("question_id", "text", "answers") diff --git a/core/search/__init__.py b/core/search/__init__.py deleted file mode 100644 index 1d8d1b26..00000000 --- a/core/search/__init__.py +++ /dev/null @@ -1,10 +0,0 @@ -from core.service import init_service -from .services import SearchServices as _SearchServices - -__all__ = [] - - -@init_service -def create_search_service(): - _service = _SearchServices() - return _service diff --git a/core/service.py b/core/service.py deleted file mode 100644 index a68b5ee4..00000000 --- a/core/service.py +++ /dev/null @@ -1,31 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Callable - -from utils.log import logger - -__all__ = ["Service", "init_service"] - - -class Service(ABC): - @abstractmethod - def __init__(self, *args, **kwargs): - """初始化""" - - async def start(self): - """启动 service""" - - async def stop(self): - """关闭 service""" - - -def init_service(func: Callable): - from core.bot import bot - - if bot.is_running: - try: - service = bot.init_inject(func) - logger.success(f'服务 "{service.__class__.__name__}" 初始化成功') - bot.add_service(service) - except Exception as e: # pylint: disable=W0703 - logger.exception(f"来自{func.__module__}的服务初始化失败:{e}") - return func diff --git a/core/services/__init__.py b/core/services/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/core/services/cookies/__init__.py b/core/services/cookies/__init__.py new file mode 100644 index 00000000..608582b9 --- /dev/null +++ b/core/services/cookies/__init__.py @@ -0,0 +1,5 @@ +"""CookieService""" + +from core.services.cookies.services import CookiesService, PublicCookiesService + +__all__ = ("CookiesService", "PublicCookiesService") diff --git a/core/cookies/cache.py b/core/services/cookies/cache.py similarity index 90% rename from core/cookies/cache.py rename to core/services/cookies/cache.py index 7b36900d..834fed2d 100644 --- a/core/cookies/cache.py +++ b/core/services/cookies/cache.py @@ -1,12 +1,15 @@ from typing import List, Union -from core.base.redisdb import RedisDB +from core.base_service import BaseService +from core.basemodel import RegionEnum +from core.dependence.redisdb import RedisDB +from core.services.cookies.error import CookiesCachePoolExhausted from utils.error import RegionNotFoundError -from utils.models.base import RegionEnum -from .error import CookiesCachePoolExhausted +__all__ = ("PublicCookiesCache",) -class PublicCookiesCache: + +class PublicCookiesCache(BaseService.Component): """使用优先级(score)进行排序,对使用次数最少的Cookies进行审核""" def __init__(self, redis: RedisDB): @@ -19,10 +22,9 @@ def __init__(self, redis: RedisDB): def get_public_cookies_queue_name(self, region: RegionEnum): if region == RegionEnum.HYPERION: return f"{self.score_qname}:yuanshen" - elif region == RegionEnum.HOYOLAB: + if region == RegionEnum.HOYOLAB: return f"{self.score_qname}:genshin" - else: - raise RegionNotFoundError(region.name) + raise RegionNotFoundError(region.name) async def putback_public_cookies(self, uid: int, region: RegionEnum): """重新添加单个到缓存列表 diff --git a/core/cookies/error.py b/core/services/cookies/error.py similarity index 71% rename from core/cookies/error.py rename to core/services/cookies/error.py index 5873e674..239110af 100644 --- a/core/cookies/error.py +++ b/core/services/cookies/error.py @@ -7,11 +7,6 @@ def __init__(self): super().__init__("Cookies cache pool is exhausted") -class CookiesNotFoundError(CookieServiceError): - def __init__(self, user_id): - super().__init__(f"{user_id} cookies not found") - - class TooManyRequestPublicCookies(CookieServiceError): def __init__(self, user_id): super().__init__(f"{user_id} too many request public cookies") diff --git a/core/services/cookies/models.py b/core/services/cookies/models.py new file mode 100644 index 00000000..0cafa345 --- /dev/null +++ b/core/services/cookies/models.py @@ -0,0 +1,39 @@ +import enum +from typing import Optional, Dict + +from sqlmodel import SQLModel, Field, Boolean, Column, Enum, JSON, Integer, BigInteger, Index + +from core.basemodel import RegionEnum + +__all__ = ("Cookies", "CookiesDataBase", "CookiesStatusEnum") + + +class CookiesStatusEnum(int, enum.Enum): + STATUS_SUCCESS = 0 + INVALID_COOKIES = 1 + TOO_MANY_REQUESTS = 2 + + +class Cookies(SQLModel): + __table_args__ = ( + Index("index_user_account", "user_id", "account_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) + user_id: int = Field( + sa_column=Column(BigInteger()), + ) + account_id: int = Field( + default=None, + sa_column=Column( + BigInteger(), + ), + ) + data: Optional[Dict[str, str]] = Field(sa_column=Column(JSON)) + status: Optional[CookiesStatusEnum] = Field(sa_column=Column(Enum(CookiesStatusEnum))) + region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) + is_share: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class CookiesDataBase(Cookies, table=True): + __tablename__ = "cookies" diff --git a/core/services/cookies/repositories.py b/core/services/cookies/repositories.py new file mode 100644 index 00000000..a11b9d89 --- /dev/null +++ b/core/services/cookies/repositories.py @@ -0,0 +1,55 @@ +from typing import Optional, List + +from sqlmodel import select + +from core.base_service import BaseService +from core.basemodel import RegionEnum +from core.dependence.mysql import MySQL +from core.services.cookies.models import CookiesDataBase as Cookies +from core.sqlmodel.session import AsyncSession + +__all__ = ("CookiesRepository",) + + +class CookiesRepository(BaseService.Component): + def __init__(self, mysql: MySQL): + self.engine = mysql.engine + + async def get( + self, + user_id: int, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + ) -> Optional[Cookies]: + async with AsyncSession(self.engine) as session: + statement = select(Cookies).where(Cookies.user_id == user_id) + if account_id is not None: + statement = statement.where(Cookies.account_id == account_id) + if region is not None: + statement = statement.where(Cookies.region == region) + results = await session.exec(statement) + return results.first() + + async def add(self, cookies: Cookies) -> None: + async with AsyncSession(self.engine) as session: + session.add(cookies) + await session.commit() + + async def update(self, cookies: Cookies) -> Cookies: + async with AsyncSession(self.engine) as session: + session.add(cookies) + await session.commit() + await session.refresh(cookies) + return cookies + + async def delete(self, cookies: Cookies) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(cookies) + await session.commit() + + async def get_all_by_region(self, region: RegionEnum) -> List[Cookies]: + async with AsyncSession(self.engine) as session: + statement = select(Cookies).where(Cookies.region == region) + results = await session.exec(statement) + cookies = results.all() + return cookies diff --git a/core/cookies/services.py b/core/services/cookies/services.py similarity index 63% rename from core/cookies/services.py rename to core/services/cookies/services.py index d5f54c22..6f5a5b0d 100644 --- a/core/cookies/services.py +++ b/core/services/cookies/services.py @@ -1,67 +1,73 @@ -from typing import List +from typing import List, Optional import genshin -from genshin import GenshinException, InvalidCookies, TooManyRequests, types, Game +from genshin import Game, GenshinException, InvalidCookies, TooManyRequests, types +from core.base_service import BaseService +from core.basemodel import RegionEnum +from core.services.cookies.cache import PublicCookiesCache +from core.services.cookies.error import CookieServiceError, TooManyRequestPublicCookies +from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum +from core.services.cookies.repositories import CookiesRepository from utils.log import logger -from utils.models.base import RegionEnum -from .cache import PublicCookiesCache -from .error import TooManyRequestPublicCookies, CookieServiceError -from .models import CookiesStatusEnum -from .repositories import CookiesNotFoundError, CookiesRepository +__all__ = ("CookiesService", "PublicCookiesService") -class CookiesService: + +class CookiesService(BaseService): def __init__(self, cookies_repository: CookiesRepository) -> None: self._repository: CookiesRepository = cookies_repository - async def update_cookies(self, user_id: int, cookies: dict, region: RegionEnum): - await self._repository.update_cookies(user_id, cookies, region) - - async def add_cookies(self, user_id: int, cookies: dict, region: RegionEnum): - await self._repository.add_cookies(user_id, cookies, region) + async def update(self, cookies: Cookies): + await self._repository.update(cookies) - async def get_cookies(self, user_id: int, region: RegionEnum): - return await self._repository.get_cookies(user_id, region) + async def add(self, cookies: Cookies): + await self._repository.add(cookies) - async def del_cookies(self, user_id: int, region: RegionEnum): - return await self._repository.del_cookies(user_id, region) + async def get( + self, + user_id: int, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + ) -> Optional[Cookies]: + return await self._repository.get(user_id, account_id, region) - async def add_or_update_cookies(self, user_id: int, cookies: dict, region: RegionEnum): - try: - await self.get_cookies(user_id, region) - await self.update_cookies(user_id, cookies, region) - except CookiesNotFoundError: - await self.add_cookies(user_id, cookies, region) + async def delete(self, cookies: Cookies) -> None: + return await self._repository.delete(cookies) -class PublicCookiesService: +class PublicCookiesService(BaseService): def __init__(self, cookies_repository: CookiesRepository, public_cookies_cache: PublicCookiesCache): self._cache = public_cookies_cache self._repository: CookiesRepository = cookies_repository self.count: int = 0 self.user_times_limiter = 3 * 3 + async def initialize(self) -> None: + logger.info("正在初始化公共Cookies池") + await self.refresh() + logger.success("刷新公共Cookies池成功") + async def refresh(self): """刷新公共Cookies 定时任务 :return: """ user_list: List[int] = [] - cookies_list = await self._repository.get_all_cookies(RegionEnum.HYPERION) # 从数据库获取2 + cookies_list = await self._repository.get_all_by_region(RegionEnum.HYPERION) # 从数据库获取2 for cookies in cookies_list: if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: user_list.append(cookies.user_id) if len(user_list) > 0: add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HYPERION) - logger.info(f"国服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]") + logger.info("国服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) user_list.clear() - cookies_list = await self._repository.get_all_cookies(RegionEnum.HOYOLAB) + cookies_list = await self._repository.get_all_by_region(RegionEnum.HOYOLAB) for cookies in cookies_list: if cookies.status is None or cookies.status == CookiesStatusEnum.STATUS_SUCCESS: user_list.append(cookies.user_id) if len(user_list) > 0: add, count = await self._cache.add_public_cookies(user_list, RegionEnum.HOYOLAB) - logger.info(f"国际服公共Cookies池已经添加[{add}]个 当前成员数为[{count}]") + logger.info("国际服公共Cookies池已经添加[%s]个 当前成员数为[%s]", add, count) async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): """获取公共Cookies @@ -71,20 +77,19 @@ async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): """ user_times = await self._cache.incr_by_user_times(user_id) if int(user_times) > self.user_times_limiter: - logger.warning(f"用户 [{user_id}] 使用公共Cookie次数已经到达上限") + logger.warning("用户 %s 使用公共Cookie次数已经到达上限", user_id) raise TooManyRequestPublicCookies(user_id) while True: public_id, count = await self._cache.get_public_cookies(region) - try: - cookies = await self._repository.get_cookies(public_id, region) - except CookiesNotFoundError: + cookies = await self._repository.get(public_id, region=region) + if cookies is None: await self._cache.delete_public_cookies(public_id, region) continue if region == RegionEnum.HYPERION: - client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE) + client = genshin.Client(cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.CHINESE) elif region == RegionEnum.HOYOLAB: client = genshin.Client( - cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" + cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" ) else: raise CookieServiceError @@ -101,13 +106,13 @@ async def get_cookies(self, user_id: int, region: RegionEnum = RegionEnum.NULL): logger.warning("Cookies无效 ") logger.exception(exc) cookies.status = CookiesStatusEnum.INVALID_COOKIES - await self._repository.update_cookies_ex(cookies, region) + await self._repository.update(cookies) await self._cache.delete_public_cookies(cookies.user_id, region) continue except TooManyRequests: logger.warning("用户 [%s] 查询次数太多或操作频繁", public_id) cookies.status = CookiesStatusEnum.TOO_MANY_REQUESTS - await self._repository.update_cookies_ex(cookies, region) + await self._repository.update(cookies) await self._cache.delete_public_cookies(cookies.user_id, region) continue except GenshinException as exc: diff --git a/core/services/game/__init__.py b/core/services/game/__init__.py new file mode 100644 index 00000000..78e0e2f1 --- /dev/null +++ b/core/services/game/__init__.py @@ -0,0 +1 @@ +"""GameService""" diff --git a/core/game/cache.py b/core/services/game/cache.py similarity index 60% rename from core/game/cache.py rename to core/services/game/cache.py index 24123966..cfb69821 100644 --- a/core/game/cache.py +++ b/core/services/game/cache.py @@ -1,12 +1,16 @@ from typing import List -from core.base.redisdb import RedisDB +from core.base_service import BaseService +from core.dependence.redisdb import RedisDB + +__all__ = ["GameCache", "GameCacheForStrategy", "GameCacheForMaterial"] class GameCache: - def __init__(self, redis: RedisDB, qname: str, ttl: int = 3600): + qname: str + + def __init__(self, redis: RedisDB, ttl: int = 3600): self.client = redis.client - self.qname = qname self.ttl = ttl async def get_url_list(self, character_name: str): @@ -19,3 +23,11 @@ async def set_url_list(self, character_name: str, str_list: List[str]): await self.client.lpush(qname, *str_list) await self.client.expire(qname, self.ttl) return await self.client.llen(qname) + + +class GameCacheForStrategy(BaseService.Component, GameCache): + qname = "game:strategy" + + +class GameCacheForMaterial(BaseService.Component, GameCache): + qname = "game:material" diff --git a/core/game/services.py b/core/services/game/services.py similarity index 86% rename from core/game/services.py rename to core/services/game/services.py index 04a349c9..170de632 100644 --- a/core/game/services.py +++ b/core/services/game/services.py @@ -1,11 +1,14 @@ from typing import List, Optional +from core.base_service import BaseService +from core.services.game.cache import GameCacheForMaterial, GameCacheForStrategy from modules.apihelper.client.components.hyperion import Hyperion -from .cache import GameCache +__all__ = ("GameMaterialService", "GameStrategyService") -class GameStrategyService: - def __init__(self, cache: GameCache, collections: Optional[List[int]] = None): + +class GameStrategyService(BaseService): + def __init__(self, cache: GameCacheForStrategy, collections: Optional[List[int]] = None): self._cache = cache self._hyperion = Hyperion() if collections is None: @@ -49,8 +52,8 @@ async def get_strategy(self, character_name: str) -> str: return artwork_info.image_urls[0] -class GameMaterialService: - def __init__(self, cache: GameCache, collections: Optional[List[int]] = None): +class GameMaterialService(BaseService): + def __init__(self, cache: GameCacheForMaterial, collections: Optional[List[int]] = None): self._cache = cache self._hyperion = Hyperion() self._collections = [428421, 1164644] if collections is None else collections @@ -91,9 +94,8 @@ async def get_material(self, character_name: str) -> str: await self._cache.set_url_list(character_name, image_url_list) if len(image_url_list) == 0: return "" - elif len(image_url_list) == 1: + if len(image_url_list) == 1: return image_url_list[0] - elif character_name in self._special: + if character_name in self._special: return image_url_list[2] - else: - return image_url_list[1] + return image_url_list[1] diff --git a/core/services/players/__init__.py b/core/services/players/__init__.py new file mode 100644 index 00000000..5cfee96c --- /dev/null +++ b/core/services/players/__init__.py @@ -0,0 +1,3 @@ +from .services import PlayersService + +__all__ = ("PlayersService",) diff --git a/core/services/players/error.py b/core/services/players/error.py new file mode 100644 index 00000000..623bed84 --- /dev/null +++ b/core/services/players/error.py @@ -0,0 +1,2 @@ +class PlayerNotFoundError(Exception): + pass diff --git a/core/services/players/models.py b/core/services/players/models.py new file mode 100644 index 00000000..8962235e --- /dev/null +++ b/core/services/players/models.py @@ -0,0 +1,96 @@ +from datetime import datetime +from typing import Optional + +from pydantic import BaseModel, BaseSettings +from sqlalchemy import TypeDecorator +from sqlmodel import Boolean, Column, Enum, Field, SQLModel, Integer, Index, BigInteger, VARCHAR, func, DateTime + +from core.basemodel import RegionEnum + +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + +__all__ = ("Player", "PlayersDataBase", "PlayerInfo", "PlayerInfoSQLModel") + + +class Player(SQLModel): + __table_args__ = ( + Index("index_user_account_player", "user_id", "account_id", "player_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + account_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + region: RegionEnum = Field(sa_column=Column(Enum(RegionEnum))) + is_chosen: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class PlayersDataBase(Player, table=True): + __tablename__ = "players" + + +class ExtraPlayerInfo(BaseModel): + class Config(BaseSettings.Config): + json_loads = jsonlib.loads + json_dumps = jsonlib.dumps + + waifu_id: Optional[int] = None + + +class ExtraPlayerType(TypeDecorator): # pylint: disable=W0223 + impl = VARCHAR(length=521) + + cache_ok = True + + def process_bind_param(self, value, dialect): + """ + :param value: ExtraPlayerInfo | obj | None + :param dialect: + :return: + """ + if value is not None: + if isinstance(value, ExtraPlayerInfo): + return value.json() + raise TypeError + return value + + def process_result_value(self, value, dialect): + """ + :param value: str | obj | None + :param dialect: + :return: + """ + if value is not None: + return ExtraPlayerInfo.parse_raw(value) + return None + + +class PlayerInfo(SQLModel): + __table_args__ = ( + Index("index_user_account_player", "user_id", "player_id", unique=True), + dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"), + ) + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + player_id: int = Field(primary_key=True, sa_column=Column(BigInteger())) + nickname: Optional[str] = Field() + signature: Optional[str] = Field() + hand_image: Optional[int] = Field() + name_card: Optional[int] = Field() + extra_data: Optional[ExtraPlayerInfo] = Field(sa_column=Column(ExtraPlayerType)) + create_time: Optional[datetime] = Field( + sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 + ) + last_save_time: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 + is_update: Optional[bool] = Field(sa_column=Column(Boolean)) + + +class PlayerInfoSQLModel(PlayerInfo, table=True): + __tablename__ = "players_info" diff --git a/core/services/players/repositories.py b/core/services/players/repositories.py new file mode 100644 index 00000000..91f68743 --- /dev/null +++ b/core/services/players/repositories.py @@ -0,0 +1,109 @@ +from typing import List, Optional + +from sqlmodel import select, delete + +from core.base_service import BaseService +from core.basemodel import RegionEnum +from core.dependence.mysql import MySQL +from core.services.players.models import PlayerInfoSQLModel +from core.services.players.models import PlayersDataBase as Player +from core.sqlmodel.session import AsyncSession + +__all__ = ("PlayersRepository", "PlayerInfoRepository") + + +class PlayersRepository(BaseService.Component): + def __init__(self, mysql: MySQL): + self.engine = mysql.engine + + async def get( + self, + user_id: int, + player_id: Optional[int] = None, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + is_chosen: Optional[bool] = None, + ) -> Optional[Player]: + async with AsyncSession(self.engine) as session: + statement = select(Player).where(Player.user_id == user_id) + if player_id is not None: + statement = statement.where(Player.player_id == player_id) + if account_id is not None: + statement = statement.where(Player.account_id == account_id) + if region is not None: + statement = statement.where(Player.region == region) + if is_chosen is not None: + statement = statement.where(Player.is_chosen == is_chosen) + results = await session.exec(statement) + return results.first() + + async def add(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + + async def delete(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(player) + await session.commit() + + async def update(self, player: Player) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + await session.refresh(player) + + async def get_all_by_user_id(self, user_id: int) -> List[Player]: + async with AsyncSession(self.engine) as session: + statement = select(Player).where(Player.user_id == user_id) + results = await session.exec(statement) + players = results.all() + return players + + +class PlayerInfoRepository(BaseService.Component): + def __init__(self, mysql: MySQL): + self.engine = mysql.engine + + async def get( + self, + user_id: int, + player_id: int, + ) -> Optional[PlayerInfoSQLModel]: + async with AsyncSession(self.engine) as session: + statement = ( + select(PlayerInfoSQLModel) + .where(PlayerInfoSQLModel.player_id == player_id) + .where(PlayerInfoSQLModel.user_id == user_id) + ) + results = await session.exec(statement) + return results.first() + + async def add(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + + async def delete(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + await session.delete(player) + await session.commit() + + async def delete_by_id( + self, + user_id: int, + player_id: int, + ) -> None: + async with AsyncSession(self.engine) as session: + statement = ( + delete(PlayerInfoSQLModel) + .where(PlayerInfoSQLModel.player_id == player_id) + .where(PlayerInfoSQLModel.user_id == user_id) + ) + await session.execute(statement) + + async def update(self, player: PlayerInfoSQLModel) -> None: + async with AsyncSession(self.engine) as session: + session.add(player) + await session.commit() + await session.refresh(player) diff --git a/core/services/players/services.py b/core/services/players/services.py new file mode 100644 index 00000000..5eaf8129 --- /dev/null +++ b/core/services/players/services.py @@ -0,0 +1,184 @@ +from datetime import datetime, timedelta +from typing import List, Optional + +from aiohttp import ClientConnectorError +from enkanetwork import ( + EnkaNetworkAPI, + VaildateUIDError, + HTTPException, + EnkaPlayerNotFound, + PlayerInfo as EnkaPlayerInfo, +) + +from core.base_service import BaseService +from core.basemodel import RegionEnum +from core.config import config +from core.dependence.redisdb import RedisDB +from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel, PlayerInfo +from core.services.players.repositories import PlayersRepository, PlayerInfoRepository +from utils.enkanetwork import RedisCache +from utils.log import logger +from utils.patch.aiohttp import AioHttpTimeoutException + +__all__ = ("PlayersService", "PlayerInfoService") + + +class PlayersService(BaseService): + def __init__(self, players_repository: PlayersRepository) -> None: + self._repository = players_repository + + async def get( + self, + user_id: int, + player_id: Optional[int] = None, + account_id: Optional[int] = None, + region: Optional[RegionEnum] = None, + is_chosen: Optional[bool] = None, + ) -> Optional[Player]: + return await self._repository.get(user_id, player_id, account_id, region, is_chosen) + + async def get_player(self, user_id: int, region: Optional[RegionEnum] = None) -> Optional[Player]: + return await self._repository.get(user_id, region=region, is_chosen=True) + + async def add(self, player: Player) -> None: + await self._repository.add(player) + + async def update(self, player: Player) -> None: + await self._repository.update(player) + + async def get_all_by_user_id(self, user_id: int) -> List[Player]: + return await self._repository.get_all_by_user_id(user_id) + + async def remove_all_by_user_id(self, user_id: int): + players = await self._repository.get_all_by_user_id(user_id) + for player in players: + await self._repository.delete(player) + + async def delete(self, player: Player): + await self._repository.delete(player) + + +class PlayerInfoService(BaseService): + def __init__(self, redis: RedisDB, players_info_repository: PlayerInfoRepository): + self.cache = redis.client + self._players_info_repository = players_info_repository + self.enka_client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent) + self.enka_client.set_cache(RedisCache(redis.client, key="players_info:enka_network", ttl=60)) + self.qname = "players_info" + + async def get_form_cache(self, player: Player): + qname = f"{self.qname}:{player.user_id}:{player.player_id}" + data = await self.cache.get(qname) + if data is None: + return None + json_data = str(data, encoding="utf-8") + return PlayerInfo.parse_raw(json_data) + + async def set_form_cache(self, player: PlayerInfo): + qname = f"{self.qname}:{player.user_id}:{player.player_id}" + await self.cache.set(qname, player.json(), ex=60) + + async def get_player_info_from_enka(self, player_id: int) -> Optional[EnkaPlayerInfo]: + try: + response = await self.enka_client.fetch_user(player_id, info=True) + return response.player + except (VaildateUIDError, EnkaPlayerNotFound, HTTPException) as exc: + logger.warning("EnkaNetwork 请求失败: %s", str(exc)) + except AioHttpTimeoutException as exc: + logger.warning("EnkaNetwork 请求超时: %s", str(exc)) + except ClientConnectorError as exc: + logger.warning("EnkaNetwork 请求错误: %s", str(exc)) + except Exception as exc: + logger.error("EnkaNetwork 请求失败: %s", exc_info=exc) + return None + + async def get(self, player: Player) -> Optional[PlayerInfo]: + player_info = await self.get_form_cache(player) + if player_info is not None: + return player_info + player_info = await self._players_info_repository.get(player.user_id, player.player_id) + if player_info is None: + player_info_enka = await self.get_player_info_from_enka(player.player_id) + if player_info_enka is None: + return None + player_info = PlayerInfo( + user_id=player.user_id, + player_id=player.player_id, + nickname=player_info_enka.nickname, + signature=player_info_enka.signature, + name_card=player_info_enka.namecard.id, + hand_image=player_info_enka.avatar.id, + create_time=datetime.now(), + last_save_time=datetime.now(), + is_update=True, + ) + await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info)) + await self.set_form_cache(player_info) + return player_info + if player_info.is_update: + expiration_time = datetime.now() - timedelta(days=7) + if player_info.last_save_time is None or player_info.last_save_time <= expiration_time: + player_info_enka = await self.get_player_info_from_enka(player.player_id) + if player_info_enka is None: + player_info.last_save_time = datetime.now() + await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info)) + await self.set_form_cache(player_info) + return player_info + player_info.nickname = player_info_enka.nickname + player_info.name_card = player_info_enka.namecard.id + player_info.signature = player_info_enka.signature + player_info.hand_image = player_info_enka.avatar.id + player_info.nickname = player_info_enka.nickname + player_info.last_save_time = datetime.now() + await self._players_info_repository.update(PlayerInfoSQLModel.from_orm(player_info)) + await self.set_form_cache(player_info) + return player_info + + async def update_from_enka(self, player: Player) -> bool: + player_info = await self._players_info_repository.get(player.user_id, player.player_id) + if player_info is not None: + player_info_enka = await self.get_player_info_from_enka(player.player_id) + if player_info_enka is None: + return False + player_info.nickname = player_info_enka.nickname + player_info.name_card = player_info_enka.namecard.id + player_info.signature = player_info_enka.signature + player_info.hand_image = player_info_enka.avatar.id + player_info.nickname = player_info_enka.nickname + player_info.last_save_time = datetime.now() + await self._players_info_repository.update(player_info) + return True + return False + + async def add_from_enka(self, player: Player) -> bool: + player_info = await self._players_info_repository.get(player.user_id, player.player_id) + if player_info is None: + player_info_enka = await self.get_player_info_from_enka(player.player_id) + if player_info_enka is None: + return False + player_info = PlayerInfoSQLModel( + user_id=player.user_id, + player_id=player.player_id, + nickname=player_info_enka.nickname, + signature=player_info_enka.signature, + name_card=player_info_enka.namecard.id, + hand_image=player_info_enka.avatar.id, + create_time=datetime.now(), + last_save_time=datetime.now(), + is_update=True, + ) + await self._players_info_repository.add(player_info) + return True + return False + + async def get_form_sql(self, player: Player): + return await self._players_info_repository.get(player.user_id, player.player_id) + + async def delete_form_player(self, player: Player): + await self._players_info_repository.delete_by_id(user_id=player.user_id, player_id=player.player_id) + + async def add(self, player_info: PlayerInfo): + await self._players_info_repository.add(PlayerInfoSQLModel.from_orm(player_info)) + + async def delete(self, player_info: PlayerInfo): + await self._players_info_repository.delete(PlayerInfoSQLModel.from_orm(player_info)) diff --git a/core/services/quiz/__init__.py b/core/services/quiz/__init__.py new file mode 100644 index 00000000..eff5d659 --- /dev/null +++ b/core/services/quiz/__init__.py @@ -0,0 +1 @@ +"""QuizService""" diff --git a/core/quiz/cache.py b/core/services/quiz/cache.py similarity index 85% rename from core/quiz/cache.py rename to core/services/quiz/cache.py index 867d87cf..22830dc6 100644 --- a/core/quiz/cache.py +++ b/core/services/quiz/cache.py @@ -1,12 +1,13 @@ from typing import List -import ujson +from core.base_service import BaseService +from core.dependence.redisdb import RedisDB +from core.services.quiz.models import Answer, Question -from core.base.redisdb import RedisDB -from .models import Answer, Question +__all__ = ("QuizCache",) -class QuizCache: +class QuizCache(BaseService.Component): def __init__(self, redis: RedisDB): self.client = redis.client self.question_qname = "quiz:question" @@ -18,7 +19,7 @@ async def get_all_question(self) -> List[Question]: data_list = [self.question_qname + f":{question_id}" for question_id in await self.client.lrange(qname, 0, -1)] data = await self.client.mget(data_list) for i in data: - temp_list.append(Question.de_json(ujson.loads(i))) + temp_list.append(Question.parse_raw(i)) return temp_list async def get_all_question_id_list(self) -> List[str]: @@ -29,19 +30,19 @@ async def get_one_question(self, question_id: int) -> Question: qname = f"{self.question_qname}:{question_id}" data = await self.client.get(qname) json_data = str(data, encoding="utf-8") - return Question.de_json(ujson.loads(json_data)) + return Question.parse_raw(json_data) async def get_one_answer(self, answer_id: int) -> Answer: qname = f"{self.answer_qname}:{answer_id}" data = await self.client.get(qname) json_data = str(data, encoding="utf-8") - return Answer.de_json(ujson.loads(json_data)) + return Answer.parse_raw(json_data) async def add_question(self, question_list: List[Question] = None) -> int: if not question_list: return 0 for question in question_list: - await self.client.set(f"{self.question_qname}:{question.question_id}", ujson.dumps(question.to_dict())) + await self.client.set(f"{self.question_qname}:{question.question_id}", question.json()) question_id_list = [question.question_id for question in question_list] await self.client.lpush(f"{self.question_qname}:id_list", *question_id_list) return await self.client.llen(f"{self.question_qname}:id_list") @@ -62,7 +63,7 @@ async def add_answer(self, answer_list: List[Answer] = None) -> int: if not answer_list: return 0 for answer in answer_list: - await self.client.set(f"{self.answer_qname}:{answer.answer_id}", ujson.dumps(answer.to_dict())) + await self.client.set(f"{self.answer_qname}:{answer.answer_id}", answer.json()) answer_id_list = [answer.answer_id for answer in answer_list] await self.client.lpush(f"{self.answer_qname}:id_list", *answer_id_list) return await self.client.llen(f"{self.answer_qname}:id_list") diff --git a/core/services/quiz/models.py b/core/services/quiz/models.py new file mode 100644 index 00000000..99b0074a --- /dev/null +++ b/core/services/quiz/models.py @@ -0,0 +1,57 @@ +from typing import List, Optional + +from pydantic import BaseModel +from sqlmodel import Column, Field, ForeignKey, Integer, SQLModel + +__all__ = ("Answer", "AnswerDB", "Question", "QuestionDB") + + +class AnswerDB(SQLModel, table=True): + __tablename__ = "answer" + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True) + ) + question_id: Optional[int] = Field( + sa_column=Column(Integer, ForeignKey("question.id", ondelete="RESTRICT", onupdate="RESTRICT")) + ) + is_correct: Optional[bool] = Field() + text: Optional[str] = Field() + + +class QuestionDB(SQLModel, table=True): + __tablename__ = "question" + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer, primary_key=True, autoincrement=True) + ) + text: Optional[str] = Field() + + +class Answer(BaseModel): + answer_id: int = 0 + question_id: int = 0 + is_correct: bool = True + text: str = "" + + def to_database_data(self) -> AnswerDB: + return AnswerDB(id=self.answer_id, question_id=self.question_id, text=self.text, is_correct=self.is_correct) + + @classmethod + def de_database_data(cls, data: AnswerDB) -> Optional["Answer"]: + return cls(answer_id=data.id, question_id=data.question_id, text=data.text, is_correct=data.is_correct) + + +class Question(BaseModel): + question_id: int = 0 + text: str = "" + answers: List[Answer] = [] + + def to_database_data(self) -> QuestionDB: + return QuestionDB(text=self.text, id=self.question_id) + + @classmethod + def de_database_data(cls, data: QuestionDB) -> Optional["Question"]: + return cls(question_id=data.id, text=data.text) diff --git a/core/quiz/repositories.py b/core/services/quiz/repositories.py similarity index 63% rename from core/quiz/repositories.py rename to core/services/quiz/repositories.py index 5096b243..28542e8f 100644 --- a/core/quiz/repositories.py +++ b/core/services/quiz/repositories.py @@ -2,54 +2,55 @@ from sqlmodel import select -from core.base.mysql import MySQL -from .models import AnswerDB, QuestionDB +from core.base_service import BaseService +from core.dependence.mysql import MySQL +from core.services.quiz.models import AnswerDB, QuestionDB +from core.sqlmodel.session import AsyncSession +__all__ = ("QuizRepository",) -class QuizRepository: + +class QuizRepository(BaseService.Component): def __init__(self, mysql: MySQL): - self.mysql = mysql + self.engine = mysql.engine async def get_question_list(self) -> List[QuestionDB]: - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: query = select(QuestionDB) results = await session.exec(query) - questions = results.all() - return questions + return results.all() async def get_answers_from_question_id(self, question_id: int) -> List[AnswerDB]: - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: query = select(AnswerDB).where(AnswerDB.question_id == question_id) results = await session.exec(query) - answers = results.all() - return answers + return results.all() async def add_question(self, question: QuestionDB): - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: session.add(question) await session.commit() async def get_question_by_text(self, text: str) -> QuestionDB: - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: query = select(QuestionDB).where(QuestionDB.text == text) results = await session.exec(query) - question = results.first() - return question[0] + return results.first() async def add_answer(self, answer: AnswerDB): - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: session.add(answer) await session.commit() async def delete_question_by_id(self, question_id: int): - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: statement = select(QuestionDB).where(QuestionDB.id == question_id) results = await session.exec(statement) question = results.one() await session.delete(question) async def delete_answer_by_id(self, answer_id: int): - async with self.mysql.Session() as session: + async with AsyncSession(self.engine) as session: statement = select(AnswerDB).where(AnswerDB.id == answer_id) results = await session.exec(statement) answer = results.one() diff --git a/core/quiz/services.py b/core/services/quiz/services.py similarity index 89% rename from core/quiz/services.py rename to core/services/quiz/services.py index a5f543f3..d3e658ff 100644 --- a/core/quiz/services.py +++ b/core/services/quiz/services.py @@ -1,12 +1,15 @@ import asyncio from typing import List -from .cache import QuizCache -from .models import Answer, Question -from .repositories import QuizRepository +from core.base_service import BaseService +from core.services.quiz.cache import QuizCache +from core.services.quiz.models import Answer, Question +from core.services.quiz.repositories import QuizRepository +__all__ = ("QuizService",) -class QuizService: + +class QuizService(BaseService): def __init__(self, repository: QuizRepository, cache: QuizCache): self._repository = repository self._cache = cache diff --git a/core/services/search/__init__.py b/core/services/search/__init__.py new file mode 100644 index 00000000..32e4d963 --- /dev/null +++ b/core/services/search/__init__.py @@ -0,0 +1 @@ +"""SearchService""" diff --git a/core/search/models.py b/core/services/search/models.py similarity index 94% rename from core/search/models.py rename to core/services/search/models.py index 524e42d1..03e875e1 100644 --- a/core/search/models.py +++ b/core/services/search/models.py @@ -1,12 +1,11 @@ from abc import abstractmethod -from typing import Optional, List +from typing import List, Optional from pydantic import BaseModel - -__all__ = ["BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList"] - from thefuzz import fuzz +__all__ = ("BaseEntry", "WeaponEntry", "WeaponsEntry", "StrategyEntry", "StrategyEntryList") + class BaseEntry(BaseModel): """所有可搜索条目的基类。 diff --git a/core/search/services.py b/core/services/search/services.py similarity index 95% rename from core/search/services.py rename to core/services/search/services.py index 94d3de74..a8c40a7c 100644 --- a/core/search/services.py +++ b/core/services/search/services.py @@ -5,19 +5,22 @@ import os import time from pathlib import Path -from typing import Tuple, List, Optional, Dict +from typing import Dict, List, Optional, Tuple import aiofiles from async_lru import alru_cache -from core.search.models import WeaponEntry, BaseEntry, WeaponsEntry, StrategyEntry, StrategyEntryList +from core.base_service import BaseService +from core.services.search.models import BaseEntry, StrategyEntry, StrategyEntryList, WeaponEntry, WeaponsEntry from utils.const import PROJECT_ROOT +__all__ = ("SearchServices",) + ENTRY_DAYA_PATH = PROJECT_ROOT.joinpath("data", "entry") ENTRY_DAYA_PATH.mkdir(parents=True, exist_ok=True) -class SearchServices: +class SearchServices(BaseService): def __init__(self): self._lock = asyncio.Lock() # 访问和修改操作成员变量必须加锁操作 self.weapons: List[WeaponEntry] = [] diff --git a/core/services/sign/__init__.py b/core/services/sign/__init__.py new file mode 100644 index 00000000..9b51e2f5 --- /dev/null +++ b/core/services/sign/__init__.py @@ -0,0 +1 @@ +"""SignService""" diff --git a/core/sign/models.py b/core/services/sign/models.py similarity index 57% rename from core/sign/models.py rename to core/services/sign/models.py index 4a46aece..dd62239d 100644 --- a/core/sign/models.py +++ b/core/services/sign/models.py @@ -2,8 +2,10 @@ from datetime import datetime from typing import Optional -from sqlalchemy import func -from sqlmodel import SQLModel, Field, Enum, Column, DateTime +from sqlalchemy import func, BigInteger +from sqlmodel import Column, DateTime, Enum, Field, SQLModel, Integer + +__all__ = ("SignStatusEnum", "Sign") class SignStatusEnum(int, enum.Enum): @@ -19,10 +21,13 @@ class SignStatusEnum(int, enum.Enum): class Sign(SQLModel, table=True): __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - user_id: int = Field(foreign_key="user.user_id") + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(primary_key=True, sa_column=Column(BigInteger(), index=True)) chat_id: Optional[int] = Field(default=None) - time_created: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), server_default=func.now())) - time_updated: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True), onupdate=func.now())) + time_created: Optional[datetime] = Field( + sa_column=Column(DateTime, server_default=func.now()) # pylint: disable=E1102 + ) + time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 status: Optional[SignStatusEnum] = Field(sa_column=Column(Enum(SignStatusEnum))) diff --git a/core/services/sign/repositories.py b/core/services/sign/repositories.py new file mode 100644 index 00000000..ef5d108a --- /dev/null +++ b/core/services/sign/repositories.py @@ -0,0 +1,50 @@ +from typing import List, Optional + +from sqlmodel import select + +from core.base_service import BaseService +from core.dependence.mysql import MySQL +from core.services.sign.models import Sign +from core.sqlmodel.session import AsyncSession + +__all__ = ("SignRepository",) + + +class SignRepository(BaseService.Component): + def __init__(self, mysql: MySQL): + self.engine = mysql.engine + + async def add(self, sign: Sign): + async with AsyncSession(self.engine) as session: + session.add(sign) + await session.commit() + + async def remove(self, sign: Sign): + async with AsyncSession(self.engine) as session: + await session.delete(sign) + await session.commit() + + async def update(self, sign: Sign) -> Sign: + async with AsyncSession(self.engine) as session: + session.add(sign) + await session.commit() + await session.refresh(sign) + return sign + + async def get_by_user_id(self, user_id: int) -> Optional[Sign]: + async with AsyncSession(self.engine) as session: + statement = select(Sign).where(Sign.user_id == user_id) + results = await session.exec(statement) + return results.first() + + async def get_by_chat_id(self, chat_id: int) -> Optional[List[Sign]]: + async with AsyncSession(self.engine) as session: + statement = select(Sign).where(Sign.chat_id == chat_id) + results = await session.exec(statement) + return results.all() + + async def get_all(self) -> List[Sign]: + async with AsyncSession(self.engine) as session: + query = select(Sign) + results = await session.exec(query) + return results.all() diff --git a/core/sign/services.py b/core/services/sign/services.py similarity index 77% rename from core/sign/services.py rename to core/services/sign/services.py index 9772ea74..74de7aa8 100644 --- a/core/sign/services.py +++ b/core/services/sign/services.py @@ -1,8 +1,11 @@ -from .models import Sign -from .repositories import SignRepository +from core.base_service import BaseService +from core.services.sign.models import Sign +from core.services.sign.repositories import SignRepository +__all__ = ["SignServices"] -class SignServices: + +class SignServices(BaseService): def __init__(self, sign_repository: SignRepository) -> None: self._repository: SignRepository = sign_repository diff --git a/core/template/README.md b/core/services/template/README.md similarity index 100% rename from core/template/README.md rename to core/services/template/README.md diff --git a/core/services/template/__init__.py b/core/services/template/__init__.py new file mode 100644 index 00000000..79551dd6 --- /dev/null +++ b/core/services/template/__init__.py @@ -0,0 +1 @@ +"""TemplateService""" diff --git a/core/template/cache.py b/core/services/template/cache.py similarity index 86% rename from core/template/cache.py rename to core/services/template/cache.py index 5df3529f..139763a7 100644 --- a/core/template/cache.py +++ b/core/services/template/cache.py @@ -3,10 +3,14 @@ from hashlib import sha256 from typing import Any, Optional -from core.base.redisdb import RedisDB +from core.base_service import BaseService +from core.dependence.redisdb import RedisDB -class TemplatePreviewCache: +__all__ = ["TemplatePreviewCache", "HtmlToFileIdCache"] + + +class TemplatePreviewCache(BaseService.Component): """暂存渲染模板的数据用于预览""" def __init__(self, redis: RedisDB): @@ -29,7 +33,7 @@ def cache_key(self, key: str) -> str: return f"{self.qname}:{key}" -class HtmlToFileIdCache: +class HtmlToFileIdCache(BaseService.Component): """html to file_id 的缓存""" def __init__(self, redis: RedisDB): diff --git a/core/template/error.py b/core/services/template/error.py similarity index 100% rename from core/template/error.py rename to core/services/template/error.py diff --git a/core/template/models.py b/core/services/template/models.py similarity index 91% rename from core/template/models.py rename to core/services/template/models.py index b6cbdfd6..9af68f34 100644 --- a/core/template/models.py +++ b/core/services/template/models.py @@ -1,10 +1,12 @@ from enum import Enum -from typing import Optional, Union, List +from typing import List, Optional, Union -from telegram import Message, InputMediaPhoto, InputMediaDocument +from telegram import InputMediaDocument, InputMediaPhoto, Message -from core.template.cache import HtmlToFileIdCache -from core.template.error import ErrorFileType, FileIdNotFound +from core.services.template.cache import HtmlToFileIdCache +from core.services.template.error import ErrorFileType, FileIdNotFound + +__all__ = ["FileType", "RenderResult", "RenderGroupResult"] class FileType(Enum): @@ -16,10 +18,9 @@ def media_type(file_type: "FileType"): """对应的 Telegram media 类型""" if file_type == FileType.PHOTO: return InputMediaPhoto - elif file_type == FileType.DOCUMENT: + if file_type == FileType.DOCUMENT: return InputMediaDocument - else: - raise ErrorFileType + raise ErrorFileType class RenderResult: diff --git a/core/template/services.py b/core/services/template/services.py similarity index 71% rename from core/template/services.py rename to core/services/template/services.py index 52a585bc..dd61d7c3 100644 --- a/core/template/services.py +++ b/core/services/template/services.py @@ -1,44 +1,31 @@ -import time +import asyncio from typing import Optional -from urllib.parse import ( - urlencode, - urljoin, - urlsplit, -) +from urllib.parse import urlencode, urljoin, urlsplit from uuid import uuid4 -from fastapi import HTTPException -from fastapi.responses import ( - FileResponse, - HTMLResponse, -) +from fastapi import FastAPI, HTTPException +from fastapi.responses import FileResponse, HTMLResponse from fastapi.staticfiles import StaticFiles -from jinja2 import ( - Environment, - FileSystemLoader, - Template, -) +from jinja2 import Environment, FileSystemLoader, Template from playwright.async_api import ViewportSize -from core.base.aiobrowser import AioBrowser -from core.base.webserver import webapp -from core.bot import bot -from core.template.cache import ( - HtmlToFileIdCache, - TemplatePreviewCache, -) -from core.template.error import QuerySelectorNotFound -from core.template.models import ( - FileType, - RenderResult, -) +from core.application import Application +from core.base_service import BaseService +from core.config import config as application_config +from core.dependence.aiobrowser import AioBrowser +from core.services.template.cache import HtmlToFileIdCache, TemplatePreviewCache +from core.services.template.error import QuerySelectorNotFound +from core.services.template.models import FileType, RenderResult from utils.const import PROJECT_ROOT from utils.log import logger +__all__ = ("TemplateService", "TemplatePreviewer") -class TemplateService: + +class TemplateService(BaseService): def __init__( self, + app: Application, browser: AioBrowser, html_to_file_id_cache: HtmlToFileIdCache, preview_cache: TemplatePreviewCache, @@ -51,10 +38,12 @@ def __init__( loader=FileSystemLoader(template_dir), enable_async=True, autoescape=True, - auto_reload=bot.config.debug, + auto_reload=application_config.debug, ) + self.using_preview = application_config.debug and application_config.webserver.enable - self.previewer = TemplatePreviewer(self, preview_cache) + if self.using_preview: + self.previewer = TemplatePreviewer(self, preview_cache, app.web_app) self.html_to_file_id_cache = html_to_file_id_cache @@ -66,10 +55,11 @@ async def render_async(self, template_name: str, template_data: dict) -> str: :param template_name: 模板文件名 :param template_data: 模板数据 """ - start_time = time.time() + loop = asyncio.get_event_loop() + start_time = loop.time() template = self.get_template(template_name) html = await template.render_async(**template_data) - logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}") + logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) return html async def render( @@ -100,19 +90,20 @@ async def render( :param filename: 文件名字 :return: """ - start_time = time.time() + loop = asyncio.get_event_loop() + start_time = loop.time() template = self.get_template(template_name) - if bot.config.debug: + if self.using_preview: preview_url = await self.previewer.get_preview_url(template_name, template_data) - logger.debug(f"调试模板 URL: {preview_url}") + logger.debug("调试模板 URL: \n%s", preview_url) html = await template.render_async(**template_data) - logger.debug(f"{template_name} 模板渲染使用了 {str(time.time() - start_time)}") + logger.debug("%s 模板渲染使用了 %s", template_name, str(loop.time() - start_time)) file_id = await self.html_to_file_id_cache.get_data(html, file_type.name) - if file_id and not bot.config.debug: - logger.debug(f"{template_name} 命中缓存,返回 file_id {file_id}") + if file_id and not application_config.debug: + logger.debug("%s 命中缓存,返回 file_id[%s]", template_name, file_id) return RenderResult( html=html, photo=file_id, @@ -125,7 +116,7 @@ async def render( ) browser = await self._browser.get_browser() - start_time = time.time() + start_time = loop.time() page = await browser.new_page(viewport=viewport) uri = (PROJECT_ROOT / template.filename).as_uri() await page.goto(uri) @@ -142,10 +133,10 @@ async def render( if not clip: raise QuerySelectorNotFound except QuerySelectorNotFound: - logger.warning(f"未找到 {query_selector} 元素") + logger.warning("未找到 %s 元素", query_selector) png_data = await page.screenshot(clip=clip, full_page=full_page) await page.close() - logger.debug(f"{template_name} 图片渲染使用了 {str(time.time() - start_time)}") + logger.debug("%s 图片渲染使用了 %s", template_name, str(loop.time() - start_time)) return RenderResult( html=html, photo=png_data, @@ -158,15 +149,21 @@ async def render( ) -class TemplatePreviewer: - def __init__(self, template_service: TemplateService, cache: TemplatePreviewCache): +class TemplatePreviewer(BaseService, load=application_config.webserver.enable and application_config.debug): + def __init__( + self, + template_service: TemplateService, + cache: TemplatePreviewCache, + web_app: FastAPI, + ): + self.web_app = web_app self.template_service = template_service self.cache = cache self.register_routes() async def get_preview_url(self, template: str, data: dict): """获取预览 URL""" - components = urlsplit(bot.config.webserver.url) + components = urlsplit(application_config.webserver.url) path = urljoin("/preview/", template) query = {} @@ -176,12 +173,13 @@ async def get_preview_url(self, template: str, data: dict): await self.cache.set_data(key, data) query["key"] = key + # noinspection PyProtectedMember return components._replace(path=path, query=urlencode(query)).geturl() def register_routes(self): """注册预览用到的路由""" - @webapp.get("/preview/{path:path}") + @self.web_app.get("/preview/{path:path}") async def preview_template(path: str, key: Optional[str] = None): # pylint: disable=W0612 # 如果是 /preview/ 开头的静态文件,直接返回内容。比如使用相对链接 ../ 引入的静态资源 if not path.endswith(".html"): @@ -206,4 +204,4 @@ async def preview_template(path: str, key: Optional[str] = None): # pylint: dis for name in ["cache", "resources"]: directory = PROJECT_ROOT / name directory.mkdir(exist_ok=True) - webapp.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) + self.web_app.mount(f"/{name}", StaticFiles(directory=PROJECT_ROOT / name), name=name) diff --git a/core/services/users/__init__.py b/core/services/users/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/core/services/users/cache.py b/core/services/users/cache.py new file mode 100644 index 00000000..28afee81 --- /dev/null +++ b/core/services/users/cache.py @@ -0,0 +1,24 @@ +from typing import List + +from core.base_service import BaseService +from core.dependence.redisdb import RedisDB + +__all__ = ("UserAdminCache",) + + +class UserAdminCache(BaseService.Component): + def __init__(self, redis: RedisDB): + self.client = redis.client + self.qname = "users:admin" + + async def ismember(self, user_id: int) -> bool: + return self.client.sismember(self.qname, user_id) + + async def get_all(self) -> List[int]: + return [int(str_data) for str_data in await self.client.smembers(self.qname)] + + async def set(self, user_id: int) -> bool: + return await self.client.sadd(self.qname, user_id) + + async def remove(self, user_id: int) -> bool: + return await self.client.srem(self.qname, user_id) diff --git a/core/services/users/models.py b/core/services/users/models.py new file mode 100644 index 00000000..5d5fa1ce --- /dev/null +++ b/core/services/users/models.py @@ -0,0 +1,34 @@ +import enum +from datetime import datetime +from typing import Optional + +from sqlmodel import SQLModel, Field, DateTime, Column, Enum, BigInteger, Integer + +__all__ = ( + "User", + "UserDataBase", + "PermissionsEnum", +) + + +class PermissionsEnum(int, enum.Enum): + OWNER = 1 + ADMIN = 2 + PUBLIC = 3 + + +class User(SQLModel): + __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") + id: Optional[int] = Field( + default=None, primary_key=True, sa_column=Column(Integer(), primary_key=True, autoincrement=True) + ) + user_id: int = Field(unique=True, sa_column=Column(BigInteger())) + permissions: Optional[PermissionsEnum] = Field(sa_column=Column(Enum(PermissionsEnum))) + locale: Optional[str] = Field() + ban_end_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) + ban_start_time: Optional[datetime] = Field(sa_column=Column(DateTime(timezone=True))) + is_banned: Optional[int] = Field() + + +class UserDataBase(User, table=True): + __tablename__ = "users" diff --git a/core/services/users/repositories.py b/core/services/users/repositories.py new file mode 100644 index 00000000..71d963bc --- /dev/null +++ b/core/services/users/repositories.py @@ -0,0 +1,44 @@ +from typing import Optional, List + +from sqlmodel import select + +from core.base_service import BaseService +from core.dependence.mysql import MySQL +from core.services.users.models import UserDataBase as User +from core.sqlmodel.session import AsyncSession + +__all__ = ("UserRepository",) + + +class UserRepository(BaseService.Component): + def __init__(self, mysql: MySQL): + self.engine = mysql.engine + + async def get_by_user_id(self, user_id: int) -> Optional[User]: + async with AsyncSession(self.engine) as session: + statement = select(User).where(User.user_id == user_id) + results = await session.exec(statement) + return results.first() + + async def add(self, user: User): + async with AsyncSession(self.engine) as session: + session.add(user) + await session.commit() + + async def update(self, user: User) -> User: + async with AsyncSession(self.engine) as session: + session.add(user) + await session.commit() + await session.refresh(user) + return user + + async def remove(self, user: User): + async with AsyncSession(self.engine) as session: + await session.delete(user) + await session.commit() + + async def get_all(self) -> List[User]: + async with AsyncSession(self.engine) as session: + statement = select(User) + results = await session.exec(statement) + return results.all() diff --git a/core/services/users/services.py b/core/services/users/services.py new file mode 100644 index 00000000..25d08f90 --- /dev/null +++ b/core/services/users/services.py @@ -0,0 +1,79 @@ +from typing import List, Optional + +from core.base_service import BaseService +from core.config import config +from core.services.users.cache import UserAdminCache +from core.services.users.models import PermissionsEnum, UserDataBase as User +from core.services.users.repositories import UserRepository + +__all__ = ("UserService", "UserAdminService") + +from utils.log import logger + + +class UserService(BaseService): + def __init__(self, user_repository: UserRepository) -> None: + self._repository: UserRepository = user_repository + + async def get_user_by_id(self, user_id: int) -> Optional[User]: + """从数据库获取用户信息 + :param user_id:用户ID + :return: User + """ + return await self._repository.get_by_user_id(user_id) + + async def remove(self, user: User): + return await self._repository.remove(user) + + async def update_user(self, user: User): + return await self._repository.add(user) + + +class UserAdminService(BaseService): + def __init__(self, user_repository: UserRepository, cache: UserAdminCache): + self.user_repository = user_repository + self._cache = cache + + async def initialize(self): + owner = config.owner + if owner: + user = await self.user_repository.get_by_user_id(owner) + await self._cache.set(user.user_id) + if user: + if user.permissions != PermissionsEnum.OWNER: + user.permissions = PermissionsEnum.OWNER + await self.user_repository.update(user) + else: + user = User(user_id=owner, permissions=PermissionsEnum.OWNER) + await self.user_repository.add(user) + else: + logger.warning("检测到未配置Bot所有者 会导无法正常使用管理员权限") + + async def is_admin(self, user_id: int) -> bool: + return await self._cache.ismember(user_id) + + async def get_admin_list(self) -> List[int]: + return await self._cache.get_all() + + async def add_admin(self, user_id: int) -> bool: + user = await self.user_repository.get_by_user_id(user_id) + if user: + if user.permissions == PermissionsEnum.OWNER: + return False + if user.permissions != PermissionsEnum.ADMIN: + user.permissions = PermissionsEnum.ADMIN + await self.user_repository.update(user) + else: + user = User(user_id=user_id, permissions=PermissionsEnum.ADMIN) + await self.user_repository.add(user) + return await self._cache.set(user.user_id) + + async def delete_admin(self, user_id: int) -> bool: + user = await self.user_repository.get_by_user_id(user_id) + if user: + if user.permissions == PermissionsEnum.OWNER: + return True # 假装移除成功 + user.permissions = PermissionsEnum.PUBLIC + await self.user_repository.update(user) + return await self._cache.remove(user.user_id) + return False diff --git a/core/services/wiki/__init__.py b/core/services/wiki/__init__.py new file mode 100644 index 00000000..042a3d34 --- /dev/null +++ b/core/services/wiki/__init__.py @@ -0,0 +1 @@ +"""WikiService""" diff --git a/core/wiki/cache.py b/core/services/wiki/cache.py similarity index 86% rename from core/wiki/cache.py rename to core/services/wiki/cache.py index a73ee80f..213227f2 100644 --- a/core/wiki/cache.py +++ b/core/services/wiki/cache.py @@ -1,10 +1,13 @@ import ujson as json -from core.base.redisdb import RedisDB +from core.base_service import BaseService +from core.dependence.redisdb import RedisDB from modules.wiki.base import Model +__all__ = ["WikiCache"] -class WikiCache: + +class WikiCache(BaseService.Component): def __init__(self, redis: RedisDB): self.client = redis.client self.qname = "wiki" diff --git a/core/wiki/services.py b/core/services/wiki/services.py similarity index 87% rename from core/wiki/services.py rename to core/services/wiki/services.py index 891fc08c..b8758d66 100644 --- a/core/wiki/services.py +++ b/core/services/wiki/services.py @@ -1,12 +1,15 @@ from typing import List, NoReturn, Optional -from core.wiki.cache import WikiCache +from core.base_service import BaseService +from core.services.wiki.cache import WikiCache from modules.wiki.character import Character from modules.wiki.weapon import Weapon from utils.log import logger +__all__ = ["WikiService"] -class WikiService: + +class WikiService(BaseService): def __init__(self, cache: WikiCache): self._cache = cache """Redis 在这里的作用是作为持久化""" @@ -18,7 +21,7 @@ def __init__(self, cache: WikiCache): async def refresh_weapon(self) -> NoReturn: weapon_name_list = await Weapon.get_name_list() - logger.info(f"一共找到 {len(weapon_name_list)} 把武器信息") + logger.info("一共找到 %s 把武器信息", len(weapon_name_list)) weapon_list = [] num = 0 @@ -26,7 +29,7 @@ async def refresh_weapon(self) -> NoReturn: weapon_list.append(weapon) num += 1 if num % 10 == 0: - logger.info(f"现在已经获取到 {num} 把武器信息") + logger.info("现在已经获取到 %s 把武器信息", num) logger.info("写入武器信息到Redis") self._weapon_list = weapon_list @@ -35,7 +38,7 @@ async def refresh_weapon(self) -> NoReturn: async def refresh_characters(self) -> NoReturn: character_name_list = await Character.get_name_list() - logger.info(f"一共找到 {len(character_name_list)} 个角色信息") + logger.info("一共找到 %s 个角色信息", len(character_name_list)) character_list = [] num = 0 @@ -43,7 +46,7 @@ async def refresh_characters(self) -> NoReturn: character_list.append(character) num += 1 if num % 10 == 0: - logger.info(f"现在已经获取到 {num} 个角色信息") + logger.info("现在已经获取到 %s 个角色信息", num) logger.info("写入角色信息到Redis") self._character_list = character_list diff --git a/core/sign/__init__.py b/core/sign/__init__.py deleted file mode 100644 index e396a02b..00000000 --- a/core/sign/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from core.base.mysql import MySQL -from core.service import init_service -from .repositories import SignRepository -from .services import SignServices - - -@init_service -def create_game_strategy_service(mysql: MySQL): - _repository = SignRepository(mysql) - _service = SignServices(_repository) - return _service diff --git a/core/sign/repositories.py b/core/sign/repositories.py deleted file mode 100644 index 54c0b133..00000000 --- a/core/sign/repositories.py +++ /dev/null @@ -1,53 +0,0 @@ -from typing import List, Optional, cast - -from sqlalchemy import select -from sqlalchemy.ext.asyncio import AsyncSession - -from core.base.mysql import MySQL -from .models import Sign - - -class SignRepository: - def __init__(self, mysql: MySQL): - self.mysql = mysql - - async def add(self, sign: Sign): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - session.add(sign) - await session.commit() - - async def remove(self, sign: Sign): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - await session.delete(sign) - await session.commit() - - async def update(self, sign: Sign): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - session.add(sign) - await session.commit() - await session.refresh(sign) - - async def get_by_user_id(self, user_id: int) -> Optional[Sign]: - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - statement = select(Sign).where(Sign.user_id == user_id) - results = await session.exec(statement) - return sign[0] if (sign := results.first()) else None - - async def get_by_chat_id(self, chat_id: int) -> Optional[List[Sign]]: - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - statement = select(Sign).where(Sign.chat_id == chat_id) - results = await session.exec(statement) - signs = results.all() - return [sign[0] for sign in signs] - - async def get_all(self) -> List[Sign]: - async with self.mysql.Session() as session: - query = select(Sign) - results = await session.exec(query) - signs = results.all() - return [sign[0] for sign in signs] diff --git a/core/sqlmodel/session.py b/core/sqlmodel/session.py new file mode 100644 index 00000000..88e4d3da --- /dev/null +++ b/core/sqlmodel/session.py @@ -0,0 +1,118 @@ +from typing import Any, Mapping, Optional, Sequence, Type, TypeVar, Union, overload + +from sqlalchemy import util +from sqlalchemy.ext.asyncio import AsyncSession as _AsyncSession +from sqlalchemy.ext.asyncio.engine import AsyncConnection, AsyncEngine +from sqlalchemy.sql.base import Executable as _Executable +from sqlmodel.engine.result import Result, ScalarResult +from sqlmodel.orm.session import Session +from sqlmodel.sql.base import Executable +from sqlmodel.sql.expression import Select, SelectOfScalar +from typing_extensions import Literal + +_TSelectParam = TypeVar("_TSelectParam") + +__all__ = ("AsyncSession",) + + +class AsyncSession(_AsyncSession): # pylint: disable=W0223 + sync_session_class = Session + sync_session: Session + + def __init__( + self, + bind: Optional[Union[AsyncConnection, AsyncEngine]] = None, + binds: Optional[Mapping[object, Union[AsyncConnection, AsyncEngine]]] = None, + sync_session_class: Type[Session] = Session, + **kw: Any, + ): + super().__init__( + bind=bind, + binds=binds, + sync_session_class=sync_session_class, + **kw, + ) + + @overload + async def exec( + self, + statement: Select[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Result[_TSelectParam]: + ... + + @overload + async def exec( + self, + statement: SelectOfScalar[_TSelectParam], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> ScalarResult[_TSelectParam]: + ... + + async def exec( + self, + statement: Union[ + Select[_TSelectParam], + SelectOfScalar[_TSelectParam], + Executable[_TSelectParam], + ], + *, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Mapping[str, Any] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Union[Result[_TSelectParam], ScalarResult[_TSelectParam]]: + results = super().execute( + statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + if isinstance(statement, SelectOfScalar): + return (await results).scalars() # type: ignore + return await results # type: ignore + + async def execute( # pylint: disable=W0221 + self, + statement: _Executable, + params: Optional[Union[Mapping[str, Any], Sequence[Mapping[str, Any]]]] = None, + execution_options: Optional[Mapping[str, Any]] = util.EMPTY_DICT, + bind_arguments: Optional[Mapping[str, Any]] = None, + **kw: Any, + ) -> Result[Any]: + return await super().execute( # type: ignore + statement=statement, + params=params, + execution_options=execution_options, + bind_arguments=bind_arguments, + **kw, + ) + + async def get( # pylint: disable=W0221 + self, + entity: Type[_TSelectParam], + ident: Any, + options: Optional[Sequence[Any]] = None, + populate_existing: bool = False, + with_for_update: Optional[Union[Literal[True], Mapping[str, Any]]] = None, + identity_token: Optional[Any] = None, + execution_options: Optional[Mapping[Any, Any]] = util.EMPTY_DICT, + ) -> Optional[_TSelectParam]: + return await super().get( + entity=entity, + ident=ident, + options=options, + populate_existing=populate_existing, + with_for_update=with_for_update, + identity_token=identity_token, + execution_options=execution_options, + ) diff --git a/core/template/__init__.py b/core/template/__init__.py deleted file mode 100644 index d6c695fe..00000000 --- a/core/template/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -from core.base.aiobrowser import AioBrowser -from core.service import init_service -from core.base.redisdb import RedisDB -from core.template.services import TemplateService -from core.template.cache import TemplatePreviewCache, HtmlToFileIdCache - - -@init_service -def create_template_service(browser: AioBrowser, redis: RedisDB): - _preview_cache = TemplatePreviewCache(redis) - _html_to_file_id_cache = HtmlToFileIdCache(redis) - _service = TemplateService(browser, _html_to_file_id_cache, _preview_cache) - return _service diff --git a/core/user/__init__.py b/core/user/__init__.py deleted file mode 100644 index 2189d238..00000000 --- a/core/user/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from core.base.mysql import MySQL -from core.service import init_service -from .repositories import UserRepository -from .services import UserService - - -@init_service -def create_user_service(mysql: MySQL): - _repository = UserRepository(mysql) - _service = UserService(_repository) - return _service diff --git a/core/user/error.py b/core/user/error.py deleted file mode 100644 index 13b8b9c2..00000000 --- a/core/user/error.py +++ /dev/null @@ -1,3 +0,0 @@ -class UserNotFoundError(Exception): - def __init__(self, user_id): - super().__init__(f"user not found, user_id: {user_id}") diff --git a/core/user/models.py b/core/user/models.py deleted file mode 100644 index f796b9d3..00000000 --- a/core/user/models.py +++ /dev/null @@ -1,15 +0,0 @@ -from typing import Optional - -from sqlmodel import SQLModel, Field, Enum, Column - -from utils.models.base import RegionEnum - - -class User(SQLModel, table=True): - __table_args__ = dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci") - - id: int = Field(primary_key=True) - user_id: int = Field(unique=True) - yuanshen_uid: Optional[int] = Field() - genshin_uid: Optional[int] = Field() - region: Optional[RegionEnum] = Field(sa_column=Column(Enum(RegionEnum))) diff --git a/core/user/repositories.py b/core/user/repositories.py deleted file mode 100644 index 616ab760..00000000 --- a/core/user/repositories.py +++ /dev/null @@ -1,48 +0,0 @@ -from typing import cast - -from sqlalchemy import select -from sqlmodel.ext.asyncio.session import AsyncSession - -from core.base.mysql import MySQL -from .error import UserNotFoundError -from .models import User - - -class UserRepository: - def __init__(self, mysql: MySQL): - self.mysql = mysql - - async def get_by_user_id(self, user_id: int) -> User: - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - statement = select(User).where(User.user_id == user_id) - results = await session.exec(statement) - if user := results.first(): - return user[0] - else: - raise UserNotFoundError(user_id) - - async def update_user(self, user: User): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - session.add(user) - await session.commit() - await session.refresh(user) - - async def add_user(self, user: User): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - session.add(user) - await session.commit() - - async def del_user_by_id(self, user_id): - async with self.mysql.Session() as session: - session = cast(AsyncSession, session) - statement = select(User).where(User.user_id == user_id) - results = await session.execute(statement) - user = results.unique().scalar_one() - if user: - await session.delete(user) - await session.commit() - else: - raise UserNotFoundError(user_id) diff --git a/core/user/services.py b/core/user/services.py deleted file mode 100644 index 1b00b284..00000000 --- a/core/user/services.py +++ /dev/null @@ -1,23 +0,0 @@ -from .models import User -from .repositories import UserRepository - - -class UserService: - def __init__(self, user_repository: UserRepository) -> None: - self._repository: UserRepository = user_repository - - async def get_user_by_id(self, user_id: int) -> User: - """从数据库获取用户信息 - :param user_id:用户ID - :return: User - """ - return await self._repository.get_by_user_id(user_id) - - async def del_user_by_id(self, user_id: int) -> User: - return await self._repository.del_user_by_id(user_id) - - async def update_user(self, user: User) -> User: - return await self._repository.update_user(user) - - async def add_user(self, user: User) -> User: - return await self._repository.add_user(user) diff --git a/core/wiki/__init__.py b/core/wiki/__init__.py deleted file mode 100644 index 9c975445..00000000 --- a/core/wiki/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from core.base.redisdb import RedisDB -from core.service import init_service -from .cache import WikiCache -from .services import WikiService - - -@init_service -def create_wiki_service(redis: RedisDB): - _cache = WikiCache(redis) - _service = WikiService(_cache) - return _service diff --git a/metadata/genshin.py b/metadata/genshin.py index 33aade94..5e54efc8 100644 --- a/metadata/genshin.py +++ b/metadata/genshin.py @@ -11,7 +11,7 @@ from utils.log import logger from utils.typedefs import StrOrInt -__all__ = [ +__all__ = ( "HONEY_DATA", "AVATAR_DATA", "WEAPON_DATA", @@ -23,7 +23,7 @@ "Data", "weapon_to_game_id", "avatar_to_game_id", -] +) K = TypeVar("K") V = TypeVar("V") @@ -46,7 +46,8 @@ def data(self) -> dict[K, V]: path = data_dir.joinpath(self._file_name).with_suffix(".json") if not path.exists(): logger.error( - f'暂未找到名为 "{self._file_name}.json" 的 metadata , ' "请先使用 [yellow bold]/refresh_metadata[/] 命令下载", + '暂未找到名为 "%s" 的 metadata , ' "请先使用 [yellow bold]/refresh_metadata[/] 命令下载", + self._file_name, extra={"markup": True}, ) self._dict = {} diff --git a/metadata/pool/pool.py b/metadata/pool/pool.py index 163e9169..8b09eea0 100644 --- a/metadata/pool/pool.py +++ b/metadata/pool/pool.py @@ -6,7 +6,8 @@ def get_pool_by_id(pool_type): if pool_type == 200: return POOL_200 - elif pool_type == 301: + if pool_type == 301: return POOL_301 - elif pool_type == 302: + if pool_type == 302: return POOL_302 + return None diff --git a/metadata/scripts/metadatas.py b/metadata/scripts/metadatas.py index 8f6fde64..39f7723a 100644 --- a/metadata/scripts/metadatas.py +++ b/metadata/scripts/metadatas.py @@ -4,7 +4,7 @@ import ujson as json from aiofiles import open as async_open -from httpx import AsyncClient, RemoteProtocolError, Response, URL +from httpx import URL, AsyncClient, RemoteProtocolError, Response from utils.const import AMBR_HOST, PROJECT_ROOT from utils.log import logger @@ -68,7 +68,7 @@ async def update_metadata_from_github(overwrite: bool = True): if line == " {\n": started = True continue - elif line in [" },\n", " }\n"]: + if line in [" },\n", " }\n"]: started = False if any("MATERIAL_NAMECARD" in x for x in cell): material_json_data.append(json.loads("{" + "".join(cell) + "}")) diff --git a/metadata/scripts/paimon_moe.py b/metadata/scripts/paimon_moe.py index 472ad461..db734c27 100644 --- a/metadata/scripts/paimon_moe.py +++ b/metadata/scripts/paimon_moe.py @@ -1,6 +1,7 @@ -from utils.const import PROJECT_ROOT from aiofiles import open as async_open -from httpx import AsyncClient, URL +from httpx import URL, AsyncClient + +from utils.const import PROJECT_ROOT GACHA_LOG_PAIMON_MOE_PATH = PROJECT_ROOT.joinpath("metadata/data/paimon_moe_zh.json") diff --git a/modules/apihelper/__init__.py b/modules/apihelper/__init__.py index 8b137891..e69de29b 100644 --- a/modules/apihelper/__init__.py +++ b/modules/apihelper/__init__.py @@ -1 +0,0 @@ - diff --git a/modules/apihelper/client/base/hyperionrequest.py b/modules/apihelper/client/base/hyperionrequest.py index dacb2939..1495a548 100644 --- a/modules/apihelper/client/base/hyperionrequest.py +++ b/modules/apihelper/client/base/hyperionrequest.py @@ -33,8 +33,7 @@ async def get( if return_code != 0: if message is None: raise ResponseException(message=f"response error in return code: {return_code}") - else: - raise ResponseException(response=json_data) + raise ResponseException(response=json_data) if not re_json_data and data is not None: return data return json_data @@ -61,8 +60,7 @@ async def post( if return_code != 0: if message is None: raise ResponseException(message=f"response error in return code: {return_code}") - else: - raise ResponseException(response=json_data) + raise ResponseException(response=json_data) if not re_json_data and data is not None: return data return json_data diff --git a/modules/apihelper/client/components/__init__.py b/modules/apihelper/client/components/__init__.py index 8b137891..e69de29b 100644 --- a/modules/apihelper/client/components/__init__.py +++ b/modules/apihelper/client/components/__init__.py @@ -1 +0,0 @@ - diff --git a/modules/apihelper/client/components/authclient.py b/modules/apihelper/client/components/authclient.py index 9ad54f6a..f0a938b6 100644 --- a/modules/apihelper/client/components/authclient.py +++ b/modules/apihelper/client/components/authclient.py @@ -5,8 +5,9 @@ from io import BytesIO from string import ascii_letters, digits -from typing import Dict, Union, Optional, Tuple, Any +from typing import Dict, Union, Optional from httpx import AsyncClient +from qrcode.image.pure import PyPNGImage from ...logger import logger from ...models.genshin.cookies import CookiesModel @@ -68,9 +69,8 @@ async def get_stoken_by_login_ticket(self) -> bool: for i in res_data: name = i.get("name") token = i.get("token") - if name and token: - if hasattr(self.cookies, name): - setattr(self.cookies, name, token) + if name and token and hasattr(self.cookies, name): + setattr(self.cookies, name, token) if self.cookies.stoken: if self.cookies.stuid: self.cookies.stuid = self.user_id @@ -221,7 +221,7 @@ def generate_qrcode(url: str) -> bytes: qr = qrcode.QRCode(version=1, error_correction=qrcode.constants.ERROR_CORRECT_L, box_size=10, border=4) qr.add_data(url) qr.make(fit=True) - img = qr.make_image(fill_color="black", back_color="white") + img = qr.make_image(image_factory=PyPNGImage, fill_color="black", back_color="white") bio = BytesIO() img.save(bio) return bio.getvalue() diff --git a/modules/apihelper/client/components/calendar.py b/modules/apihelper/client/components/calendar.py index 39aeccac..00c866b0 100644 --- a/modules/apihelper/client/components/calendar.py +++ b/modules/apihelper/client/components/calendar.py @@ -1,10 +1,9 @@ import re from datetime import datetime, timedelta -from typing import List, Tuple, Optional, Dict, Union +from typing import List, Tuple, Optional, Dict, Union, TYPE_CHECKING from httpx import AsyncClient -from core.base.assets import AssetsService from metadata.genshin import AVATAR_DATA from metadata.shortname import roleToId from modules.apihelper.client.components.remote import Remote @@ -12,6 +11,10 @@ from modules.wiki.character import Character +if TYPE_CHECKING: + from core.dependence.assets import AssetsService + + class Calendar: """原神活动日历""" @@ -189,7 +192,7 @@ def parse_label(self, act: FinalAct, is_act: bool, s_date: datetime, e_date: dat act.label = label @staticmethod - async def parse_type(act: FinalAct, assets: AssetsService) -> None: + async def parse_type(act: FinalAct, assets: "AssetsService") -> None: """解析活动类型""" if "神铸赋形" in act.title: act.type = ActEnum.weapon @@ -216,7 +219,7 @@ async def get_list( total_range: timedelta, time_map: Dict[str, ActTime], is_act: bool, - assets: AssetsService, + assets: "AssetsService", ) -> Optional[FinalAct]: """获取活动列表""" act = FinalAct( @@ -269,7 +272,7 @@ def end(date: datetime, up: bool = False): return ret async def get_birthday_char( - self, date_list: List[Date], assets: AssetsService + self, date_list: List[Date], assets: "AssetsService" ) -> Tuple[int, Dict[str, Dict[str, List[BirthChar]]]]: """获取生日角色""" birthday_list = await self.async_gen_birthday_list() @@ -322,7 +325,7 @@ def merge_list(self, target: List[FinalAct]) -> Tuple[List[List[FinalAct]], int, ret.append([li]) return ret, char_count, char_old - async def get_photo_data(self, assets: AssetsService) -> Dict: + async def get_photo_data(self, assets: "AssetsService") -> Dict: """获取数据""" now = self.get_now_hour() list_data, time_map = await self.req_cal_data() diff --git a/modules/apihelper/client/components/signin.py b/modules/apihelper/client/components/signin.py new file mode 100644 index 00000000..e69de29b diff --git a/modules/apihelper/error.py b/modules/apihelper/error.py index 94c38f25..42c29cfd 100644 --- a/modules/apihelper/error.py +++ b/modules/apihelper/error.py @@ -1,4 +1,4 @@ -from typing import Mapping, Any, Optional +from typing import Any, Mapping, Optional class APIHelperException(Exception): diff --git a/modules/apihelper/models/genshin/calendar.py b/modules/apihelper/models/genshin/calendar.py index f951c99d..89469307 100644 --- a/modules/apihelper/models/genshin/calendar.py +++ b/modules/apihelper/models/genshin/calendar.py @@ -8,7 +8,7 @@ class Date(BaseModel): """日历日期""" month: int - date: List[int] + date: List[int] # skipcq: PTC-W0052 week: List[str] is_today: List[bool] diff --git a/modules/apihelper/models/genshin/hyperion.py b/modules/apihelper/models/genshin/hyperion.py index a396e229..0d9cba74 100644 --- a/modules/apihelper/models/genshin/hyperion.py +++ b/modules/apihelper/models/genshin/hyperion.py @@ -1,6 +1,6 @@ -import imghdr -from typing import List, Any, Union +from typing import Any, List, Union, Optional +from PIL import Image, UnidentifiedImageError from pydantic import BaseModel, PrivateAttr __all__ = ("ArtworkImage", "PostInfo") @@ -16,8 +16,14 @@ class ArtworkImage(BaseModel): is_error: bool = False @property - def format(self) -> str: - return "" if self.is_error else (imghdr.what(None, self.data) or self.ext) + def format(self) -> Optional[str]: + if not self.is_error: + try: + with Image.open(self.data) as im: + return im.format + except UnidentifiedImageError: + pass + return None def input_media(self, *args, **kwargs) -> Union[None, InputMediaDocument, InputMediaPhoto, InputMediaVideo]: file_type = self.format @@ -34,8 +40,8 @@ class PostInfo(BaseModel): user_uid: int subject: str image_urls: List[str] - video_urls: List[str] created_at: int + video_urls: List[str] def __init__(self, _data: dict, **data: Any): super().__init__(**data) diff --git a/modules/apihelper/typedefs.py b/modules/apihelper/typedefs.py index 9fc972b1..3504cf1a 100644 --- a/modules/apihelper/typedefs.py +++ b/modules/apihelper/typedefs.py @@ -1,4 +1,4 @@ -from typing import Dict, Any +from typing import Any, Dict __all__ = ("POST_DATA", "JSON_DATA") diff --git a/modules/apihelper/utility/helpers.py b/modules/apihelper/utility/helpers.py index 4828e631..601d4766 100644 --- a/modules/apihelper/utility/helpers.py +++ b/modules/apihelper/utility/helpers.py @@ -4,7 +4,7 @@ import string import time import uuid -from typing import Mapping, Any, Optional +from typing import Any, Mapping, Optional __all__ = ("get_device_id", "hex_digest", "get_ds", "get_recognize_server", "get_ua") @@ -78,8 +78,7 @@ def get_recognize_server(uid: int) -> str: server = RECOGNIZE_SERVER.get(str(uid)[0]) if server: return server - else: - raise TypeError(f"UID {uid} isn't associated with any recognize server") + raise TypeError(f"UID {uid} isn't associated with any recognize server") def get_ua(device: str = "Paimon Build", version: str = "2.36.1"): diff --git a/modules/errorpush/__init__.py b/modules/errorpush/__init__.py index 8126606e..9da1e789 100644 --- a/modules/errorpush/__init__.py +++ b/modules/errorpush/__init__.py @@ -1,5 +1,4 @@ from .pb import PbClient, PbClientException from .sentry import SentryClient, SentryClientException - __all__ = ["PbClient", "PbClientException", "SentryClient", "SentryClientException"] diff --git a/modules/gacha/banner.py b/modules/gacha/banner.py index 2185647e..76611dc2 100644 --- a/modules/gacha/banner.py +++ b/modules/gacha/banner.py @@ -53,10 +53,9 @@ class GachaBanner(BaseModel): def get_weight(self, rarity: int, pity: int) -> int: if rarity == 4: return lerp(pity, self.weight4) - elif rarity == 5: + if rarity == 5: return lerp(pity, self.weight5) - else: - raise GachaIllegalArgument + raise GachaIllegalArgument def has_epitomized(self): return self.banner_type == BannerType.WEAPON @@ -64,17 +63,15 @@ def has_epitomized(self): def get_event_chance(self, rarity: int) -> int: if rarity == 4: return self.event_chance4 - elif rarity == 5: + if rarity == 5: return self.event_chance5 - elif self.event_chance >= -1: + if self.event_chance >= -1: return self.event_chance - else: - raise GachaIllegalArgument + raise GachaIllegalArgument def get_pool_balance_weight(self, rarity: int, pity: int) -> int: if rarity == 4: return lerp(pity, self.pool_balance_weights4) - elif rarity == 5: + if rarity == 5: return lerp(pity, self.pool_balance_weights5) - else: - raise GachaIllegalArgument + raise GachaIllegalArgument diff --git a/modules/gacha/player/banner.py b/modules/gacha/player/banner.py index 9e54a6f9..504eb8f3 100644 --- a/modules/gacha/player/banner.py +++ b/modules/gacha/player/banner.py @@ -29,10 +29,9 @@ def inc_pity_all(self): def get_failed_featured_item_pulls(self, rarity: int) -> int: if rarity == 4: return self.failed_featured4_item_pulls - elif rarity == 5: + if rarity == 5: return self.failed_featured_item_pulls - else: - raise GachaIllegalArgument + raise GachaIllegalArgument def set_failed_featured_item_pulls(self, rarity: int, amount: int): if rarity == 4: @@ -53,7 +52,7 @@ def add_failed_featured_item_pulls(self, rarity: int, amount: int): def get_pity_pool(self, rarity: int, param: int) -> int: if rarity == 4: return self.pity4_pool1 if param == 1 else self.pity4_pool2 - elif rarity == 5: + if rarity == 5: return self.pity5_pool1 if param == 1 else self.pity5_pool2 raise GachaIllegalArgument diff --git a/modules/gacha/player/info.py b/modules/gacha/player/info.py index 03962567..838355b1 100644 --- a/modules/gacha/player/info.py +++ b/modules/gacha/player/info.py @@ -2,7 +2,7 @@ from pydantic import BaseModel -from modules.gacha.banner import GachaBanner, BannerType +from modules.gacha.banner import BannerType, GachaBanner from modules.gacha.player.banner import PlayerGachaBannerInfo @@ -25,6 +25,6 @@ def __init__(self, **kwargs): def get_banner_info(self, banner: GachaBanner) -> PlayerGachaBannerInfo: if banner.banner_type == BannerType.EVENT: return self.event_character_banner - elif banner.banner_type == BannerType.WEAPON: + if banner.banner_type == BannerType.WEAPON: return self.event_weapon_banner return self.standard_banner diff --git a/modules/gacha/system.py b/modules/gacha/system.py index 24a4bb67..9b85729b 100644 --- a/modules/gacha/system.py +++ b/modules/gacha/system.py @@ -1,10 +1,9 @@ import random -from typing import Tuple, List +from typing import List, Tuple from modules.gacha.banner import GachaBanner -from modules.gacha.error import GachaInvalidTimes, GachaIllegalArgument -from modules.gacha.player.info import PlayerGachaBannerInfo -from modules.gacha.player.info import PlayerGachaInfo +from modules.gacha.error import GachaIllegalArgument, GachaInvalidTimes +from modules.gacha.player.info import PlayerGachaBannerInfo, PlayerGachaInfo from modules.gacha.pool import BannerPool @@ -58,13 +57,12 @@ def do_pull(self, banner: GachaBanner, gacha_info: PlayerGachaBannerInfo, pools: return self.do_rare_pull( pools.rate_up_items5, pools.fallback_items5_pool1, pools.fallback_items5_pool2, 5, banner, gacha_info ) - elif leval_won == 4: + if leval_won == 4: gacha_info.pity4 = 0 return self.do_rare_pull( pools.rate_up_items4, pools.fallback_items4_pool1, pools.fallback_items4_pool2, 4, banner, gacha_info ) - else: - return self.get_random(banner.fallback_items3) + return self.get_random(banner.fallback_items3) @staticmethod def draw_roulette(weights, cutoff: int) -> int: @@ -126,17 +124,15 @@ def do_fallback_rare_pull( return self.get_random( self.fallback_items5_pool2_default if rarity == 5 else self.fallback_items4_pool2_default ) - else: - return self.get_random(fallback2) - elif len(fallback2) < 1: + return self.get_random(fallback2) + if len(fallback2) < 1: return self.get_random(fallback1) + pity_pool1 = banner.get_pool_balance_weight(rarity, gacha_info.get_pity_pool(rarity, 1)) + pity_pool2 = banner.get_pool_balance_weight(rarity, gacha_info.get_pity_pool(rarity, 2)) + if pity_pool1 >= pity_pool2: + chosen_pool = 1 + self.draw_roulette((pity_pool1, pity_pool2), 10000) else: - pity_pool1 = banner.get_pool_balance_weight(rarity, gacha_info.get_pity_pool(rarity, 1)) - pity_pool2 = banner.get_pool_balance_weight(rarity, gacha_info.get_pity_pool(rarity, 2)) - if pity_pool1 >= pity_pool2: - chosen_pool = 1 + self.draw_roulette((pity_pool1, pity_pool2), 10000) - else: - chosen_pool = 2 - self.draw_roulette((pity_pool2, pity_pool1), 10000) + chosen_pool = 2 - self.draw_roulette((pity_pool2, pity_pool1), 10000) if chosen_pool == 1: gacha_info.set_pity_pool(rarity, 1, 0) return self.get_random(fallback1) diff --git a/modules/gacha/utils.py b/modules/gacha/utils.py index e9558f8c..b982cdf7 100644 --- a/modules/gacha/utils.py +++ b/modules/gacha/utils.py @@ -6,7 +6,7 @@ def lerp(x: int, x_y_array) -> int: with contextlib.suppress(KeyError, IndexError): if x <= x_y_array[0][0]: return x_y_array[0][1] - elif x >= x_y_array[-1][0]: + if x >= x_y_array[-1][0]: return x_y_array[-1][1] for index, _ in enumerate(x_y_array): if x == x_y_array[index + 1][0]: diff --git a/modules/gacha_log/log.py b/modules/gacha_log/log.py index b0206e6b..fc379271 100644 --- a/modules/gacha_log/log.py +++ b/modules/gacha_log/log.py @@ -8,23 +8,23 @@ from typing import Dict, IO, List, Optional, Tuple, Union import aiofiles -from genshin import Client, InvalidAuthkey, AuthkeyTimeout +from genshin import AuthkeyTimeout, Client, InvalidAuthkey from genshin.models import BannerType from openpyxl import load_workbook -from core.base.assets import AssetsService +from core.dependence.assets import AssetsService from metadata.pool.pool import get_pool_by_id from metadata.shortname import roleToId, weaponToId from modules.gacha_log.const import GACHA_TYPE_LIST, PAIMONMOE_VERSION from modules.gacha_log.error import ( GachaLogAccountNotFound, + GachaLogAuthkeyTimeout, GachaLogException, GachaLogFileError, GachaLogInvalidAuthkey, GachaLogMixedProvider, GachaLogNotFound, PaimonMoeGachaLogFileError, - GachaLogAuthkeyTimeout, ) from modules.gacha_log.models import ( FiveStarItem, @@ -271,15 +271,15 @@ async def get_gacha_log_data(self, user_id: int, client: Client, authkey: str) - def check_avatar_up(name: str, gacha_time: datetime.datetime) -> bool: if name in {"莫娜", "七七", "迪卢克", "琴"}: return False - elif name == "刻晴": + if name == "刻晴": start_time = datetime.datetime.strptime("2021-02-17 18:00:00", "%Y-%m-%d %H:%M:%S") end_time = datetime.datetime.strptime("2021-03-02 15:59:59", "%Y-%m-%d %H:%M:%S") - if not (start_time < gacha_time < end_time): + if not start_time < gacha_time < end_time: return False elif name == "提纳里": start_time = datetime.datetime.strptime("2022-08-24 06:00:00", "%Y-%m-%d %H:%M:%S") end_time = datetime.datetime.strptime("2022-09-09 17:59:59", "%Y-%m-%d %H:%M:%S") - if not (start_time < gacha_time < end_time): + if not start_time < gacha_time < end_time: return False return True @@ -475,14 +475,13 @@ def count_fortune(pool_name: str, summon_data, weapon: bool = False): num = j.get("num", 0) if num == 0: return pool_name - elif num <= data[0]: + if num <= data[0]: return f"{pool_name} · 欧" - elif num <= data[1]: + if num <= data[1]: return f"{pool_name} · 吉" - elif num <= data[2]: + if num <= data[2]: return f"{pool_name} · 普通" - else: - return f"{pool_name} · 非" + return f"{pool_name} · 非" return pool_name async def get_analysis(self, user_id: int, client: Client, pool: BannerType, assets: AssetsService): diff --git a/modules/gacha_log/models.py b/modules/gacha_log/models.py index 8b1085fc..805084fc 100644 --- a/modules/gacha_log/models.py +++ b/modules/gacha_log/models.py @@ -1,10 +1,10 @@ import datetime from enum import Enum -from typing import List, Dict, Union, Any +from typing import Any, Dict, List, Union from pydantic import BaseModel, validator -from metadata.shortname import roleToId, weaponToId, not_real_roles +from metadata.shortname import not_real_roles, roleToId, weaponToId from modules.gacha_log.const import UIGF_VERSION diff --git a/modules/wiki/base.py b/modules/wiki/base.py index 423d789d..3b922388 100644 --- a/modules/wiki/base.py +++ b/modules/wiki/base.py @@ -7,15 +7,18 @@ from typing import AsyncIterator, ClassVar, List, Optional, Tuple, Union import anyio -import ujson as json from bs4 import BeautifulSoup -from httpx import AsyncClient, HTTPError, Response, URL -from pydantic import ( - BaseConfig as PydanticBaseConfig, - BaseModel as PydanticBaseModel, -) +from httpx import URL, AsyncClient, HTTPError, Response +from pydantic import BaseConfig as PydanticBaseConfig +from pydantic import BaseModel as PydanticBaseModel from typing_extensions import Self +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib + + __all__ = ["Model", "WikiModel", "HONEY_HOST"] HONEY_HOST = URL("https://genshin.honeyhunterworld.com/") @@ -27,12 +30,12 @@ class Model(PydanticBaseModel): def __new__(cls, *args, **kwargs): # 让每次new的时候都解析 cls.update_forward_refs() - return super(Model, cls).__new__(cls) + return super(Model, cls).__new__(cls) # pylint: disable=E1120 class Config(PydanticBaseConfig): # 使用 ujson 作为解析库 - json_dumps = json.dumps - json_loads = json.loads + json_dumps = jsonlib.dumps + json_loads = jsonlib.loads class WikiModel(Model): @@ -203,7 +206,7 @@ async def task(page: URL): response = await cls._client_get(page) # 从页面中获取对应的 chaos data (未处理的json格式字符串) chaos_data = re.findall(r"sortable_data\.push\((.*?)\);\s*sortable_cur_page", response.text)[0] - json_data = json.loads(chaos_data) # 转为 json + json_data = jsonlib.loads(chaos_data) # 转为 json for data in json_data: # 遍历 json data_name = re.findall(r">(.*)<", data[1])[0] # 获取 Model 的名称 if with_url: # 如果需要返回对应的 url diff --git a/modules/wiki/character.py b/modules/wiki/character.py index e38dc008..6820c77e 100644 --- a/modules/wiki/character.py +++ b/modules/wiki/character.py @@ -4,8 +4,7 @@ from bs4 import BeautifulSoup from httpx import URL -from modules.wiki.base import Model, HONEY_HOST -from modules.wiki.base import WikiModel +from modules.wiki.base import HONEY_HOST, Model, WikiModel from modules.wiki.other import Association, Element, WeaponType diff --git a/modules/wiki/material.py b/modules/wiki/material.py index cfdbdcd8..ce9b1932 100644 --- a/modules/wiki/material.py +++ b/modules/wiki/material.py @@ -48,6 +48,7 @@ def get_table_row(target: str): for row in table_rows: if target in row.find("td").text: return row.find_all("td")[-1] + return None def get_table_text(row_num: int) -> str: """一个便捷函数,用于返回表格对应行的最后一个单元格中的文本""" diff --git a/modules/wiki/other.py b/modules/wiki/other.py index 4677f048..5fa76bf0 100644 --- a/modules/wiki/other.py +++ b/modules/wiki/other.py @@ -102,6 +102,7 @@ def convert(cls, string: str) -> Optional[Self]: for k, v in _ATTR_TYPE_MAP.items(): if string == k or string in v or string.upper() == k: return cls[k] + return None _ASSOCIATION_MAP = { diff --git a/modules/wiki/weapon.py b/modules/wiki/weapon.py index 65d6feac..b48e5b65 100644 --- a/modules/wiki/weapon.py +++ b/modules/wiki/weapon.py @@ -5,7 +5,7 @@ from bs4 import BeautifulSoup from httpx import URL -from modules.wiki.base import Model, HONEY_HOST, WikiModel +from modules.wiki.base import HONEY_HOST, Model, WikiModel from modules.wiki.other import AttributeType, WeaponType __all__ = ["Weapon", "WeaponAffix", "WeaponAttribute"] @@ -135,8 +135,7 @@ async def get_name_list(cls, *, with_url: bool = False) -> List[Union[str, Tuple name_list = [i async for i in cls._name_list_generator(with_url=with_url)] if with_url: return [(i[0], list(i[1])[0][1]) for i in itertools.groupby(name_list, lambda x: x[0])] - else: - return [i[0] for i in itertools.groupby(name_list, lambda x: x)] + return [i[0] for i in itertools.groupby(name_list, lambda x: x)] @property def icon(self) -> WeaponIcon: diff --git a/plugins/__init__.py b/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/account/account.py b/plugins/account/account.py new file mode 100644 index 00000000..3c6af508 --- /dev/null +++ b/plugins/account/account.py @@ -0,0 +1,190 @@ +from datetime import datetime +from typing import Optional + +import genshin +from genshin import DataNotPublic, GenshinException, types +from genshin.models import RecordCard +from telegram import ReplyKeyboardMarkup, ReplyKeyboardRemove, TelegramObject, Update +from telegram.ext import CallbackContext, ConversationHandler, filters +from telegram.helpers import escape_markdown + +from core.basemodel import RegionEnum +from core.plugin import Plugin, conversation, handler +from core.services.cookies.error import TooManyRequestPublicCookies +from core.services.cookies.services import CookiesService, PublicCookiesService +from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel +from core.services.players.services import PlayersService +from utils.log import logger + +__all__ = ("BindAccountPlugin",) + + +class BindAccountPluginData(TelegramObject): + player: Optional[Player] = None + record_card: Optional[RecordCard] = None + region: RegionEnum = RegionEnum.HYPERION + account_id: int = 0 + # player_id: int = 0 + + def reset(self): + self.player = None + self.region = RegionEnum.NULL + self.account_id = 0 + self.record_card = None + + +CHECK_SERVER, CHECK_UID, COMMAND_RESULT = range(10100, 10103) + + +class BindAccountPlugin(Plugin.Conversation): + """UID用户绑定""" + + def __init__( + self, + players_service: PlayersService = None, + cookies_service: CookiesService = None, + public_cookies_service: PublicCookiesService = None, + ): + self.public_cookies_service = public_cookies_service + self.cookies_service = cookies_service + self.players_service = players_service + + @conversation.entry_point + @handler.command(command="setuid", filters=filters.ChatType.PRIVATE, block=True) + async def command_start(self, update: Update, context: CallbackContext) -> int: + user = update.effective_user + message = update.effective_message + logger.info("用户 %s[%s] 绑定账号命令请求", user.full_name, user.id) + bind_account_plugin_data: BindAccountPluginData = context.chat_data.get("bind_account_plugin_data") + if bind_account_plugin_data is None: + bind_account_plugin_data = BindAccountPluginData() + context.chat_data["bind_account_plugin_data"] = bind_account_plugin_data + else: + bind_account_plugin_data.reset() + text = ( + f"你好 {user.mention_markdown_v2()} " + f'{escape_markdown("!请输入通行证ID(非游戏玩家ID),BOT将会通过通行证UID查找游戏UID。请选择要绑定的服务器!或回复退出取消操作")}' + ) + reply_keyboard = [["米游社", "HoYoLab"], ["退出"]] + await message.reply_markdown_v2(text, reply_markup=ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)) + return CHECK_SERVER + + @conversation.state(state=CHECK_SERVER) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + async def check_server(self, update: Update, context: CallbackContext) -> int: + message = update.effective_message + bind_account_plugin_data: BindAccountPluginData = context.chat_data.get("bind_account_plugin_data") + if message.text == "退出": + await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + if message.text == "米游社": + bind_account_plugin_data.region = RegionEnum.HYPERION + elif message.text == "HoYoLab": + bind_account_plugin_data.region = RegionEnum.HOYOLAB + else: + await message.reply_text("选择错误,请重新选择") + return CHECK_SERVER + await message.reply_text("请输入你的通行证ID(非玩家ID)", reply_markup=ReplyKeyboardRemove()) + return CHECK_UID + + @conversation.state(state=CHECK_UID) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + async def check_cookies(self, update: Update, context: CallbackContext) -> int: + user = update.effective_user + message = update.effective_message + bind_account_plugin_data: BindAccountPluginData = context.chat_data.get("bind_account_plugin_data") + region = bind_account_plugin_data.region + if message.text == "退出": + await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + try: + account_id = int(message.text) + except ValueError: + await message.reply_text("ID 格式有误,请检查", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + try: + cookies = await self.public_cookies_service.get_cookies(user.id, region) + except TooManyRequestPublicCookies: + await message.reply_text("用户查询次数过多,请稍后重试", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + if region == RegionEnum.HYPERION: + client = genshin.Client(cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.CHINESE) + elif region == RegionEnum.HOYOLAB: + client = genshin.Client( + cookies=cookies.data, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" + ) + else: + return ConversationHandler.END + try: + record_card = await client.get_record_card(account_id) + except DataNotPublic: + await message.reply_text("角色未公开", reply_markup=ReplyKeyboardRemove()) + logger.warning("获取账号信息发生错误 %s 账户信息未公开", account_id) + return ConversationHandler.END + except GenshinException as exc: + await message.reply_text("获取账号信息发生错误", reply_markup=ReplyKeyboardRemove()) + logger.error("获取账号信息发生错误") + logger.exception(exc) + return ConversationHandler.END + if record_card.game != types.Game.GENSHIN: + await message.reply_text("角色信息查询返回非原神游戏信息," "请设置展示主界面为原神", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + player_info = await self.players_service.get( + user.id, player_id=record_card.uid, region=bind_account_plugin_data.region + ) + if player_info: + await message.reply_text("你已经绑定该账号") + return ConversationHandler.END + bind_account_plugin_data.account_id = account_id + reply_keyboard = [["确认", "退出"]] + await message.reply_text("获取角色基础信息成功,请检查是否正确!") + logger.info("用户 %s[%s] 获取账号 %s[%s] 信息成功", user.full_name, user.id, record_card.nickname, record_card.uid) + text = ( + f"*角色信息*\n" + f"角色名称:{escape_markdown(record_card.nickname, version=2)}\n" + f"角色等级:{record_card.level}\n" + f"UID:`{record_card.uid}`\n" + f"服务器名称:`{record_card.server_name}`\n" + ) + bind_account_plugin_data.record_card = record_card + await message.reply_markdown_v2(text, reply_markup=ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)) + return COMMAND_RESULT + + @conversation.state(state=COMMAND_RESULT) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + async def command_result(self, update: Update, context: CallbackContext) -> int: + user = update.effective_user + message = update.effective_message + bind_account_plugin_data: BindAccountPluginData = context.chat_data.get("bind_account_plugin_data") + if message.text == "退出": + await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + if message.text == "确认": + record_card = bind_account_plugin_data.record_card + is_chosen = True + player_info = await self.players_service.get_player(user.id) # 寻找主账号 + if player_info.is_chosen: + is_chosen = False + player = Player( + user_id=user.id, + account_id=bind_account_plugin_data.account_id, + player_id=record_card.uid, + region=bind_account_plugin_data.region, + is_chosen=is_chosen, # todo 多账号 + ) + await self.players_service.add(player) + player_info = await self.player_info_service.get(player) + if player_info is None: + player_info = PlayerInfoSQLModel( + user_id=player.user_id, + player_id=player.player_id, + nickname=record_card.nickname, + create_time=datetime.now(), + is_update=True, + ) # 不添加更新时间 + await self.player_info_service.add(player_info) + logger.success("用户 %s[%s] 绑定UID账号成功", user.full_name, user.id) + await message.reply_text("保存成功", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END + await message.reply_text("回复错误,请重新输入") + return COMMAND_RESULT diff --git a/plugins/genshin/cookies.py b/plugins/account/cookies.py similarity index 61% rename from plugins/genshin/cookies.py rename to plugins/account/cookies.py index 5eafd3c0..0141dc1f 100644 --- a/plugins/genshin/cookies.py +++ b/plugins/account/cookies.py @@ -1,4 +1,5 @@ import contextlib +from datetime import datetime from typing import Dict, Optional import genshin @@ -10,38 +11,56 @@ from telegram.ext import CallbackContext, ConversationHandler, filters from telegram.helpers import escape_markdown -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService +from core.basemodel import RegionEnum from core.plugin import Plugin, conversation, handler -from core.user.error import UserNotFoundError -from core.user.models import User -from core.user.services import UserService +from core.services.cookies.models import CookiesDataBase as Cookies, CookiesStatusEnum +from core.services.cookies.services import CookiesService +from core.services.players.models import PlayersDataBase as Player, PlayerInfoSQLModel +from core.services.players.services import PlayersService, PlayerInfoService from modules.apihelper.client.components.authclient import AuthClient from modules.apihelper.models.genshin.cookies import CookiesModel -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger -from utils.models.base import RegionEnum +__all__ = ("AccountCookiesPlugin",) -class AddUserCommandData(TelegramObject): - user: Optional[User] = None + +class AccountIdNotFound(Exception): + pass + + +class AccountCookiesPluginData(TelegramObject): + player: Optional[Player] = None + cookies_data_base: Optional[Cookies] = None region: RegionEnum = RegionEnum.NULL cookies: dict = {} - game_uid: int = 0 - phone: int = 0 + account_id: int = 0 + # player_id: int = 0 + genshin_account: Optional[GenshinAccount] = None + + def reset(self): + self.player = None + self.cookies_data_base = None + self.region = RegionEnum.NULL + self.cookies = {} + self.account_id = 0 + self.genshin_account = None CHECK_SERVER, INPUT_COOKIES, COMMAND_RESULT = range(10100, 10103) -class SetUserCookies(Plugin.Conversation, BasePlugin.Conversation): +class AccountCookiesPlugin(Plugin.Conversation): """Cookie绑定""" - def __init__(self, user_service: UserService = None, cookies_service: CookiesService = None): + def __init__( + self, + players_service: PlayersService = None, + cookies_service: CookiesService = None, + player_info_service: PlayerInfoService = None, + ): self.cookies_service = cookies_service - self.user_service = user_service + self.players_service = players_service + self.player_info_service = player_info_service # noinspection SpellCheckingInspection @staticmethod @@ -59,14 +78,16 @@ def parse_cookie(cookie: Dict[str, str]) -> Dict[str, str]: @conversation.entry_point @handler.command(command="setcookie", filters=filters.ChatType.PRIVATE, block=True) @handler.command(command="setcookies", filters=filters.ChatType.PRIVATE, block=True) - @restricts() - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> int: user = update.effective_user message = update.effective_message logger.info("用户 %s[%s] 绑定账号命令请求", user.full_name, user.id) - cookies_command_data = AddUserCommandData() - context.chat_data["add_user_command_data"] = cookies_command_data + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") + if account_cookies_plugin_data is None: + account_cookies_plugin_data = AccountCookiesPluginData() + context.chat_data["account_cookies_plugin_data"] = account_cookies_plugin_data + else: + account_cookies_plugin_data.reset() text = f'你好 {user.mention_markdown_v2()} {escape_markdown("!请选择要绑定的服务器!或回复退出取消操作")}' reply_keyboard = [["米游社", "HoYoLab"], ["退出"]] @@ -74,50 +95,38 @@ async def command_start(self, update: Update, context: CallbackContext) -> int: return CHECK_SERVER @conversation.entry_point - @handler.command("qlogin", filters=filters.ChatType.PRIVATE, block=False) - @error_callable + @handler.command("qlogin", filters=filters.ChatType.PRIVATE, block=True) async def qrcode_login(self, update: Update, context: CallbackContext): user = update.effective_user message = update.effective_message logger.info("用户 %s[%s] 绑定账号命令请求", user.full_name, user.id) - add_user_command_data = AddUserCommandData() - context.chat_data["add_user_command_data"] = add_user_command_data - add_user_command_data.region = RegionEnum.HYPERION - try: - user_info = await self.user_service.get_user_by_id(user.id) - except UserNotFoundError: - user_info = None - if user_info is not None: - try: - await self.cookies_service.get_cookies(user.id, RegionEnum.HYPERION) - except CookiesNotFoundError: - await message.reply_text("你已经绑定UID,如果继续操作会覆盖当前UID。") - else: - await message.reply_text("警告,你已经绑定Cookie,如果继续操作会覆盖当前Cookie。") - add_user_command_data.user = user_info + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") + if account_cookies_plugin_data is None: + account_cookies_plugin_data = AccountCookiesPluginData() + context.chat_data["account_cookies_plugin_data"] = account_cookies_plugin_data + else: + account_cookies_plugin_data.reset() + account_cookies_plugin_data.region = RegionEnum.HYPERION auth_client = AuthClient() url, ticket = await auth_client.create_qrcode_login() data = auth_client.generate_qrcode(url) text = f"你好 {user.mention_html()} !该绑定方法仅支持国服,请在3分钟内使用米游社扫码并确认进行绑定。" await message.reply_photo(data, caption=text, parse_mode=ParseMode.HTML) if await auth_client.check_qrcode_login(ticket): - add_user_command_data.cookies = auth_client.cookies.to_dict() + account_cookies_plugin_data.cookies = auth_client.cookies.to_dict() return await self.check_cookies(update, context) - else: - await message.reply_markdown_v2("可能是验证码已过期或者你没有同意授权,请重新发送命令进行绑定。") - return ConversationHandler.END + await message.reply_markdown_v2("可能是验证码已过期或者你没有同意授权,请重新发送命令进行绑定。") + return ConversationHandler.END @conversation.state(state=CHECK_SERVER) @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable async def check_server(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user message = update.effective_message - add_user_command_data: AddUserCommandData = context.chat_data.get("add_user_command_data") + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") if message.text == "退出": await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - elif message.text == "米游社": + if message.text == "米游社": region = RegionEnum.HYPERION bbs_url = "https://user.mihoyo.com/" bbs_name = "米游社" @@ -128,19 +137,7 @@ async def check_server(self, update: Update, context: CallbackContext) -> int: else: await message.reply_text("选择错误,请重新选择") return CHECK_SERVER - try: - user_info = await self.user_service.get_user_by_id(user.id) - except UserNotFoundError: - user_info = None - if user_info is not None: - try: - await self.cookies_service.get_cookies(user.id, region) - except CookiesNotFoundError: - await message.reply_text("你已经绑定UID,如果继续操作会覆盖当前UID。") - else: - await message.reply_text("警告,你已经绑定Cookie,如果继续操作会覆盖当前Cookie。") - add_user_command_data.user = user_info - add_user_command_data.region = region + account_cookies_plugin_data.region = region await message.reply_text(f"请输入{bbs_name}的Cookies!或回复退出取消操作", reply_markup=ReplyKeyboardRemove()) if bbs_name == "米游社": help_message = ( @@ -185,11 +182,10 @@ async def check_server(self, update: Update, context: CallbackContext) -> int: @conversation.state(state=INPUT_COOKIES) @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable async def input_cookies(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user - add_user_command_data: AddUserCommandData = context.chat_data.get("add_user_command_data") + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") if message.text == "退出": await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END @@ -212,18 +208,17 @@ async def input_cookies(self, update: Update, context: CallbackContext) -> int: logger.info("用户 %s[%s] Cookies格式有误", user.full_name, user.id) await message.reply_text("Cookies格式有误,请检查", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - add_user_command_data.cookies = cookies + account_cookies_plugin_data.cookies = cookies return await self.check_cookies(update, context) - @staticmethod - async def check_cookies(update: Update, context: CallbackContext) -> int: + async def check_cookies(self, update: Update, context: CallbackContext) -> int: user = update.effective_user message = update.effective_message - add_user_command_data: AddUserCommandData = context.chat_data.get("add_user_command_data") - cookies = CookiesModel(**add_user_command_data.cookies) - if add_user_command_data.region == RegionEnum.HYPERION: + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") + cookies = CookiesModel(**account_cookies_plugin_data.cookies) + if account_cookies_plugin_data.region == RegionEnum.HYPERION: client = genshin.Client(cookies=cookies.to_dict(), region=types.Region.CHINESE) - elif add_user_command_data.region == RegionEnum.HOYOLAB: + elif account_cookies_plugin_data.region == RegionEnum.HOYOLAB: client = genshin.Client(cookies=cookies.to_dict(), region=types.Region.OVERSEAS) else: logger.error("用户 %s[%s] region 异常", user.full_name, user.id) @@ -232,16 +227,15 @@ async def check_cookies(update: Update, context: CallbackContext) -> int: if not cookies.check(): await message.reply_text("检测到Cookie不完整,可能会出现问题。", reply_markup=ReplyKeyboardRemove()) try: - if client.cookie_manager.user_id is None: - if cookies.is_v2: - logger.info("检测到用户 %s[%s] 使用 V2 Cookie 正在尝试获取 account_id", user.full_name, user.id) - if client.region == types.Region.CHINESE: - account_info = await client.get_hoyolab_user() - account_id = account_info.hoyolab_id - cookies.set_v2_uid(account_id) - logger.success("获取用户 %s[%s] account_id[%s] 成功", user.full_name, user.id, account_id) - else: - logger.warning("用户 %s[%s] region[%s] 也许是不正确的", user.full_name, user.id, client.region.name) + if client.cookie_manager.user_id is None and cookies.is_v2: + logger.info("检测到用户 %s[%s] 使用 V2 Cookie 正在尝试获取 account_id", user.full_name, user.id) + if client.region == types.Region.CHINESE: + account_info = await client.get_hoyolab_user() + account_id = account_info.hoyolab_id + cookies.set_v2_uid(account_id) + logger.success("获取用户 %s[%s] account_id[%s] 成功", user.full_name, user.id, account_id) + else: + logger.warning("用户 %s[%s] region[%s] 也许是不正确的", user.full_name, user.id, client.region.name) genshin_accounts = await client.genshin_accounts() except DataNotPublic: logger.info("用户 %s[%s] 账号疑似被注销", user.full_name, user.id) @@ -259,6 +253,10 @@ async def check_cookies(update: Update, context: CallbackContext) -> int: f"获取账号信息发生错误,错误信息为 {exc.original},请检查Cookie或者账号是否正常", reply_markup=ReplyKeyboardRemove() ) return ConversationHandler.END + except AccountIdNotFound: + logger.info("用户 %s[%s] 无法获取账号ID", user.full_name, user.id) + await message.reply_text("无法获取账号ID,请检查Cookie是否正常", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END except (AttributeError, ValueError) as exc: logger.warning("用户 %s[%s] Cookies错误", user.full_name, user.id) logger.debug("用户 %s[%s] Cookies错误", user.full_name, user.id, exc_info=exc) @@ -274,74 +272,102 @@ async def check_cookies(update: Update, context: CallbackContext) -> int: if await auth_client.get_ltoken_by_stoken(): logger.success("用户 %s[%s] 绑定时获取 ltoken 成功", user.full_name, user.id) auth_client.cookies.remove_v2() - user_info: Optional[GenshinAccount] = None + genshin_account: Optional[GenshinAccount] = None level: int = 0 # todo : 多账号绑定 - for genshin_account in genshin_accounts: - if genshin_account.level >= level: # 获取账号等级最高的 - level = genshin_account.level - user_info = genshin_account - if user_info is None: + for temp in genshin_accounts: + if temp.level >= level: # 获取账号等级最高的 + level = temp.level + genshin_account = temp + if genshin_account is None: await message.reply_text("未找到原神账号,请确认账号信息无误。") return ConversationHandler.END - add_user_command_data.game_uid = user_info.uid + account_cookies_plugin_data.genshin_account = genshin_account + player_info = await self.players_service.get( + user.id, account_id=genshin_account.uid, region=account_cookies_plugin_data.region + ) + account_cookies_plugin_data.player = player_info + if player_info: + cookies_database = await self.cookies_service.get( + user.id, player_info.account_id, account_cookies_plugin_data.region + ) + if cookies_database: + account_cookies_plugin_data.cookies_data_base = cookies_database + await message.reply_text("警告,你已经绑定Cookie,如果继续操作会覆盖当前Cookie。") reply_keyboard = [["确认", "退出"]] await message.reply_text("获取角色基础信息成功,请检查是否正确!") - logger.info("用户 %s[%s] 获取账号 %s[%s] 信息成功", user.full_name, user.id, user_info.nickname, user_info.uid) + logger.info( + "用户 %s[%s] 获取账号 %s[%s] 信息成功", user.full_name, user.id, genshin_account.nickname, genshin_account.uid + ) text = ( f"*角色信息*\n" - f"角色名称:{escape_markdown(user_info.nickname, version=2)}\n" - f"角色等级:{user_info.level}\n" - f"UID:`{user_info.uid}`\n" - f"服务器名称:`{user_info.server_name}`\n" + f"角色名称:{escape_markdown(genshin_account.nickname, version=2)}\n" + f"角色等级:{genshin_account.level}\n" + f"UID:`{genshin_account.uid}`\n" + f"服务器名称:`{genshin_account.server_name}`\n" ) await message.reply_markdown_v2(text, reply_markup=ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)) - add_user_command_data.cookies = cookies.to_dict() + account_cookies_plugin_data.cookies = cookies.to_dict() return COMMAND_RESULT @conversation.state(state=COMMAND_RESULT) @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable async def command_result(self, update: Update, context: CallbackContext) -> int: user = update.effective_user message = update.effective_message - add_user_command_data: AddUserCommandData = context.chat_data.get("add_user_command_data") + account_cookies_plugin_data: AccountCookiesPluginData = context.chat_data.get("account_cookies_plugin_data") if message.text == "退出": await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - elif message.text == "确认": - if add_user_command_data.user is None: - if add_user_command_data.region == RegionEnum.HYPERION: - user_db = User( + if message.text == "确认": + player = account_cookies_plugin_data.player + genshin_account = account_cookies_plugin_data.genshin_account + if player: + await self.players_service.update(player) + cookies = account_cookies_plugin_data.cookies_data_base + if cookies: + cookies.data = account_cookies_plugin_data.cookies + cookies.status = CookiesStatusEnum.STATUS_SUCCESS + await self.cookies_service.update(cookies) + else: + cookies = Cookies( user_id=user.id, - yuanshen_uid=add_user_command_data.game_uid, - region=add_user_command_data.region, - ) - elif add_user_command_data.region == RegionEnum.HOYOLAB: - user_db = User( - user_id=user.id, genshin_uid=add_user_command_data.game_uid, region=add_user_command_data.region + account_id=account_cookies_plugin_data.account_id, + data=account_cookies_plugin_data.cookies, + region=account_cookies_plugin_data.region, + is_share=True, # todo 用户可以自行选择是否将Cookies加入公共池 ) - else: - await message.reply_text("数据错误") - return ConversationHandler.END - await self.user_service.add_user(user_db) + await self.cookies_service.add(cookies) + logger.success("用户 %s[%s] 更新Cookies", user.full_name, user.id) else: - user_db = add_user_command_data.user - user_db.region = add_user_command_data.region - if add_user_command_data.region == RegionEnum.HYPERION: - user_db.yuanshen_uid = add_user_command_data.game_uid - elif add_user_command_data.region == RegionEnum.HOYOLAB: - user_db.genshin_uid = add_user_command_data.game_uid - else: - await message.reply_text("数据错误") - return ConversationHandler.END - await self.user_service.update_user(user_db) - await self.cookies_service.add_or_update_cookies( - user.id, add_user_command_data.cookies, add_user_command_data.region - ) - logger.info("用户 %s[%s] 绑定账号成功", user.full_name, user.id) + player = Player( + user_id=user.id, + account_id=account_cookies_plugin_data.account_id, + player_id=genshin_account.uid, + region=account_cookies_plugin_data.region, + is_chosen=True, # todo 多账号 + ) + player_info = await self.player_info_service.get(player) + if player_info is None: + player_info = PlayerInfoSQLModel( + user_id=player.user_id, + player_id=player.player_id, + nickname=genshin_account.nickname, + create_time=datetime.now(), + is_update=True, + ) # 不添加更新时间 + await self.player_info_service.add(player_info) + await self.players_service.add(player) + cookies = Cookies( + user_id=user.id, + account_id=account_cookies_plugin_data.account_id, + data=account_cookies_plugin_data.cookies, + region=account_cookies_plugin_data.region, + is_share=True, # todo 用户可以自行选择是否将Cookies加入公共池 + ) + await self.cookies_service.add(cookies) + logger.info("用户 %s[%s] 绑定账号成功", user.full_name, user.id) await message.reply_text("保存成功", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - else: - await message.reply_text("回复错误,请重新输入") - return COMMAND_RESULT + await message.reply_text("回复错误,请重新输入") + return COMMAND_RESULT diff --git a/plugins/account/players.py b/plugins/account/players.py new file mode 100644 index 00000000..f46c8ffc --- /dev/null +++ b/plugins/account/players.py @@ -0,0 +1,243 @@ +from typing import Tuple + +from telegram import Update, InlineKeyboardMarkup, InlineKeyboardButton +from telegram.ext import filters, ContextTypes + +from core.plugin import Plugin, handler +from core.services.cookies import CookiesService +from core.services.players import PlayersService +from core.services.players.services import PlayerInfoService +from utils.log import logger + +__all__ = ("PlayersManagesPlugin",) + + +class PlayersManagesPlugin(Plugin): + def __init__(self, players: PlayersService, cookies: CookiesService, player_info_service: PlayerInfoService): + self.cookies_service = cookies + self.players_service = players + self.player_info_service = player_info_service + + @staticmethod + def players_manager_callback(callback_query_data: str) -> Tuple[str, int, int]: + _data = callback_query_data.split("|") + _handle = _data[-3] + _user_id = int(_data[-2]) + _player_id = int(_data[-1]) + logger.debug("players_manager_callback函数返回 handle[%s] user_id[%s] player_id[%s]", _handle, _user_id, _player_id) + return _handle, _user_id, _player_id + + @handler.command(command="player", filters=filters.ChatType.PRIVATE, block=False) + @handler.command(command="players", filters=filters.ChatType.PRIVATE, block=False) + @handler.callback_query(r"^players_manager\|list", block=False) + async def command_start(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: + callback_query = update.callback_query + user = update.effective_user + message = update.effective_message + players = await self.players_service.get_all_by_user_id(user.id) + if len(players) == 0: + if callback_query: + await callback_query.edit_message_text("未查询到您所绑定的账号信息,请先绑定账号") + else: + await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号") + return + buttons = [] + for player in players: + player_info = await self.player_info_service.get(player) + text = f"{player.player_id} {player_info.nickname}" + buttons.append( + [ + InlineKeyboardButton( + text, + callback_data=f"players_manager|get|{user.id}|{player.player_id}", + ) + ] + ) + if callback_query: + await callback_query.edit_message_text("从下面的列表中选择一个玩家", reply_markup=InlineKeyboardMarkup(buttons)) + else: + await message.reply_text("从下面的列表中选择一个玩家", reply_markup=InlineKeyboardMarkup(buttons)) + + @handler.callback_query(r"^players_manager\|get\|", block=False) + async def get_player(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: + callback_query = update.callback_query + user = callback_query.from_user + + _, user_id, player_id = self.players_manager_callback(callback_query.data) + if user.id != user_id: + if callback_query.message: + await callback_query.message.delete() + return + + player = await self.players_service.get(user.id, player_id=player_id) + if player is None: + await callback_query.edit_message_text(f"账号 {player_id} 未找到") + return + + player_info = await self.player_info_service.get(player) + if player_info is None: + await callback_query.edit_message_text(f"账号 {player_id} 信息未找到") + return + + buttons = [ + [ + InlineKeyboardButton( + "设置为主账号", + callback_data=f"players_manager|main|{user.id}|{player.player_id}", + ), + InlineKeyboardButton( + "删除账号", + callback_data=f"players_manager|del|{user.id}|{player.player_id}", + ), + InlineKeyboardButton( + "更新账号信息", + callback_data=f"players_manager|update|{user.id}|{player.player_id}", + ), + ], + [ + InlineKeyboardButton( + "« 返回玩家列表", + callback_data="players_manager|list", + ) + ], + ] + + await callback_query.edit_message_text( + f"这里是 {player.player_id} {player_info.nickname}\n你想用这个账号做什么?", reply_markup=InlineKeyboardMarkup(buttons) + ) + + @handler.callback_query(r"^players_manager\|update\|", block=False) + async def update_user(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: + callback_query = update.callback_query + user = callback_query.from_user + + _, user_id, player_id = self.players_manager_callback(callback_query.data) + if user.id != user_id: + if callback_query.message: + await callback_query.message.delete() + return + + player = await self.players_service.get(user.id, player_id=player_id) + if player is None: + await callback_query.edit_message_text(f"账号 {player_id} 未找到") + return + + status = await self.player_info_service.update_from_enka(player) + + buttons = [ + [ + InlineKeyboardButton( + "« 返回", + callback_data=f"players_manager|get|{user.id}|{player.player_id}", + ) + ], + ] + + if status: + await callback_query.edit_message_text( + f"更新玩家信息 {player.player_id} 成功", reply_markup=InlineKeyboardMarkup(buttons) + ) + else: + await callback_query.edit_message_text( + f"更新玩家信息 {player.player_id} 更新失败 请稍后重试", reply_markup=InlineKeyboardMarkup(buttons) + ) + + @handler.callback_query(r"^players_manager\|main\|", block=False) + async def set_main(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: + callback_query = update.callback_query + user = callback_query.from_user + + _, user_id, player_id = self.players_manager_callback(callback_query.data) + if user.id != user_id: + if callback_query.message: + await callback_query.message.delete() + return + + player = await self.players_service.get(user.id, player_id=player_id) + if player is None: + await callback_query.edit_message_text(f"账号 {player_id} 未找到") + return + + player_info = await self.player_info_service.get(player) + if player_info is None: + await callback_query.edit_message_text(f"账号 {player_id} 信息未找到") + return + + main_player = await self.players_service.get(user.id, is_chosen=True) + if main_player and player.id != main_player.id: + main_player.is_chosen = False + await self.players_service.update(main_player) + + player.is_chosen = True + await self.players_service.update(player) + + buttons = [ + [ + InlineKeyboardButton( + "« 返回", + callback_data=f"players_manager|get|{user.id}|{player.player_id}", + ) + ], + ] + + await callback_query.edit_message_text( + f"成功设置 {player.player_id} {player_info.nickname} 为主账号", reply_markup=InlineKeyboardMarkup(buttons) + ) + + @handler.callback_query(r"^players_manager\|del\|", block=False) + async def delete(self, update: Update, _: ContextTypes.DEFAULT_TYPE) -> None: + callback_query = update.callback_query + user = callback_query.from_user + + _handle, user_id, player_id = self.players_manager_callback(callback_query.data) + if user.id != user_id: + if callback_query.message: + await callback_query.message.delete() + return + + player = await self.players_service.get(user.id, player_id=player_id) + if player is None: + await callback_query.edit_message_text(f"账号 {player_id} 未找到") + return + + player_info = await self.player_info_service.get(player) + if player_info is None: + await callback_query.edit_message_text(f"账号 {player_id} 信息未找到") + return + + if _handle == "true": + buttons = [ + [ + InlineKeyboardButton( + "« 返回玩家列表", + callback_data="players_manager|list", + ) + ], + ] + await self.players_service.delete(player) + cookies = await self.cookies_service.get(player.user_id, player.account_id, player.region) + if cookies: + await self.cookies_service.delete(cookies) + await self.player_info_service.delete(player_info) + await callback_query.edit_message_text( + f"成功删除 {player.player_id} ", reply_markup=InlineKeyboardMarkup(buttons) + ) + elif _handle == "del": + buttons = [ + [ + InlineKeyboardButton( + "是的我非常确定", + callback_data=f"players_manager|del|true|{user.id}|{player.player_id}", + ) + ], + [ + InlineKeyboardButton( + "取消操作", + callback_data=f"players_manager|get|{user.id}|{player.player_id}", + ) + ], + ] + await callback_query.edit_message_text("请问你真的要从Bot中删除改账号吗?", reply_markup=InlineKeyboardMarkup(buttons)) + else: + if callback_query.message: + await callback_query.message.delete() diff --git a/plugins/admin/admin.py b/plugins/admin/admin.py new file mode 100644 index 00000000..a50f80dd --- /dev/null +++ b/plugins/admin/admin.py @@ -0,0 +1,51 @@ +from telegram import Update +from telegram.ext import ContextTypes + +from core.plugin import Plugin, handler +from core.services.users.services import UserAdminService +from utils.log import logger + + +class AdminPlugin(Plugin): + """有关BOT ADMIN处理""" + + def __init__(self, user_admin_service: UserAdminService = None): + self.user_admin_service = user_admin_service + + @handler.command("add_admin", block=False, admin=True) + async def add_admin(self, update: Update, _: ContextTypes.DEFAULT_TYPE): + message = update.effective_message + reply_to_message = message.reply_to_message + user = update.effective_user + logger.info("用户 %s[%s] add_admin 命令请求", user.full_name, user.id) + if reply_to_message: + from_user = reply_to_message.from_user + if from_user: + if await self.user_admin_service.add_admin(from_user.id): + logger.success("成功添加用户 %s[%s] 到Bot的管理员权限", from_user.full_name, from_user.id) + await message.reply_text("添加成功") + else: + await message.reply_text("该用户已经存在管理员列表") + else: + await message.reply_text("回复的用户不存在") + else: + await message.reply_text("请回复对应消息") + + @handler.command("del_admin", block=False, admin=True) + async def del_admin(self, update: Update, _: ContextTypes.DEFAULT_TYPE): + message = update.effective_message + reply_to_message = message.reply_to_message + user = update.effective_user + logger.info("用户 %s[%s] del_admin 命令请求", user.full_name, user.id) + if reply_to_message: + from_user = reply_to_message.from_user + if from_user: + if await self.user_admin_service.delete_admin(from_user.id): + logger.success("成功移除用户 %s[%s] 在Bot的管理员权限", from_user.full_name, from_user.id) + await message.reply_text("移除成功") + else: + await message.reply_text("移除失败 该用户不存在管理员列表") + else: + await message.reply_text("回复的用户不存在") + else: + await message.reply_text("请回复对应消息") diff --git a/plugins/admin/get_chat.py b/plugins/admin/get_chat.py new file mode 100644 index 00000000..6eb4dc5b --- /dev/null +++ b/plugins/admin/get_chat.py @@ -0,0 +1,94 @@ +import html +from typing import Tuple + +from telegram import Chat, ChatMember, ChatMemberAdministrator, ChatMemberOwner, Update +from telegram.error import BadRequest, Forbidden +from telegram.ext import CallbackContext, CommandHandler + +from core.basemodel import RegionEnum +from core.plugin import Plugin, handler +from core.services.cookies import CookiesService +from core.services.players import PlayersService +from utils.log import logger + + +class GetChat(Plugin): + def __init__( + self, + players_service: PlayersService, + cookies_service: CookiesService, + ): + self.cookies_service = cookies_service + self.players_service = players_service + + @staticmethod + async def parse_group_chat(chat: Chat, admins: Tuple[ChatMember]) -> str: + text = f"群 ID:{chat.id}\n群名称:{chat.title}\n" + if chat.username: + text += f"群用户名:@{chat.username}\n" + if chat.description: + text += f"群简介:{html.escape(chat.description)}\n" + if admins: + for admin in admins: + text += f'{html.escape(admin.user.full_name)} ' + if isinstance(admin, ChatMemberAdministrator): + text += "C" if admin.can_change_info else "_" + text += "D" if admin.can_delete_messages else "_" + text += "R" if admin.can_restrict_members else "_" + text += "I" if admin.can_invite_users else "_" + text += "T" if admin.can_manage_topics else "_" + text += "P" if admin.can_pin_messages else "_" + text += "V" if admin.can_manage_video_chats else "_" + text += "N" if admin.can_promote_members else "_" + text += "A" if admin.is_anonymous else "_" + elif isinstance(admin, ChatMemberOwner): + text += "创建者" + text += "\n" + return text + + async def parse_private_chat(self, chat: Chat) -> str: + text = ( + f'MENTION\n' + f"用户 ID:{chat.id}\n" + f"用户名称:{chat.full_name}\n" + ) + if chat.username: + text += f"用户名:@{chat.username}\n" + player_info = await self.players_service.get_player(chat.id) + if player_info is not None: + if player_info.region == RegionEnum.HYPERION: + text += "米游社绑定:" + else: + text += "原神绑定:" + cookies_info = await self.cookies_service.get(chat.id, player_info.account_id, player_info.region) + if cookies_info is None: + temp = "UID 绑定" + else: + temp = "Cookie 绑定" + text += f"{temp}\n游戏 ID:{player_info.player_id}" + return text + + @handler(CommandHandler, command="get_chat", block=False, admin=True) + async def get_chat_command(self, update: Update, context: CallbackContext): + user = update.effective_user + logger.info("用户 %s[%s] get_chat 命令请求", user.full_name, user.id) + message = update.effective_message + args = self.get_args(context) + if not args: + await message.reply_text("参数错误,请指定群 id !") + return + try: + chat_id = int(args[0]) + except ValueError: + await message.reply_text("参数错误,请指定群 id !") + return + try: + chat = await self.get_chat(args[0]) + if chat_id < 0: + admins = await chat.get_administrators() if chat_id < 0 else None + text = await self.parse_group_chat(chat, admins) + else: + text = await self.parse_private_chat(chat) + await message.reply_text(text, parse_mode="HTML") + except (BadRequest, Forbidden) as exc: + await message.reply_text(f"通过 id 获取会话信息失败,API 返回:{exc.message}") diff --git a/plugins/other/post.py b/plugins/admin/post.py similarity index 86% rename from plugins/other/post.py rename to plugins/admin/post.py index d0d7468e..b0060be6 100644 --- a/plugins/other/post.py +++ b/plugins/admin/post.py @@ -1,29 +1,25 @@ -from typing import Optional, List, Tuple +from typing import List, Optional, Tuple -from bs4 import BeautifulSoup, Tag +from bs4 import BeautifulSoup from telegram import ( - Update, - ReplyKeyboardMarkup, - ReplyKeyboardRemove, InlineKeyboardButton, InlineKeyboardMarkup, + InputMediaPhoto, Message, + ReplyKeyboardMarkup, + ReplyKeyboardRemove, + Update, ) -from telegram.constants import ParseMode, MessageLimit +from telegram.constants import MessageLimit, ParseMode from telegram.error import BadRequest from telegram.ext import CallbackContext, ConversationHandler, filters from telegram.helpers import escape_markdown -from core.baseplugin import BasePlugin -from core.bot import bot from core.config import config from core.plugin import Plugin, conversation, handler from modules.apihelper.client.components.hyperion import Hyperion from modules.apihelper.error import APIHelperException from modules.apihelper.models.genshin.hyperion import ArtworkImage -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger @@ -40,7 +36,7 @@ def __init__(self): GET_POST_CHANNEL, GET_TAGS, GET_TEXT = range(10904, 10907) -class Post(Plugin.Conversation, BasePlugin.Conversation): +class Post(Plugin.Conversation): """文章推送""" MENU_KEYBOARD = ReplyKeyboardMarkup([["推送频道", "添加TAG"], ["编辑文字", "删除图片"], ["退出"]], True, True) @@ -48,9 +44,11 @@ class Post(Plugin.Conversation, BasePlugin.Conversation): def __init__(self): self.bbs = Hyperion() self.last_post_id_list: List[int] = [] + + async def initialize(self): if config.channels and len(config.channels) > 0: logger.success("文章定时推送处理已经开启") - bot.app.job_queue.run_repeating(self.task, 60) + self.application.job_queue.run_repeating(self.task, 60) async def task(self, context: CallbackContext): temp_post_id_list: List[int] = [] @@ -109,8 +107,6 @@ async def task(self, context: CallbackContext): @conversation.entry_point @handler.callback_query(pattern=r"^post_admin\|", block=False) - @bot_admins_rights_check - @error_callable async def callback_query_start(self, update: Update, context: CallbackContext) -> int: post_handler_data = context.chat_data.get("post_handler_data") if post_handler_data is None: @@ -143,10 +139,7 @@ async def get_post_admin_callback(callback_query_data: str) -> Tuple[str, int]: return ConversationHandler.END @conversation.entry_point - @handler.command(command="post", filters=filters.ChatType.PRIVATE, block=True) - @restricts() - @bot_admins_rights_check - @error_callable + @handler.command(command="post", filters=filters.ChatType.PRIVATE, block=False, admin=True) async def command_start(self, update: Update, context: CallbackContext) -> int: user = update.effective_user message = update.effective_message @@ -161,8 +154,7 @@ async def command_start(self, update: Update, context: CallbackContext) -> int: return CHECK_POST @conversation.state(state=CHECK_POST) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def check_post(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") message = update.effective_message @@ -176,41 +168,24 @@ async def check_post(self, update: Update, context: CallbackContext) -> int: return ConversationHandler.END return await self.send_post_info(post_handler_data, message, post_id) - @staticmethod - def parse_post_text(soup: BeautifulSoup, post_subject: str) -> str: - def parse_tag(_tag: Tag) -> str: - if _tag.name == "a" and _tag.get("href"): - return f"[{escape_markdown(_tag.get_text(), version=2)}]({_tag.get('href')})" - return escape_markdown(_tag.get_text(), version=2) - - post_p = soup.find_all("p") - post_text = f"*{escape_markdown(post_subject, version=2)}*\n\n" - start = True - for p in post_p: - t = p.get_text() - if not t and start: - continue - start = False - for tag in p.contents: - post_text += parse_tag(tag) - post_text += "\n" - return post_text - async def send_post_info(self, post_handler_data: PostHandlerData, message: Message, post_id: int) -> int: post_info = await self.bbs.get_post_info(2, post_id) post_images = await self.bbs.get_images_by_post_id(2, post_id) post_data = post_info["post"]["post"] post_subject = post_data["subject"] post_soup = BeautifulSoup(post_data["content"], features="html.parser") - post_text = self.parse_post_text(post_soup, post_subject) + post_p = post_soup.find_all("p") + post_text = f"*{escape_markdown(post_subject, version=2)}*\n" f"\n" + for p in post_p: + post_text += f"{escape_markdown(p.get_text(), version=2)}\n" post_text += f"[source](https://www.miyoushe.com/ys/article/{post_id})" if len(post_text) >= MessageLimit.CAPTION_LENGTH: post_text = post_text[: MessageLimit.CAPTION_LENGTH] await message.reply_text(f"警告!图片字符描述已经超过 {MessageLimit.CAPTION_LENGTH} 个字,已经切割") try: if len(post_images) > 1: - media = [img_info.input_media() for img_info in post_images if img_info.format] - media[0] = post_images[0].input_media(caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) + media = [InputMediaPhoto(img_info.data) for img_info in post_images] + media[0] = InputMediaPhoto(post_images[0].data, caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) if len(media) > 10: media = media[:10] await message.reply_text("获取到的图片已经超过10张,为了保证发送成功,已经删除一部分图片") @@ -235,20 +210,19 @@ async def send_post_info(self, post_handler_data: PostHandlerData, message: Mess return CHECK_COMMAND @conversation.state(state=CHECK_COMMAND) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def check_command(self, update: Update, context: CallbackContext) -> int: message = update.effective_message if message.text == "退出": await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - elif message.text == "推送频道": + if message.text == "推送频道": return await self.get_channel(update, context) - elif message.text == "添加TAG": + if message.text == "添加TAG": return await self.add_tags(update, context) - elif message.text == "编辑文字": + if message.text == "编辑文字": return await self.edit_text(update, context) - elif message.text == "删除图片": + if message.text == "删除图片": return await self.delete_photo(update, context) return ConversationHandler.END @@ -261,8 +235,7 @@ async def delete_photo(update: Update, context: CallbackContext) -> int: return GTE_DELETE_PHOTO @conversation.state(state=GTE_DELETE_PHOTO) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_delete_photo(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") photo_len = len(post_handler_data.post_images) @@ -282,14 +255,13 @@ async def get_delete_photo(self, update: Update, context: CallbackContext) -> in await message.reply_text("请选择你的操作", reply_markup=self.MENU_KEYBOARD) return CHECK_COMMAND - @staticmethod - async def get_channel(update: Update, _: CallbackContext) -> int: + async def get_channel(self, update: Update, _: CallbackContext) -> int: message = update.effective_message reply_keyboard = [] try: - for channel_info in bot.config.channels: - name = channel_info.name - reply_keyboard.append([f"{name}"]) + for channel_id in config.channels: + chat = await self.get_chat(chat_id=channel_id) + reply_keyboard.append([f"{chat.username}"]) except KeyError as error: logger.error("从配置文件获取频道信息发生错误,退出任务", exc_info=error) await message.reply_text("从配置文件获取频道信息发生错误,退出任务", reply_markup=ReplyKeyboardRemove()) @@ -298,16 +270,16 @@ async def get_channel(update: Update, _: CallbackContext) -> int: return GET_POST_CHANNEL @conversation.state(state=GET_POST_CHANNEL) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_post_channel(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") message = update.effective_message channel_id = -1 try: - for channel_info in bot.config.channels: - if message.text == channel_info.name: - channel_id = channel_info.chat_id + for channel_chat_id in config.channels: + chat = await self.get_chat(chat_id=channel_id) + if message.text == chat.username: + channel_id = channel_chat_id except KeyError as exc: logger.error("从配置文件获取频道信息发生错误,退出任务", exc_info=exc) logger.exception(exc) @@ -328,8 +300,7 @@ async def add_tags(update: Update, _: CallbackContext) -> int: return GET_TAGS @conversation.state(state=GET_TAGS) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_tags(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") message = update.effective_message @@ -346,8 +317,7 @@ async def edit_text(update: Update, _: CallbackContext) -> int: return GET_TEXT @conversation.state(state=GET_TEXT) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_edit_text(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") message = update.effective_message @@ -357,8 +327,7 @@ async def get_edit_text(self, update: Update, context: CallbackContext) -> int: return CHECK_COMMAND @conversation.state(state=SEND_POST) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def send_post(self, update: Update, context: CallbackContext) -> int: post_handler_data: PostHandlerData = context.chat_data.get("post_handler_data") message = update.effective_message @@ -369,9 +338,10 @@ async def send_post(self, update: Update, context: CallbackContext) -> int: channel_id = post_handler_data.channel_id channel_name = None try: - for channel_info in bot.config.channels: + for channel_info in config.channels: if post_handler_data.channel_id == channel_info.chat_id: - channel_name = channel_info.name + chat = await self.get_chat(chat_id=channel_id) + channel_name = chat.username except KeyError as exc: logger.error("从配置文件获取频道信息发生错误,退出任务") logger.exception(exc) @@ -382,13 +352,13 @@ async def send_post(self, update: Update, context: CallbackContext) -> int: for index, _ in enumerate(post_handler_data.post_images): if index + 1 not in post_handler_data.delete_photo: post_images.append(post_handler_data.post_images[index]) - post_text += f" @{escape_markdown(channel_name, version=2)}" + post_text += f" @{channel_name}" for tag in post_handler_data.tags: post_text += f" \\#{tag}" try: if len(post_images) > 1: - media = [img_info.input_media() for img_info in post_images if img_info.format] - media[0] = post_images[0].input_media(caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) + media = [InputMediaPhoto(img_info.data) for img_info in post_images] + media[0] = InputMediaPhoto(post_images[0].data, caption=post_text, parse_mode=ParseMode.MARKDOWN_V2) await context.bot.send_media_group(channel_id, media=media) elif len(post_images) == 1: image = post_images[0] diff --git a/plugins/system/set_quiz.py b/plugins/admin/quiz.py similarity index 86% rename from plugins/system/set_quiz.py rename to plugins/admin/quiz.py index f1166381..e816fa72 100644 --- a/plugins/system/set_quiz.py +++ b/plugins/admin/quiz.py @@ -6,13 +6,9 @@ from telegram.ext import CallbackContext, ConversationHandler, filters from telegram.helpers import escape_markdown -from core.baseplugin import BasePlugin from core.plugin import Plugin, conversation, handler -from core.quiz import QuizService -from core.quiz.models import Answer, Question -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from core.services.quiz.models import Answer, Question +from core.services.quiz.services import QuizService from utils.log import logger ( @@ -35,7 +31,7 @@ class QuizCommandData: status: int = 0 -class SetQuizPlugin(Plugin.Conversation, BasePlugin.Conversation): +class SetQuizPlugin(Plugin.Conversation): """派蒙的十万个为什么问题修改/添加/删除""" def __init__(self, quiz_service: QuizService = None): @@ -43,10 +39,7 @@ def __init__(self, quiz_service: QuizService = None): self.time_out = 120 @conversation.entry_point - @handler.command(command="set_quiz", filters=filters.ChatType.PRIVATE, block=True) - @restricts() - @bot_admins_rights_check - @error_callable + @handler.command(command="set_quiz", filters=filters.ChatType.PRIVATE, block=False, admin=True) async def command_start(self, update: Update, context: CallbackContext) -> int: user = update.effective_user message = update.effective_message @@ -67,42 +60,41 @@ async def view_command(self, update: Update, _: CallbackContext) -> int: return CHECK_COMMAND @conversation.state(state=CHECK_QUESTION) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def check_question(self, update: Update, _: CallbackContext) -> int: reply_keyboard = [["删除问题"], ["退出"]] await update.message.reply_text("请选择你的操作", reply_markup=ReplyKeyboardMarkup(reply_keyboard)) return CHECK_COMMAND @conversation.state(state=CHECK_COMMAND) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def check_command(self, update: Update, context: CallbackContext) -> int: quiz_command_data: QuizCommandData = context.chat_data.get("quiz_command_data") if update.message.text == "退出": await update.message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - elif update.message.text == "查看问题": + if update.message.text == "查看问题": return await self.view_command(update, context) - elif update.message.text == "添加问题": + if update.message.text == "添加问题": return await self.add_question(update, context) - elif update.message.text == "删除问题": + if update.message.text == "删除问题": return await self.delete_question(update, context) # elif update.message.text == "修改问题": # return await self.edit_question(update, context) - elif update.message.text == "重载问题": + if update.message.text == "重载问题": return await self.refresh_question(update, context) - else: - result = re.findall(r"问题ID (\d+)", update.message.text) - if len(result) == 1: - try: - question_id = int(result[0]) - except ValueError: - await update.message.reply_text("获取问题ID失败") - return ConversationHandler.END - quiz_command_data.question_id = question_id - await update.message.reply_text("获取问题ID成功") - return await self.check_question(update, context) - await update.message.reply_text("命令错误", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END + result = re.findall(r"问题ID (\d+)", update.message.text) + if len(result) == 1: + try: + question_id = int(result[0]) + except ValueError: + await update.message.reply_text("获取问题ID失败") + return ConversationHandler.END + quiz_command_data.question_id = question_id + await update.message.reply_text("获取问题ID成功") + return await self.check_question(update, context) + await update.message.reply_text("命令错误", reply_markup=ReplyKeyboardRemove()) + return ConversationHandler.END async def refresh_question(self, update: Update, _: CallbackContext) -> int: try: @@ -128,7 +120,7 @@ async def add_question(self, update: Update, context: CallbackContext) -> int: return GET_NEW_QUESTION @conversation.state(state=GET_NEW_QUESTION) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_new_question(self, update: Update, context: CallbackContext) -> int: message = update.effective_message quiz_command_data: QuizCommandData = context.chat_data.get("quiz_command_data") @@ -138,7 +130,7 @@ async def get_new_question(self, update: Update, context: CallbackContext) -> in return GET_NEW_CORRECT_ANSWER @conversation.state(state=GET_NEW_CORRECT_ANSWER) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def get_new_correct_answer(self, update: Update, context: CallbackContext) -> int: quiz_command_data: QuizCommandData = context.chat_data.get("quiz_command_data") reply_text = f"正确答案:`{escape_markdown(update.message.text, version=2)}`\n" f"请填写错误答案:" @@ -147,8 +139,8 @@ async def get_new_correct_answer(self, update: Update, context: CallbackContext) return GET_NEW_WRONG_ANSWER @conversation.state(state=GET_NEW_WRONG_ANSWER) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @handler.command(command="finish_edit", block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) + @handler.command(command="finish_edit", block=False) async def get_new_wrong_answer(self, update: Update, context: CallbackContext) -> int: quiz_command_data: QuizCommandData = context.chat_data.get("quiz_command_data") reply_text = ( @@ -173,13 +165,13 @@ async def finish_edit(self, update: Update, context: CallbackContext): return SAVE_QUESTION @conversation.state(state=SAVE_QUESTION) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) + @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) async def save_question(self, update: Update, context: CallbackContext): quiz_command_data: QuizCommandData = context.chat_data.get("quiz_command_data") if update.message.text == "抛弃修改并退出": await update.message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - elif update.message.text == "保存并重载配置": + if update.message.text == "保存并重载配置": if quiz_command_data.status == 1: answer = [ Answer(text=wrong_answer, is_correct=False) for wrong_answer in quiz_command_data.new_wrong_answer @@ -197,9 +189,8 @@ async def save_question(self, update: Update, context: CallbackContext): return ConversationHandler.END await update.message.reply_text("重载配置成功", reply_markup=ReplyKeyboardRemove()) return ConversationHandler.END - else: - await update.message.reply_text("回复错误,请重新选择") - return SAVE_QUESTION + await update.message.reply_text("回复错误,请重新选择") + return SAVE_QUESTION async def edit_question(self, update: Update, context: CallbackContext) -> int: _ = self diff --git a/plugins/system/refresh_metadata.py b/plugins/admin/refresh_metadata.py similarity index 74% rename from plugins/system/refresh_metadata.py rename to plugins/admin/refresh_metadata.py index 45823a28..a1194a43 100644 --- a/plugins/system/refresh_metadata.py +++ b/plugins/admin/refresh_metadata.py @@ -1,21 +1,21 @@ from telegram import Update +from telegram.ext import CallbackContext from core.plugin import Plugin, handler from metadata.scripts.honey import update_honey_metadata from metadata.scripts.metadatas import update_metadata_from_ambr, update_metadata_from_github from metadata.scripts.paimon_moe import update_paimon_moe_zh -from utils.decorators.admins import bot_admins_rights_check from utils.log import logger +__all__ = ("MetadataPlugin",) + class MetadataPlugin(Plugin): - @handler.command("refresh_metadata") - @bot_admins_rights_check - async def refresh(self, update: Update, _) -> None: - user = update.effective_user + @handler.command("refresh_metadata", admin=True) + async def refresh(self, update: Update, _: CallbackContext) -> None: message = update.effective_message - - logger.info(f"用户 {user.full_name}[{user.id}] 刷新[bold]metadata[/]缓存命令", extra={"markup": True}) + user = update.effective_user + logger.info("用户 %s[%s] 刷新[bold]metadata[/]缓存命令", user.full_name, user.id, extra={"markup": True}) msg = await message.reply_text("正在刷新元数据,请耐心等待...") logger.info("正在从 github 上获取元数据") diff --git a/plugins/system/search.py b/plugins/admin/search.py similarity index 76% rename from plugins/system/search.py rename to plugins/admin/search.py index 99f5b8e9..0d96eda8 100644 --- a/plugins/system/search.py +++ b/plugins/admin/search.py @@ -5,23 +5,21 @@ from telegram.ext import CallbackContext from core.plugin import handler, Plugin, job -from core.search.services import SearchServices -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.restricts import restricts +from core.services.search.services import SearchServices from utils.log import logger -__all__ = [] +__all__ = ("SearchPlugin",) class SearchPlugin(Plugin): def __init__(self, search: SearchServices = None): self.search = search - self._lock = asyncio.Lock() + self.lock = asyncio.Lock() - async def __async_init__(self): + async def initialize(self): async def load_data(): logger.info("Search 插件模块正在加载搜索条目") - async with self._lock: + async with self.lock: await self.search.load_data() logger.success("Search 插件加载模块搜索条目成功") @@ -29,32 +27,28 @@ async def load_data(): @job.run_repeating(interval=datetime.timedelta(hours=1), name="SaveEntryJob") async def save_entry_job(self, _: CallbackContext): - if self._lock.locked(): + if self.lock.locked(): logger.warning("条目数据正在保存 跳过本次定时任务") else: - async with self._lock: + async with self.lock: logger.info("条目数据正在自动保存") await self.search.save_entry() logger.success("条目数据自动保存成功") - @handler.command("save_entry", block=False) - @bot_admins_rights_check - @restricts() + @handler.command("save_entry", block=False, admin=True) async def save_entry(self, update: Update, _: CallbackContext): user = update.effective_user message = update.effective_message logger.info("用户 %s[%s] 保存条目数据命令请求", user.full_name, user.id) - if self._lock.locked(): + if self.lock.locked(): await message.reply_text("条目数据正在保存 请稍后重试") else: - async with self._lock: + async with self.lock: reply_text = await message.reply_text("正在保存数据") await self.search.save_entry() await reply_text.edit_text("数据保存成功") - @handler.command("remove_all_entry", block=False) - @bot_admins_rights_check - @restricts() + @handler.command("remove_all_entry", block=False, admin=True) async def remove_all_entry(self, update: Update, _: CallbackContext): user = update.effective_user message = update.effective_message diff --git a/plugins/admin/sign_all.py b/plugins/admin/sign_all.py new file mode 100644 index 00000000..006ded94 --- /dev/null +++ b/plugins/admin/sign_all.py @@ -0,0 +1,20 @@ +from telegram import Update +from telegram.ext import CallbackContext, CommandHandler + +from core.plugin import Plugin, handler +from plugins.tools.sign import SignSystem, SignJobType +from utils.log import logger + + +class SignAll(Plugin): + def __init__(self, sign_system: SignSystem): + self.sign_system = sign_system + + @handler(CommandHandler, command="sign_all", block=False, admin=True) + async def sign_all(self, update: Update, context: CallbackContext): + user = update.effective_user + logger.info("用户 %s[%s] sign_all 命令请求", user.full_name, user.id) + message = update.effective_message + reply = await message.reply_text("正在全部重新签到,请稍后...") + await self.sign_system.do_sign_job(context, job_type=SignJobType.START) + await reply.edit_text("全部账号重新签到完成") diff --git a/plugins/system/sign_status.py b/plugins/admin/sign_status.py similarity index 72% rename from plugins/system/sign_status.py rename to plugins/admin/sign_status.py index 6da58057..8dcfea15 100644 --- a/plugins/system/sign_status.py +++ b/plugins/admin/sign_status.py @@ -1,14 +1,13 @@ from telegram import Update -from telegram.ext import CommandHandler, CallbackContext +from telegram.ext import CallbackContext, CommandHandler from core.plugin import Plugin, handler -from core.sign import SignServices -from utils.decorators.admins import bot_admins_rights_check +from core.services.sign.services import SignServices from utils.log import logger class SignStatus(Plugin): - def __init__(self, sign_service: SignServices = None): + def __init__(self, sign_service: SignServices): self.sign_service = sign_service @staticmethod @@ -21,11 +20,10 @@ async def get_sign_status(sign_service: SignServices) -> str: text = f"自动签到统计信息\n\n总人数:{len(sign_db)}\n" return text + "\n".join(f"{name}: {value}" for name, value in zip(names, values)) - @handler(CommandHandler, command="sign_status", block=False) - @bot_admins_rights_check + @handler(CommandHandler, command="sign_status", block=False, admin=True) async def sign_status(self, update: Update, _: CallbackContext): user = update.effective_user - logger.info(f"用户 {user.full_name}[{user.id}] sign_status 命令请求") + logger.info("用户 %s[%s] sign_status 命令请求", user.full_name, user.id) message = update.effective_message text = await self.get_sign_status(self.sign_service) await message.reply_text(text, parse_mode="html", quote=True) diff --git a/plugins/app/__init__.py b/plugins/app/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/system/inline.py b/plugins/app/inline.py similarity index 88% rename from plugins/system/inline.py rename to plugins/app/inline.py index 5a9aae31..a1449e38 100644 --- a/plugins/system/inline.py +++ b/plugins/app/inline.py @@ -1,23 +1,22 @@ import asyncio -from typing import cast, Dict, Awaitable, List +from typing import Awaitable, Dict, List, cast from uuid import uuid4 from telegram import ( + InlineQuery, InlineQueryResultArticle, + InlineQueryResultCachedPhoto, InputTextMessageContent, Update, - InlineQuery, - InlineQueryResultCachedPhoto, ) from telegram.constants import ParseMode from telegram.error import BadRequest from telegram.ext import CallbackContext, InlineQueryHandler -from core.base.assets import AssetsService, AssetsCouldNotFound -from core.plugin import handler, Plugin -from core.search.services import SearchServices -from core.wiki import WikiService -from utils.decorators.error import error_callable +from core.dependence.assets import AssetsCouldNotFound, AssetsService +from core.plugin import Plugin, handler +from core.services.search.services import SearchServices +from core.services.wiki.services import WikiService from utils.log import logger @@ -26,9 +25,9 @@ class Inline(Plugin): def __init__( self, - wiki_service: WikiService = None, - assets_service: AssetsService = None, - search_service: SearchServices = None, + wiki_service: WikiService, + assets_service: AssetsService, + search_service: SearchServices, ): self.assets_service = assets_service self.wiki_service = wiki_service @@ -37,7 +36,7 @@ def __init__( self.refresh_task: List[Awaitable] = [] self.search_service = search_service - async def __async_init__(self): + async def initialize(self): # todo: 整合进 wiki 或者单独模块 从Redis中读取 async def task_weapons(): logger.info("Inline 模块正在获取武器列表") @@ -73,8 +72,7 @@ async def task_characters(): self.refresh_task.append(asyncio.create_task(task_characters())) @handler(InlineQueryHandler, block=False) - @error_callable - async def inline_query(self, update: Update, context: CallbackContext) -> None: + async def inline_query(self, update: Update, _: CallbackContext) -> None: user = update.effective_user ilq = cast(InlineQuery, update.inline_query) query = ilq.query @@ -100,7 +98,7 @@ async def inline_query(self, update: Update, context: CallbackContext) -> None: ) ) else: - if "查看武器列表并查询" == args[0]: + if args[0] == "查看武器列表并查询": for weapon in self.weapons_list: name = weapon["name"] icon = weapon["icon"] @@ -111,11 +109,11 @@ async def inline_query(self, update: Update, context: CallbackContext) -> None: description=f"查看武器列表并查询 {name}", thumb_url=icon, input_message_content=InputTextMessageContent( - f"/weapon@{context.bot.username} {name}", parse_mode=ParseMode.MARKDOWN_V2 + f"武器查询{name}", parse_mode=ParseMode.MARKDOWN_V2 ), ) ) - elif "查看角色攻略列表并查询" == args[0]: + elif args[0] == "查看角色攻略列表并查询": for character in self.characters_list: name = character["name"] icon = character["icon"] @@ -126,11 +124,11 @@ async def inline_query(self, update: Update, context: CallbackContext) -> None: description=f"查看角色攻略列表并查询 {name}", thumb_url=icon, input_message_content=InputTextMessageContent( - f"/strategy@{context.bot.username} {name}", parse_mode=ParseMode.MARKDOWN_V2 + f"角色攻略查询{name}", parse_mode=ParseMode.MARKDOWN_V2 ), ) ) - elif "查看角色培养素材列表并查询" == args[0]: + elif args[0] == "查看角色培养素材列表并查询": characters_list = await self.wiki_service.get_characters_name_list() for role_name in characters_list: results_list.append( @@ -139,7 +137,7 @@ async def inline_query(self, update: Update, context: CallbackContext) -> None: title=role_name, description=f"查看角色培养素材列表并查询 {role_name}", input_message_content=InputTextMessageContent( - f"/material@{context.bot.username} {role_name}", parse_mode=ParseMode.MARKDOWN_V2 + f"角色培养素材查询{role_name}", parse_mode=ParseMode.MARKDOWN_V2 ), ) ) diff --git a/plugins/system/start.py b/plugins/app/start.py similarity index 63% rename from plugins/system/start.py rename to plugins/app/start.py index df2afff9..d729357e 100644 --- a/plugins/system/start.py +++ b/plugins/app/start.py @@ -1,38 +1,25 @@ from typing import Optional -from genshin import Region, GenshinException from telegram import Update, ReplyKeyboardRemove, Message, User, WebAppInfo, ReplyKeyboardMarkup, KeyboardButton from telegram.constants import ChatAction from telegram.ext import CallbackContext, CommandHandler from telegram.helpers import escape_markdown -from core.base.redisdb import RedisDB from core.config import config -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError from core.plugin import handler, Plugin -from core.user import UserService -from core.user.error import UserNotFoundError -from modules.apihelper.client.components.verify import Verify -from modules.apihelper.error import ResponseException, APIHelperException -from plugins.genshin.sign import SignSystem, NeedChallenge -from plugins.genshin.verification import VerificationSystem -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from plugins.tools.challenge import ChallengeSystem, ChallengeSystemException +from plugins.tools.genshin import PlayerNotFoundError, CookiesNotFoundError, GenshinHelper +from plugins.tools.sign import SignSystem, NeedChallenge from utils.log import logger class StartPlugin(Plugin): - def __init__(self, user_service: UserService = None, cookies_service: CookiesService = None, redis: RedisDB = None): - self.cookies_service = cookies_service - self.user_service = user_service - self.sign_system = SignSystem(redis) - self.verification_system = VerificationSystem(redis) + def __init__(self, sign_system: SignSystem, challenge_system: ChallengeSystem, genshin_helper: GenshinHelper): + self.challenge_system = challenge_system + self.sign_system = sign_system + self.genshin_helper = genshin_helper @handler.command("start", block=False) - @error_callable - @restricts() async def start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message @@ -58,7 +45,7 @@ async def start(self, update: Update, context: CallbackContext) -> None: await self.process_validate(message, user, bot_username=context.bot.username) elif args[0] == "sign": logger.info("用户 %s[%s] 通过start命令 获取签到信息", user.full_name, user.id) - await self.gen_sign_button(message, user) + await self.get_sign_button(message, user) elif args[0].startswith("challenge_"): _data = args[0].split("_") _command = _data[1] @@ -73,40 +60,24 @@ async def start(self, update: Update, context: CallbackContext) -> None: await message.reply_markdown_v2(f"你好 {user.mention_markdown_v2()} {escape_markdown('!我是派蒙 !')}") @staticmethod - @restricts() async def unknown_command(update: Update, _: CallbackContext) -> None: await update.effective_message.reply_text("前面的区域,以后再来探索吧!") @staticmethod - @restricts() async def emergency_food(update: Update, _: CallbackContext) -> None: await update.effective_message.reply_text("派蒙才不是应急食品!") @handler(CommandHandler, command="ping", block=False) - @restricts() async def ping(self, update: Update, _: CallbackContext) -> None: await update.effective_message.reply_text("online! ヾ(✿゚▽゚)ノ") @handler(CommandHandler, command="reply_keyboard_remove", block=False) - @restricts() async def reply_keyboard_remove(self, update: Update, _: CallbackContext) -> None: await update.message.reply_text("移除远程键盘成功", reply_markup=ReplyKeyboardRemove()) - async def gen_sign_button(self, message: Message, user: User): - try: - client = await get_genshin_client(user.id) - await message.reply_chat_action(ChatAction.TYPING) - button = await self.sign_system.get_challenge_button(client.uid, user.id, callback=False) - if not button: - await message.reply_text("验证请求已过期。", allow_sending_without_reply=True) - return - await message.reply_text("请尽快点击下方按钮进行验证。", allow_sending_without_reply=True, reply_markup=button) - except (UserNotFoundError, CookiesNotFoundError): - logger.warning("用户 %s[%s] 账号信息未找到", user.full_name, user.id) - async def process_sign_validate(self, message: Message, user: User, validate: str): try: - client = await get_genshin_client(user.id) + client = await self.genshin_helper.get_genshin_client(user.id) await message.reply_chat_action(ChatAction.TYPING) _, challenge = await self.sign_system.get_challenge(client.uid) if not challenge: @@ -114,57 +85,29 @@ async def process_sign_validate(self, message: Message, user: User, validate: st return sign_text = await self.sign_system.start_sign(client, challenge=challenge, validate=validate) await message.reply_text(sign_text, allow_sending_without_reply=True) - except (UserNotFoundError, CookiesNotFoundError): + except (PlayerNotFoundError, CookiesNotFoundError): logger.warning("用户 %s[%s] 账号信息未找到", user.full_name, user.id) except NeedChallenge: await message.reply_text("回调错误,请重新签到", allow_sending_without_reply=True) async def process_validate(self, message: Message, user: User, bot_username: Optional[str] = None): - try: - client = await get_genshin_client(user.id) - if client.region != Region.CHINESE: - await message.reply_text("非法用户") - return - except UserNotFoundError: - await message.reply_text("用户未找到") - return - except CookiesNotFoundError: - await message.reply_text("检测到用户为UID绑定,无需认证") - return - try: - await client.get_genshin_notes() - except GenshinException as exc: - if exc.retcode != 1034: - raise exc - else: - await message.reply_text("账户正常,无需认证") - return await message.reply_text( "由于官方对第三方工具限制以及账户安全的考虑,频繁使用第三方工具会导致账号被风控并要求用过验证才能进行访问。\n" "如果出现频繁验证请求,建议暂停使用本Bot在内的第三方工具查询功能。\n" "在暂停使用期间依然出现频繁认证,建议修改密码以保护账号安全。" ) - verification = Verify(cookies=client.cookie_manager.cookies) try: - data = await verification.create() - challenge = data["challenge"] - gt = data["gt"] - logger.success("用户 %s[%s] 创建验证成功\ngt:%s\nchallenge%s", user.full_name, user.id, gt, challenge) - except ResponseException as exc: - logger.warning("用户 %s[%s] 创建验证失效 API返回 [%s]%s", user.full_name, user.id, exc.code, exc.message) - await message.reply_text(f"验证失败 错误信息为 [{exc.code}]{exc.message}") + uid, gt, challenge = await self.challenge_system.create_challenge(user.id, ajax=True) + except ChallengeSystemException as exc: + await message.reply_text(exc.message) return - try: - validate = await verification.ajax(referer="https://webstatic.mihoyo.com/", gt=gt, challenge=challenge) - if validate: - await verification.verify(challenge, validate) - logger.success("用户 %s[%s] 通过 ajax 验证", user.full_name, user.id) - await message.reply_text("验证成功") - return - except APIHelperException as exc: - logger.warning("用户 %s[%s] ajax 验证失效 错误信息为 %s", user.full_name, user.id, repr(exc)) - await self.verification_system.set_challenge(client.uid, gt, challenge) - url = f"{config.pass_challenge_user_web}/webapp?username={bot_username}&command=verify>={gt}&challenge={challenge}&uid={client.uid}" + if gt == "ajax": + await message.reply_text("验证成功") + return + url = ( + f"{config.pass_challenge_user_web}/webapp?" + f"username={bot_username}&command=verify>={gt}&challenge={challenge}&uid={uid}" + ) await message.reply_text( "请尽快在10秒内完成手动验证\n或发送 /web_cancel 取消操作", reply_markup=ReplyKeyboardMarkup.from_button( @@ -174,3 +117,15 @@ async def process_validate(self, message: Message, user: User, bot_username: Opt ) ), ) + + async def get_sign_button(self, message: Message, user: User): + try: + client = await self.genshin_helper.get_genshin_client(user.id) + await message.reply_chat_action(ChatAction.TYPING) + button = await self.sign_system.get_challenge_button(client.uid, user.id, callback=False) + if not button: + await message.reply_text("验证请求已过期。", allow_sending_without_reply=True) + return + await message.reply_text("请尽快点击下方按钮进行验证。", allow_sending_without_reply=True, reply_markup=button) + except (PlayerNotFoundError, CookiesNotFoundError): + logger.warning("用户 %s[%s] 账号信息未找到", user.full_name, user.id) diff --git a/plugins/app/webapp.py b/plugins/app/webapp.py new file mode 100644 index 00000000..376f92f0 --- /dev/null +++ b/plugins/app/webapp.py @@ -0,0 +1,80 @@ +from typing import Optional + +from pydantic import BaseModel +from telegram import ReplyKeyboardRemove, Update +from telegram.ext import CallbackContext, filters + +from core.plugin import Plugin, handler +from plugins.tools.challenge import ChallengeSystem, ChallengeSystemException +from utils.log import logger + + +class WebAppData(BaseModel): + path: str + data: Optional[dict] + code: int + message: str + + +class WebAppDataException(Exception): + def __init__(self, data): + self.data = data + super().__init__() + + +class WebApp(Plugin): + def __init__( + self, + challenge_system: ChallengeSystem, + ): + self.challenge_system = challenge_system + + @staticmethod + def de_web_app_data(data: str) -> WebAppData: + try: + return WebAppData.parse_raw(data) + except Exception as exc: + raise WebAppDataException(data) from exc + + @handler.message(filters=filters.StatusUpdate.WEB_APP_DATA, block=False) + async def app(self, update: Update, _: CallbackContext): + user = update.effective_user + message = update.effective_message + web_app_data = message.web_app_data + if web_app_data: + logger.info("用户 %s[%s] 触发 WEB_APP_DATA 请求", user.full_name, user.id) + result = self.de_web_app_data(web_app_data.data) + logger.debug( + "path[%s]\ndata[%s]\ncode[%s]\nmessage[%s]", result.path, result.data, result.code, result.message + ) + if result.code == 0: + if result.path == "verify": + validate = result.data.get("geetest_validate") + if validate is not None: + try: + status = await self.challenge_system.pass_challenge(user.id, validate=validate) + except ChallengeSystemException as exc: + await message.reply_text(exc.message, reply_markup=ReplyKeyboardRemove()) + return + if status: + await message.reply_text("验证通过", reply_markup=ReplyKeyboardRemove()) + await message.reply_text("非法请求", reply_markup=ReplyKeyboardRemove()) + return + else: + logger.warning( + "用户 %s[%s] WEB_APP_DATA 请求错误 [%s]%s", user.full_name, user.id, result.code, result.message + ) + if result.path == "verify": + await message.reply_text( + f"验证过程中出现问题 {result.message}\n如果继续遇到该问题,请打开米游社→我的角色中尝试手动通过验证", + reply_markup=ReplyKeyboardRemove(), + ) + else: + await message.reply_text(f"WebApp返回错误 {result.message}", reply_markup=ReplyKeyboardRemove()) + else: + logger.warning("用户 %s[%s] WEB_APP_DATA 非法数据", user.full_name, user.id) + + @handler.command("web_cancel", block=False) + async def web_cancel(self, update: Update, _: CallbackContext) -> None: + message = update.effective_message + await message.reply_text("取消操作", reply_markup=ReplyKeyboardRemove()) diff --git a/plugins/genshin/__init__.py b/plugins/genshin/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/genshin/abyss.py b/plugins/genshin/abyss.py index 5066680a..4a29d1e9 100644 --- a/plugins/genshin/abyss.py +++ b/plugins/genshin/abyss.py @@ -5,7 +5,6 @@ from functools import lru_cache, partial from typing import Any, Coroutine, List, Match, Optional, Tuple, Union -import ujson as json from arkowrapper import ArkoWrapper from genshin import Client, GenshinException from pytz import timezone @@ -14,21 +13,23 @@ from telegram.ext import CallbackContext, filters from telegram.helpers import create_deep_linked_url -from core.base.assets import AssetsService -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError, TooManyRequestPublicCookies -from core.cookies.services import CookiesService +from core.dependence.assets import AssetsService from core.plugin import Plugin, handler -from core.template import TemplateService -from core.template.models import RenderGroupResult, RenderResult -from core.user import UserService -from core.user.error import UserNotFoundError +from core.services.cookies.error import TooManyRequestPublicCookies +from core.services.template.models import RenderGroupResult, RenderResult +from core.services.template.services import TemplateService from metadata.genshin import game_id_to_role_id -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import async_re_sub, get_genshin_client, get_public_genshin_client +from plugins.tools.genshin import GenshinHelper, CookiesNotFoundError, PlayerNotFoundError +from utils.helpers import async_re_sub from utils.log import logger +try: + import ujson as jsonlib + +except ImportError: + import json as jsonlib + + TZ = timezone("Asia/Shanghai") cmd_pattern = r"^/abyss\s*((?:\d+)|(?:all))?\s*(pre)?" msg_pattern = r"^深渊数据((?:查询)|(?:总览))(上期)?\D?(\d*)?.*?$" @@ -56,9 +57,8 @@ def get_args(text: str) -> Tuple[int, bool, bool]: except ValueError: floor = 0 return floor, result[0] == "all", bool(result[1]) - else: - result = re.match(msg_pattern, text).groups() - return int(result[2] or 0), result[0] == "总览", result[1] == "上期" + result = re.match(msg_pattern, text).groups() + return int(result[2] or 0), result[0] == "总览", result[1] == "上期" class AbyssUnlocked(Exception): @@ -73,28 +73,25 @@ class AbyssNotFoundError(Exception): """如果查询别人,是无法找到队伍详细,只有数据统计""" -class Abyss(Plugin, BasePlugin): +class AbyssPlugin(Plugin): """深渊数据查询""" def __init__( self, - user_service: UserService = None, - cookies_service: CookiesService = None, - template_service: TemplateService = None, - assets_service: AssetsService = None, + template: TemplateService, + helper: GenshinHelper, + assets_service: AssetsService, ): - self.template_service = template_service - self.cookies_service = cookies_service - self.user_service = user_service + self.template_service = template + self.helper = helper self.assets_service = assets_service @handler.command("abyss", block=False) @handler.message(filters.Regex(msg_pattern), block=False) - @restricts() - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message + uid: Optional[int] = None # 若查询帮助 if (message.text.startswith("/") and "help" in message.text) or "帮助" in message.text: @@ -107,7 +104,7 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: "深渊数据查询\n深渊数据查询上期第12层\n深渊数据总览上期", parse_mode=ParseMode.HTML, ) - logger.info(f"用户 {user.full_name}[{user.id}] 查询[bold]深渊挑战数据[/bold]帮助", extra={"markup": True}) + logger.info("用户 %s[%s] 查询[bold]深渊挑战数据[/bold]帮助", user.full_name, user.id, extra={"markup": True}) return # 解析参数 @@ -116,42 +113,45 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: if floor > 12 or floor < 0: reply_msg = await message.reply_text("深渊层数输入错误,请重新输入。支持的参数为: 1-12 或 all") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_msg.chat_id, reply_msg.message_id, 10) - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) + self.add_delete_message_job(reply_msg) + self.add_delete_message_job(message) return - elif 0 < floor < 9: + if 0 < floor < 9: previous = False logger.info( - f"用户 {user.full_name}[{user.id}] [bold]深渊挑战数据[/bold]请求: " - f"floor={floor} total={total} previous={previous}", + "用户 %s[%s] [bold]深渊挑战数据[/bold]请求: floor=%s total=%s previous=%s", + user.full_name, + user.id, + floor, + total, + previous, extra={"markup": True}, ) try: try: - client = await get_genshin_client(user.id) + client = await self.helper.get_genshin_client(user.id) await client.get_record_cards() uid = client.uid except CookiesNotFoundError: - client, uid = await get_public_genshin_client(user.id) - except UserNotFoundError: # 若未找到账号 + client, uid = await self.helper.get_public_genshin_client(user.id) + except PlayerNotFoundError: # 若未找到账号 buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message.chat_id) + self.add_delete_message_job(message.chat_id) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return except TooManyRequestPublicCookies: - reply_msg = await message.reply_text("查询次数太多,请您稍后重试") + reply_message = await message.reply_text("查询次数太多,请您稍后重试") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_msg.chat_id, reply_msg.message_id, 10) - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) + self.add_delete_message_job(reply_message.chat_id) + self.add_delete_message_job(message.chat_id) return async def reply_message_func(content: str) -> None: @@ -170,10 +170,9 @@ async def reply_message_func(content: str) -> None: try: images = await self.get_rendered_pic(client, uid, floor, total, previous) except GenshinException as exc: - if exc.retcode == 1034: - if client.uid != uid: - await message.reply_text("出错了呜呜呜 ~ 请稍后重试") - return + if exc.retcode == 1034 and client.uid != uid: + await message.reply_text("出错了呜呜呜 ~ 请稍后重试") + return raise exc except AbyssUnlocked: # 若深渊未解锁 await reply_message_func("还未解锁深渊哦~") @@ -201,7 +200,7 @@ async def reply_message_func(content: str) -> None: if reply_text is not None: await reply_text.delete() - logger.info(f"用户 {user.full_name}[{user.id}] [bold]深渊挑战数据[/bold]: 成功发送图片", extra={"markup": True}) + logger.info("用户 %s[%s] [bold]深渊挑战数据[/bold]: 成功发送图片", user.full_name, user.id, extra={"markup": True}) async def get_rendered_pic( self, client: Client, uid: int, floor: int, total: bool, previous: bool @@ -275,7 +274,7 @@ def json_encoder(value): if total: avatars = await client.get_genshin_characters(uid, lang="zh-cn") render_data["avatar_data"] = {i.id: i.constellation for i in avatars} - data = json.loads(result) + data = jsonlib.loads(result) render_data["data"] = data render_inputs: List[Tuple[int, Coroutine[Any, Any, RenderResult]]] = [] @@ -312,39 +311,38 @@ def floor_task(floor_index: int): return await asyncio.gather(*render_group_inputs) - elif floor < 1: - render_data["data"] = json.loads(result) + if floor < 1: + render_data["data"] = jsonlib.loads(result) return [ await self.template_service.render( "genshin/abyss/overview.html", render_data, viewport={"width": 750, "height": 580} ) ] + num_dic = { + "0": "", + "1": "一", + "2": "二", + "3": "三", + "4": "四", + "5": "五", + "6": "六", + "7": "七", + "8": "八", + "9": "九", + } + if num := num_dic.get(str(floor)): + render_data["floor-num"] = num else: - num_dic = { - "0": "", - "1": "一", - "2": "二", - "3": "三", - "4": "四", - "5": "五", - "6": "六", - "7": "七", - "8": "八", - "9": "九", - } - if num := num_dic.get(str(floor)): - render_data["floor-num"] = num - else: - render_data["floor-num"] = f"十{num_dic.get(str(floor % 10))}" - floors = json.loads(result)["floors"] - if (floor_data := list(filter(lambda x: x["floor"] == floor, floors))) is None: - return None - avatars = await client.get_genshin_characters(uid, lang="zh-cn") - render_data["avatar_data"] = {i.id: i.constellation for i in avatars} - render_data["floor"] = floor_data[0] - render_data["total_stars"] = f"{floor_data[0]['stars']}/{floor_data[0]['max_stars']}" - return [ - await self.template_service.render( - "genshin/abyss/floor.html", render_data, viewport={"width": 690, "height": 500} - ) - ] + render_data["floor-num"] = f"十{num_dic.get(str(floor % 10))}" + floors = jsonlib.loads(result)["floors"] + if (floor_data := list(filter(lambda x: x["floor"] == floor, floors))) is None: + return None + avatars = await client.get_genshin_characters(uid, lang="zh-cn") + render_data["avatar_data"] = {i.id: i.constellation for i in avatars} + render_data["floor"] = floor_data[0] + render_data["total_stars"] = f"{floor_data[0]['stars']}/{floor_data[0]['max_stars']}" + return [ + await self.template_service.render( + "genshin/abyss/floor.html", render_data, viewport={"width": 690, "height": 500} + ) + ] diff --git a/plugins/genshin/abyss_team.py b/plugins/genshin/abyss_team.py index 24f987da..bc05e435 100644 --- a/plugins/genshin/abyss_team.py +++ b/plugins/genshin/abyss_team.py @@ -1,54 +1,50 @@ -from telegram import Update, InlineKeyboardMarkup, InlineKeyboardButton +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.constants import ChatAction -from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters +from telegram.ext import CallbackContext, filters from telegram.helpers import create_deep_linked_url -from core.base.assets import AssetsService -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError +from core.dependence.assets import AssetsService from core.plugin import Plugin, handler -from core.template import TemplateService -from core.user import UserService -from core.user.error import UserNotFoundError +from core.services.template.services import TemplateService from metadata.shortname import roleToId from modules.apihelper.client.components.abyss import AbyssTeam as AbyssTeamClient -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from plugins.tools.genshin import GenshinHelper, CookiesNotFoundError, PlayerNotFoundError from utils.log import logger +__all__ = ("AbyssTeamPlugin",) -class AbyssTeam(Plugin, BasePlugin): + +class AbyssTeamPlugin(Plugin): """深境螺旋推荐配队查询""" def __init__( - self, user_service: UserService = None, template_service: TemplateService = None, assets: AssetsService = None + self, + template: TemplateService, + helper: GenshinHelper, + assets_service: AssetsService, ): - self.template_service = template_service - self.user_service = user_service - self.assets_service = assets + self.template_service = template + self.helper = helper self.team_data = AbyssTeamClient() + self.assets_service = assets_service - @handler(CommandHandler, command="abyss_team", block=False) - @handler(MessageHandler, filters=filters.Regex("^深渊推荐配队(.*)"), block=False) - @restricts() - @error_callable + @handler.command("abyss_team", block=False) + @handler.message(filters.Regex("^深渊推荐配队(.*)"), block=False) async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message - logger.info(f"用户 {user.full_name}[{user.id}] 查深渊推荐配队命令请求") + logger.info("用户 %s[%s] 查深渊推荐配队命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id) - except (CookiesNotFoundError, UserNotFoundError): + client = await self.helper.get_genshin_client(user.id) + except (CookiesNotFoundError, PlayerNotFoundError): buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return diff --git a/plugins/genshin/avatar_list.py b/plugins/genshin/avatar_list.py index a2d1df03..b251aa01 100644 --- a/plugins/genshin/avatar_list.py +++ b/plugins/genshin/avatar_list.py @@ -1,33 +1,30 @@ """练度统计""" import asyncio -from typing import Iterable, List, Optional, Sequence +from typing import List, Optional, Sequence from aiohttp import ClientConnectorError from arkowrapper import ArkoWrapper from enkanetwork import Assets as EnkaAssets, EnkaNetworkAPI, VaildateUIDError, HTTPException, EnkaPlayerNotFound from genshin import Client, GenshinException, InvalidCookies from genshin.models import CalculatorCharacterDetails, CalculatorTalent, Character -from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Message, Update, User +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update, User from telegram.constants import ChatAction, ParseMode from telegram.ext import CallbackContext, filters from telegram.helpers import create_deep_linked_url -from core.base.assets import AssetsService -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin from core.config import config -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService +from core.dependence.assets import AssetsService +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.template import TemplateService -from core.template.models import FileType -from core.user.error import UserNotFoundError +from core.services.cookies import CookiesService +from core.services.players import PlayersService +from core.services.players.services import PlayerInfoService +from core.services.template.models import FileType +from core.services.template.services import TemplateService from metadata.genshin import AVATAR_DATA, NAMECARD_DATA from modules.wiki.base import Model -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from plugins.tools.genshin import CookiesNotFoundError, GenshinHelper, PlayerNotFoundError, CharacterDetails from utils.enkanetwork import RedisCache -from utils.helpers import get_genshin_client from utils.log import logger from utils.patch.aiohttp import AioHttpTimeoutException @@ -49,18 +46,22 @@ class AvatarData(Model): def sum_of_skills(self) -> int: total_level = 0 - for skilldata in self.skills: - total_level += skilldata.skill.level + for skill_data in self.skills: + total_level += skill_data.skill.level return total_level -class AvatarListPlugin(Plugin, BasePlugin): +class AvatarListPlugin(Plugin): def __init__( self, + player_service: PlayersService = None, cookies_service: CookiesService = None, assets_service: AssetsService = None, template_service: TemplateService = None, redis: RedisDB = None, + helper: GenshinHelper = None, + character_details: CharacterDetails = None, + player_info_service: PlayerInfoService = None, ) -> None: self.cookies_service = cookies_service self.assets_service = assets_service @@ -68,31 +69,36 @@ def __init__( self.enka_client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent) self.enka_client.set_cache(RedisCache(redis.client, key="plugin:avatar_list:enka_network", ttl=60 * 60 * 3)) self.enka_assets = EnkaAssets(lang="chs") + self.helper = helper + self.character_details = character_details + self.player_service = player_service + self.player_info_service = player_info_service - async def get_user_client(self, user: User, message: Message, context: CallbackContext) -> Optional[Client]: + async def get_user_client(self, update: Update, context: CallbackContext) -> Optional[Client]: + message = update.effective_message + user = update.effective_user try: - return await get_genshin_client(user.id) - except UserNotFoundError: # 若未找到账号 + return await self.helper.get_genshin_client(user.id) + except PlayerNotFoundError: # 若未找到账号 buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) except CookiesNotFoundError: buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): - reply_msg = await message.reply_text( + reply_message = await message.reply_text( "此功能需要绑定cookie后使用,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons), parse_mode=ParseMode.HTML, ) - self._add_delete_message_job(context, reply_msg.chat_id, reply_msg.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text( "此功能需要绑定cookie后使用,请先私聊派蒙进行绑定", @@ -101,21 +107,8 @@ async def get_user_client(self, user: User, message: Message, context: CallbackC ) async def get_avatar_data(self, character: Character, client: Client) -> Optional["AvatarData"]: - for _ in range(5): - try: - detail = await client.get_character_details(character) - except Exception as exc: # pylint: disable=W0703 - if isinstance(exc, GenshinException) and "Too Many Requests" in exc.msg: - await asyncio.sleep(0.2) - continue - if character.name == "旅行者": - logger.debug("解析旅行者数据时遇到了错误:%s", str(exc)) - return None - raise exc - else: - break - else: - logger.warning("解析[bold]%s[/]的数据时遇到了 Too Many Requests 错误", character.name, extra={"markup": True}) + detail = await self.character_details.get_character_details(client, character) + if detail is None: return None if character.id == 10000005: # 针对男草主 talents = [] @@ -197,8 +190,9 @@ async def get_final_data(self, client: Client, characters: Sequence[Character], ArkoWrapper(choices) .map(lambda x: next(filter(lambda y: y["name"].split("·")[0] == x.name, NAMECARD_DATA.values()), None)) .filter(lambda x: x) - .map(lambda x: x["id"]) + .map(lambda x: int(x["id"])) ) + # noinspection PyTypeChecker name_card = (await self.assets_service.namecard(name_card_choices[0]).navbar()).as_uri() avatar = (await self.assets_service.avatar(cid := choices[0].id).icon()).as_uri() nickname = update.effective_user.full_name @@ -208,21 +202,32 @@ async def get_final_data(self, client: Client, characters: Sequence[Character], rarity = {k: v["rank"] for k, v in AVATAR_DATA.items()}[str(cid)] return name_card, avatar, nickname, rarity - async def get_default_final_data(self, characters: Sequence[Character], update: Update): - nickname = update.effective_user.full_name - rarity = 5 - # 须弥·正明 - name_card = (await self.assets_service.namecard(210132).navbar()).as_uri() - if traveller := next(filter(lambda x: x.id in [10000005, 10000007], characters), None): - avatar = (await self.assets_service.avatar(traveller.id).icon()).as_uri() - else: - avatar = (await self.assets_service.avatar(10000005).icon()).as_uri() + async def get_default_final_data(self, player_id: int, characters: Sequence[Character], user: User): + player = await self.player_service.get(user.id, player_id) + player_info = await self.player_info_service.get(player) + nickname = user.full_name + name_card: Optional[str] = None + avatar: Optional[str] = None + rarity: int = 5 + if player_info is not None: + if player_info.nickname is not None: + nickname = player_info.nickname + if player_info.name_card is not None: + name_card = (await self.assets_service.namecard(player_info.name_card).navbar()).as_uri() + if player_info.hand_image is not None: + avatar = (await self.assets_service.avatar(player_info.hand_image).icon()).as_uri() + rarity = {k: v["rank"] for k, v in AVATAR_DATA.items()}[str(player_info.hand_image)] + if name_card is not None: # 须弥·正明 + name_card = (await self.assets_service.namecard(210132).navbar()).as_uri() + if avatar is not None: + if traveller := next(filter(lambda x: x.id in [10000005, 10000007], characters), None): + avatar = (await self.assets_service.avatar(traveller.id).icon()).as_uri() + else: + avatar = (await self.assets_service.avatar(10000005).icon()).as_uri() return name_card, avatar, nickname, rarity @handler.command("avatars", filters.Regex(r"^/avatars\s*(?:(\d+)|(all))?$"), block=False) @handler.message(filters.Regex(r"^(全部)?练度统计$"), block=False) - @restricts(30) - @error_callable async def avatar_list(self, update: Update, context: CallbackContext): user = update.effective_user message = update.effective_message @@ -233,7 +238,7 @@ async def avatar_list(self, update: Update, context: CallbackContext): logger.info("用户 %s[%s] [bold]练度统计[/bold]: all=%s", user.full_name, user.id, all_avatars, extra={"markup": True}) - client = await self.get_user_client(user, message, context) + client = await self.get_user_client(update, context) if not client: return @@ -251,22 +256,22 @@ async def avatar_list(self, update: Update, context: CallbackContext): logger.warning("用户 %s[%s] 无法请求角色数数据 API返回信息为 [%s]%s", user.full_name, user.id, exc.retcode, exc.original) reply_message = await message.reply_text("出错了呜呜呜 ~ 当前访问令牌无法请求角色数数据,请尝试重新获取Cookie。") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return except GenshinException as e: await notice.delete() if e.retcode == -502002: reply_message = await message.reply_html("请先在米游社中使用一次养成计算器后再使用此功能~") - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 20) + self.add_delete_message_job(reply_message, delay=20) return raise e try: name_card, avatar, nickname, rarity = await self.get_final_data(client, characters, update) - except Exception as exc: + except Exception as exc: # pylint: disable=W0703 logger.error("卡片信息请求失败 %s", str(exc)) - name_card, avatar, nickname, rarity = await self.get_default_final_data(characters, update) + name_card, avatar, nickname, rarity = await self.get_default_final_data(client.uid, characters, user) render_data = { "uid": client.uid, # 玩家uid @@ -291,7 +296,7 @@ async def avatar_list(self, update: Update, context: CallbackContext): file_type=FileType.DOCUMENT if as_document else FileType.PHOTO, ttl=30 * 24 * 60 * 60, ) - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 5) + self.add_delete_message_job(notice, delay=5) if as_document: await image.reply_document(message, filename="练度统计.png") else: diff --git a/plugins/genshin/birthday.py b/plugins/genshin/birthday.py index 000cad97..072d45c0 100644 --- a/plugins/genshin/birthday.py +++ b/plugins/genshin/birthday.py @@ -5,28 +5,21 @@ from genshin import Client, GenshinException from genshin.client.routes import Route from genshin.utility import recognize_genshin_server -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import Update, InlineKeyboardMarkup, InlineKeyboardButton from telegram.constants import ParseMode -from telegram.ext import CommandHandler, CallbackContext, MessageHandler -from telegram.ext import filters +from telegram.ext import filters, MessageHandler, CommandHandler, CallbackContext from telegram.helpers import create_deep_linked_url -from core.baseplugin import BasePlugin -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError +from core.basemodel import RegionEnum from core.plugin import Plugin, handler -from core.user import UserService -from core.user.error import UserNotFoundError +from core.services.cookies import CookiesService +from core.services.users.services import UserService from metadata.genshin import AVATAR_DATA from metadata.shortname import roleToId, roleToName from modules.apihelper.client.components.calendar import Calendar -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from plugins.tools.genshin import GenshinHelper from utils.genshin import fetch_hk4e_token_by_cookie, recognize_genshin_game_biz -from utils.helpers import get_genshin_client from utils.log import logger -from utils.models.base import RegionEnum BIRTHDAY_URL = Route( "https://hk4e-api.mihoyo.com/event/birthdaystar/account/post_my_draw", @@ -40,18 +33,20 @@ def rm_starting_str(string, starting): return string -class BirthdayPlugin(Plugin, BasePlugin): +class BirthdayPlugin(Plugin): """Birthday.""" def __init__( self, - user_service: UserService = None, - cookie_service: CookiesService = None, + user_service: UserService, + helper: GenshinHelper, + cookie_service: CookiesService, ): """Load Data.""" self.birthday_list = {} self.user_service = user_service self.cookie_service = cookie_service + self.helper = helper async def __async_init__(self): self.birthday_list = await Calendar.async_gen_birthday_list() @@ -65,10 +60,8 @@ async def get_today_birthday(self) -> List[str]: ) return (self.birthday_list.get(key, [])).copy() - @handler(CommandHandler, command="birthday", block=False) - @restricts() - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> None: + @handler.command(command="birthday", block=False) + async def command_start(self, update: Update, _: CallbackContext) -> None: message = update.effective_message user = update.effective_user key = ( @@ -76,10 +69,11 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: + "_" + rm_starting_str(datetime.now().strftime("%d"), "0") ) - args = get_args(context) + args = self.get_args() + if len(args) >= 1: msg = args[0] - logger.info(f"用户 {user.full_name}[{user.id}] 查询角色生日命令请求 || 参数 {msg}") + logger.info("用户 %s[%s] 查询角色生日命令请求 || 参数 %s", user.full_name, user.id, msg) if re.match(r"\d{1,2}.\d{1,2}", msg): try: month = rm_starting_str(re.findall(r"\d+", msg)[0], "0") @@ -91,9 +85,7 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: except IndexError: text = "请输入正确的日期格式,如1-1,或输入正确的角色名称。" reply_message = await message.reply_text(text) - if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + else: try: if msg == "派蒙": @@ -106,22 +98,19 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: birthday = AVATAR_DATA[aid]["birthday"] text = f"{name} 的生日是 {birthday[0]}月{birthday[1]}日 哦~" reply_message = await message.reply_text(text) - if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + except KeyError: reply_message = await message.reply_text("请输入正确的日期格式,如1-1,或输入正确的角色名称。") - if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + else: - logger.info(f"用户 {user.full_name}[{user.id}] 查询今日角色生日列表") + logger.info("用户 %s[%s] 查询今日角色生日列表", user.full_name, user.id) today_list = await self.get_today_birthday() text = f"今天是 {'、'.join(today_list)} 的生日哦~" if today_list else "今天没有角色过生日哦~" reply_message = await message.reply_text(text) - if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + + if filters.ChatType.GROUPS.filter(reply_message): + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) @staticmethod async def get_card(client: Client, role_id: int) -> None: @@ -147,8 +136,6 @@ def role_to_id(name: str) -> Optional[int]: @handler(CommandHandler, command="birthday_card", block=False) @handler(MessageHandler, filters=filters.Regex("^领取角色生日画片$"), block=False) - @restricts() - @error_callable async def command_birthday_card_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user @@ -157,31 +144,11 @@ async def command_birthday_card_start(self, update: Update, context: CallbackCon if not today_list: reply_message = await message.reply_text("今天没有角色过生日哦~") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return - try: - client = await get_genshin_client(user.id) - if client.region == RegionEnum.HOYOLAB: - text = "此功能当前只支持国服账号哦~" - else: - await fetch_hk4e_token_by_cookie(client) - for name in today_list.copy(): - if role_id := self.role_to_id(name): - try: - await self.get_card(client, role_id) - except GenshinException as e: - if e.retcode in {-512008, -512009}: # 未过生日、已领取过 - today_list.remove(name) - if today_list: - text = f"成功领取了 {'、'.join(today_list)} 的生日画片~" - else: - text = "没有领取到生日画片哦 ~ 可能是已经领取过了" - reply_message = await message.reply_text(text) - if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) - except (UserNotFoundError, CookiesNotFoundError): + client = await self.helper.get_genshin_client(user.id) + if client is None: buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_msg = await message.reply_text( @@ -189,11 +156,31 @@ async def command_birthday_card_start(self, update: Update, context: CallbackCon reply_markup=InlineKeyboardMarkup(buttons), parse_mode=ParseMode.HTML, ) - self._add_delete_message_job(context, reply_msg.chat_id, reply_msg.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_msg.chat_id, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text( "此功能需要绑定cookie后使用,请先私聊派蒙进行绑定", parse_mode=ParseMode.HTML, reply_markup=InlineKeyboardMarkup(buttons), ) + return + if client.region == RegionEnum.HOYOLAB: + text = "此功能当前只支持国服账号哦~" + else: + await fetch_hk4e_token_by_cookie(client) + for name in today_list.copy(): + if role_id := self.role_to_id(name): + try: + await self.get_card(client, role_id) + except GenshinException as e: + if e.retcode in {-512008, -512009}: # 未过生日、已领取过 + today_list.remove(name) + if today_list: + text = f"成功领取了 {'、'.join(today_list)} 的生日画片~" + else: + text = "没有领取到生日画片哦 ~ 可能是已经领取过了" + reply_message = await message.reply_text(text) + if filters.ChatType.GROUPS.filter(reply_message): + self.add_delete_message_job(message.chat_id) + self.add_delete_message_job(reply_message.chat_id) diff --git a/plugins/genshin/calendar.py b/plugins/genshin/calendar.py index 19176ad3..3a9d3f78 100644 --- a/plugins/genshin/calendar.py +++ b/plugins/genshin/calendar.py @@ -5,14 +5,11 @@ from telegram.constants import ChatAction from telegram.ext import CallbackContext, MessageHandler, filters -from core.base.assets import AssetsService -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin +from core.dependence.assets import AssetsService +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.template import TemplateService +from core.services.template.services import TemplateService from modules.apihelper.client.components.calendar import Calendar -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger try: @@ -21,14 +18,14 @@ import json as jsonlib -class CalendarPlugin(Plugin, BasePlugin): +class CalendarPlugin(Plugin): """活动日历查询""" def __init__( self, - template_service: TemplateService = None, - assets_service: AssetsService = None, - redis: RedisDB = None, + template_service: TemplateService, + assets_service: AssetsService, + redis: RedisDB, ): self.template_service = template_service self.assets_service = assets_service @@ -46,8 +43,6 @@ async def _fetch_data(self) -> Dict: @handler.command("calendar", block=False) @handler(MessageHandler, filters=filters.Regex(r"^(活动)+(日历|日历列表)$"), block=False) - @restricts() - @error_callable async def command_start(self, update: Update, _: CallbackContext) -> None: user = update.effective_user message = update.effective_message diff --git a/plugins/genshin/daily/material.py b/plugins/genshin/daily/material.py index 311bce22..bec87ccd 100644 --- a/plugins/genshin/daily/material.py +++ b/plugins/genshin/daily/material.py @@ -24,19 +24,12 @@ from telegram.error import RetryAfter, TimedOut from telegram.ext import CallbackContext -from core.base.assets import AssetsCouldNotFound, AssetsService, AssetsServiceType -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError +from core.dependence.assets import AssetsCouldNotFound, AssetsService, AssetsServiceType from core.plugin import Plugin, handler -from core.template import TemplateService -from core.template.models import FileType, RenderGroupResult -from core.user.error import UserNotFoundError +from core.services.template.models import FileType, RenderGroupResult +from core.services.template.services import TemplateService from metadata.genshin import AVATAR_DATA, HONEY_DATA -from utils.bot import get_args -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from plugins.tools.genshin import GenshinHelper, PlayerNotFoundError, CookiesNotFoundError, CharacterDetails from utils.log import logger INTERVAL = 1 @@ -99,18 +92,26 @@ def all_substrings(string: str) -> Iterator[str]: return result -class DailyMaterial(Plugin, BasePlugin): +class DailyMaterial(Plugin): """每日素材表""" data: DATA_TYPE locks: Tuple[Lock] = (Lock(), Lock()) - def __init__(self, assets: AssetsService, template_service: TemplateService): + def __init__( + self, + assets: AssetsService, + template_service: TemplateService, + helper: GenshinHelper, + character_details: CharacterDetails, + ): self.assets_service = assets self.template_service = template_service + self.helper = helper + self.character_details = character_details self.client = AsyncClient() - async def __async_init__(self): + async def initialize(self): """插件在初始化时,会检查一下本地是否缓存了每日素材的数据""" data = None @@ -128,30 +129,10 @@ async def task_daily(): data = json.loads(await file.read()) self.data = data - @staticmethod - async def _get_skills_data(client: Client, character: Character) -> Optional[List[int]]: - """获取角色技能的数据""" - for _ in range(5): - try: - detail = await client.get_character_details(character) - except Exception as e: # pylint: disable=W0703 - if isinstance(e, GenshinException): - # 如果是 Too Many Requests 异常,则等待一段时间后重试 - if "Too Many Requests" in e.msg: - await asyncio.sleep(0.2) - continue - # 如果是其他异常,则直接抛出 - raise e - else: - break - else: - # 如果重试了5次都失败了,则直接返回 None - logger.warning( - "daily_material 解析角色 id 为 [bold]%s[/]的数据时遇到了 Too Many Requests 错误", character.id, extra={"markup": True} - ) + async def _get_skills_data(self, client: Client, character: Character) -> Optional[List[int]]: + detail = await self.character_details.get_character_details(client, character) + if detail is None: return None - # 不用针对旅行者、草主进行特殊处理,因为输入数据不会有旅行者。 - # 不用计算命座加成,因为这个是展示天赋升级情况,10 级为最高。计算命座会引起混淆。 talents = [t for t in detail.talents if t.type in ["attack", "skill", "burst"]] return [t.level for t in talents] @@ -160,7 +141,7 @@ async def _get_data_from_user(self, user: User) -> Tuple[Optional[Client], Dict[ user_data = {"avatar": [], "weapon": []} try: logger.debug("尝试获取已绑定的原神账号") - client = await get_genshin_client(user.id) + client = await self.helper.get_genshin_client(user.id) logger.debug("获取账号数据成功: UID=%s", client.uid) characters = await client.get_genshin_characters(client.uid) for character in characters: @@ -195,7 +176,7 @@ async def _get_data_from_user(self, user: User) -> Tuple[Optional[Client], Dict[ c_path=(await self.assets_service.avatar(cid).side()).as_uri(), ) ) - except (UserNotFoundError, CookiesNotFoundError): + except (PlayerNotFoundError, CookiesNotFoundError): logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) except InvalidCookies: logger.info("用户 %s[%s] 所绑定的账号信息已失效", user.full_name, user.id) @@ -206,12 +187,10 @@ async def _get_data_from_user(self, user: User) -> Tuple[Optional[Client], Dict[ return None, user_data @handler.command("daily_material", block=False) - @restricts(restricts_time_of_groups=20, without_overlapping=True) - @error_callable async def daily_material(self, update: Update, context: CallbackContext): user = update.effective_user message = update.effective_message - args = get_args(context) + args = self.get_args(context) now = datetime.now() try: @@ -235,7 +214,7 @@ async def daily_material(self, update: Update, context: CallbackContext): if self.locks[0].locked(): # 若检测到了第一个锁:正在下载每日素材表的数据 notice = await message.reply_text("派蒙正在摘抄每日素材表,以后再来探索吧~") - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 5) + self.add_delete_message_job(notice, delay=5) return if self.locks[1].locked(): # 若检测到了第二个锁:正在下载角色、武器、材料的图标 @@ -281,7 +260,7 @@ async def daily_material(self, update: Update, context: CallbackContext): except GenshinException as e: if e.retcode == -502002: calculator_sync = False # 发现角色养成计算器没启用 设置状态为 False 并防止下次继续获取 - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 5) + self.add_delete_message_job(notice, delay=5) await notice.edit_text( "获取角色天赋信息失败,如果想要显示角色天赋信息,请先在米游社/HoYoLab中使用一次养成计算器后再使用此功能~", parse_mode=ParseMode.HTML, @@ -350,7 +329,7 @@ async def daily_material(self, update: Update, context: CallbackContext): ), ) - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 5) + self.add_delete_message_job(notice, delay=5) await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) character_img_data.filename = f"{title}可培养角色.png" @@ -361,7 +340,6 @@ async def daily_material(self, update: Update, context: CallbackContext): logger.debug("角色、武器培养素材图发送成功") @handler.command("refresh_daily_material", block=False) - @bot_admins_rights_check async def refresh(self, update: Update, context: CallbackContext): user = update.effective_user message = update.effective_message @@ -369,11 +347,11 @@ async def refresh(self, update: Update, context: CallbackContext): logger.info("用户 {%s}[%s] 刷新[bold]每日素材[/]缓存命令", user.full_name, user.id, extra={"markup": True}) if self.locks[0].locked(): notice = await message.reply_text("派蒙还在抄每日素材表呢,我有在好好工作哦~") - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 10) + self.add_delete_message_job(notice, delay=10) return if self.locks[1].locked(): notice = await message.reply_text("派蒙正在搬运每日素材图标,在努力工作呢!") - self._add_delete_message_job(context, notice.chat_id, notice.message_id, 10) + self.add_delete_message_job(notice, delay=10) return async with self.locks[1]: # 锁住第二把锁 notice = await message.reply_text("派蒙正在重新摘抄每日素材表,请稍等~", parse_mode=ParseMode.HTML) diff --git a/plugins/genshin/daily_note.py b/plugins/genshin/daily_note.py index e88cd168..021fd1ab 100644 --- a/plugins/genshin/daily_note.py +++ b/plugins/genshin/daily_note.py @@ -1,52 +1,45 @@ import datetime -import os +from datetime import datetime from typing import Optional +import genshin from genshin import DataNotPublic -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.constants import ChatAction -from telegram.ext import CommandHandler, MessageHandler, ConversationHandler, filters, CallbackContext +from telegram.ext import ConversationHandler, filters, CallbackContext from telegram.helpers import create_deep_linked_url -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService from core.plugin import Plugin, handler -from core.template.services import RenderResult, TemplateService -from core.user.error import UserNotFoundError -from core.user.services import UserService -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from core.services.template.models import RenderResult +from core.services.template.services import TemplateService +from plugins.tools.genshin import GenshinHelper, CookiesNotFoundError, PlayerNotFoundError from utils.log import logger +__all__ = ("DailyNotePlugin",) -class DailyNote(Plugin, BasePlugin): + +class DailyNotePlugin(Plugin): """每日便签""" def __init__( self, - user_service: UserService = None, - cookies_service: CookiesService = None, - template_service: TemplateService = None, + template: TemplateService, + helper: GenshinHelper, ): - self.template_service = template_service - self.cookies_service = cookies_service - self.user_service = user_service - self.current_dir = os.getcwd() + self.template_service = template + self.helper = helper - async def _get_daily_note(self, client) -> RenderResult: + async def _get_daily_note(self, client: genshin.Client) -> RenderResult: daily_info = await client.get_genshin_notes(client.uid) - day = datetime.datetime.now().strftime("%m-%d %H:%M") + " 星期" + "一二三四五六日"[datetime.datetime.now().weekday()] + + day = datetime.now().strftime("%m-%d %H:%M") + " 星期" + "一二三四五六日"[datetime.now().weekday()] resin_recovery_time = ( daily_info.resin_recovery_time.strftime("%m-%d %H:%M") if daily_info.max_resin - daily_info.current_resin else None ) realm_recovery_time = ( - (datetime.datetime.now().astimezone() + daily_info.remaining_realm_currency_recovery_time).strftime( - "%m-%d %H:%M" - ) + (datetime.now().astimezone() + daily_info.remaining_realm_currency_recovery_time).strftime("%m-%d %H:%M") if daily_info.max_realm_currency - daily_info.current_realm_currency else None ) @@ -58,13 +51,15 @@ async def _get_daily_note(self, client) -> RenderResult: else: remained_time = i.remaining_time if remained_time: - remained_time = (datetime.datetime.now().astimezone() + remained_time).strftime("%m-%d %H:%M") + remained_time = (datetime.now().astimezone() + remained_time).strftime("%m-%d %H:%M") + transformer, transformer_ready, transformer_recovery_time = False, None, None if daily_info.remaining_transformer_recovery_time is not None: transformer = True transformer_ready = daily_info.remaining_transformer_recovery_time.total_seconds() == 0 transformer_recovery_time = daily_info.transformer_recovery_time.strftime("%m-%d %H:%M") - daily_data = { + + render_data = { "uid": client.uid, "day": day, "resin_recovery_time": resin_recovery_time, @@ -87,38 +82,49 @@ async def _get_daily_note(self, client) -> RenderResult: "transformer_recovery_time": transformer_recovery_time, } render_result = await self.template_service.render( - "genshin/daily_note/daily_note.html", daily_data, {"width": 600, "height": 548}, full_page=False, ttl=8 * 60 + "genshin/daily_note/daily_note.html", + render_data, + {"width": 600, "height": 548}, + full_page=False, + ttl=8 * 60, ) return render_result - @handler(CommandHandler, command="dailynote", block=False) - @handler(MessageHandler, filters=filters.Regex("^当前状态(.*)"), block=False) - @restricts(30) - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> Optional[int]: - user = update.effective_user + @handler.command("dailynote", block=False) + @handler.message(filters.Regex("^当前状态(.*)"), block=False) + async def command_start(self, update: Update, _: CallbackContext) -> Optional[int]: message = update.effective_message - logger.info(f"用户 {user.full_name}[{user.id}] 查询游戏状态命令请求") + user = update.effective_user + logger.info("用户 %s[%s] 每日便签命令请求", user.full_name, user.id) + try: - client = await get_genshin_client(user.id) + # 获取当前用户的 genshin.Client + client = await self.helper.get_genshin_client(user.id) + # 渲染 render_result = await self._get_daily_note(client) - except (UserNotFoundError, CookiesNotFoundError): - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] + except (CookiesNotFoundError, PlayerNotFoundError): + buttons = [ + [ + InlineKeyboardButton( + "点我绑定账号", url=create_deep_linked_url(self.application.bot.username, "set_cookie") + ) + ] + ] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return except DataNotPublic: reply_message = await message.reply_text("查询失败惹,可能是便签功能被禁用了?请尝试通过米游社或者 hoyolab 获取一次便签信息后重试。") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 300) - self._add_delete_message_job(context, message.chat_id, message.message_id, 300) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return ConversationHandler.END + await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) await render_result.reply_photo(message, filename=f"{client.uid}.png", allow_sending_without_reply=True) diff --git a/plugins/genshin/help.py b/plugins/genshin/help.py index cd7edaa1..51703837 100644 --- a/plugins/genshin/help.py +++ b/plugins/genshin/help.py @@ -1,13 +1,13 @@ from telegram import Update from telegram.constants import ChatAction -from telegram.ext import CommandHandler, CallbackContext +from telegram.ext import CallbackContext from core.plugin import Plugin, handler -from core.template import TemplateService -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from core.services.template.services import TemplateService from utils.log import logger +__all__ = ("HelpPlugin",) + class HelpPlugin(Plugin): def __init__(self, template_service: TemplateService = None): @@ -15,17 +15,15 @@ def __init__(self, template_service: TemplateService = None): raise ModuleNotFoundError self.template_service = template_service - @handler(CommandHandler, command="help", block=False) - @error_callable - @restricts() - async def start(self, update: Update, context: CallbackContext): - user = update.effective_user + @handler.command(command="help", block=False) + async def start(self, update: Update, _: CallbackContext): message = update.effective_message + user = update.effective_user logger.info("用户 %s[%s] 发出help命令", user.full_name, user.id) await message.reply_chat_action(ChatAction.TYPING) render_result = await self.template_service.render( "bot/help/help.html", - {"bot_username": context.bot.username}, + {"bot_username": self.application.bot.username}, {"width": 1280, "height": 900}, ttl=30 * 24 * 60 * 60, ) diff --git a/plugins/genshin/help_raw.py b/plugins/genshin/help_raw.py new file mode 100644 index 00000000..91211210 --- /dev/null +++ b/plugins/genshin/help_raw.py @@ -0,0 +1,38 @@ +import os +from typing import Optional + +import aiofiles +from bs4 import BeautifulSoup +from telegram import Update +from telegram.ext import CallbackContext + +from core.plugin import Plugin, handler +from utils.log import logger + +__all__ = ("HelpRawPlugin",) + + +class HelpRawPlugin(Plugin): + def __init__(self): + self.help_raw: Optional[str] = None + + async def initialize(self): + file_path = os.path.join(os.getcwd(), "resources", "bot", "help", "help.html") # resources/bot/help/help.html + async with aiofiles.open(file_path, mode="r", encoding="utf-8") as f: + html_content = await f.read() + soup = BeautifulSoup(html_content, "lxml") + command_div = soup.find_all("div", _class="command") + for div in command_div: + command_name_div = div.find("div", _class="command_name") + if command_name_div: + command_description_div = div.find("div", _class="command-description") + if command_description_div: + self.help_raw += f"/{command_name_div.text} - {command_description_div}" + + @handler.command(command="help_raw", block=False) + async def start(self, update: Update, _: CallbackContext): + if self.help_raw is not None: + message = update.effective_message + user = update.effective_user + logger.info("用户 %s[%s] 发出 help_raw 命令", user.full_name, user.id) + await message.reply_text(self.help_raw, allow_sending_without_reply=True) diff --git a/plugins/genshin/hilichurls.py b/plugins/genshin/hilichurls.py index dadfa82c..3ecd73bc 100644 --- a/plugins/genshin/hilichurls.py +++ b/plugins/genshin/hilichurls.py @@ -1,14 +1,11 @@ -from os import sep +from typing import Dict +from aiofiles import open as async_open from telegram import Update -from telegram.ext import CommandHandler, CallbackContext -from telegram.ext import filters +from telegram.ext import CallbackContext, filters -from core.baseplugin import BasePlugin from core.plugin import Plugin, handler -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from utils.const import RESOURCE_DIR from utils.log import logger try: @@ -17,37 +14,39 @@ except ImportError: import json as jsonlib +__all__ = ("HilichurlsPlugin",) -class HilichurlsPlugin(Plugin, BasePlugin): + +class HilichurlsPlugin(Plugin): """丘丘语字典.""" - def __init__(self): + hilichurls_dictionary: Dict[str, str] + + async def initialize(self) -> None: """加载数据文件.数据整理自 https://wiki.biligame.com/ys By @zhxycn.""" - with open(f"resources{sep}json{sep}hilichurls_dictionary.json", "r", encoding="utf8") as f: - self.hilichurls_dictionary = jsonlib.load(f) + async with async_open(RESOURCE_DIR / "json/hilichurls_dictionary.json", encoding="utf-8") as file: + self.hilichurls_dictionary = jsonlib.loads(await file.read()) - @handler(CommandHandler, command="hilichurls", block=False) - @restricts() - @error_callable + @handler.command(command="hilichurls", block=False) async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) if len(args) >= 1: msg = args[0] else: reply_message = await message.reply_text("请输入要查询的丘丘语。") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return search = str.casefold(msg) # 忽略大小写以方便查询 if search not in self.hilichurls_dictionary: reply_message = await message.reply_text(f"在丘丘语字典中未找到 {msg}。") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return - logger.info(f"用户 {user.full_name}[{user.id}] 查询丘丘语字典命令请求 || 参数 {msg}") + logger.info("用户 %s[%s] 查询今日角色生日列表 查询丘丘语字典命令请求 || 参数 %s", user.full_name, user.id, msg) result = self.hilichurls_dictionary[f"{search}"] await message.reply_markdown_v2(f"丘丘语: `{search}`\n\n`{result}`") diff --git a/plugins/genshin/ledger.py b/plugins/genshin/ledger.py index 1336b69f..68d5a9cf 100644 --- a/plugins/genshin/ledger.py +++ b/plugins/genshin/ledger.py @@ -2,75 +2,38 @@ import re from datetime import datetime, timedelta -from genshin import GenshinException, DataNotPublic, InvalidCookies -from telegram import Update, InlineKeyboardMarkup, InlineKeyboardButton +from genshin import DataNotPublic, InvalidCookies, GenshinException +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.constants import ChatAction -from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters +from telegram.ext import filters, CallbackContext from telegram.helpers import create_deep_linked_url -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService from core.plugin import Plugin, handler -from core.template.services import RenderResult, TemplateService -from core.user.error import UserNotFoundError -from core.user.services import UserService -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from core.services.cookies import CookiesService +from core.services.template.models import RenderResult +from core.services.template.services import TemplateService +from plugins.tools.genshin import CookiesNotFoundError, GenshinHelper, PlayerNotFoundError from utils.log import logger +__all__ = ("LedgerPlugin",) -def get_now() -> datetime: - now = datetime.now() - return (now - timedelta(days=1)) if now.day == 1 and now.hour <= 4 else now - -def check_ledger_month(context: CallbackContext) -> int: - now_time = get_now() - month = now_time.month - args = get_args(context) - if len(args) >= 1: - month = args[0].replace("月", "") - if re_data := re.findall(r"\d+", str(month)): - month = int(re_data[0]) - else: - num_dict = {"一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "十": 10} - month = sum(num_dict.get(i, 0) for i in str(month)) - # check right - allow_month = [now_time.month] - last_month = now_time.replace(day=1) - timedelta(days=1) - allow_month.append(last_month.month) - last_month = last_month.replace(day=1) - timedelta(days=1) - allow_month.append(last_month.month) - - if month in allow_month: - return month - elif isinstance(month, int): - raise IndexError - return now_time.month - - -class Ledger(Plugin, BasePlugin): - """旅行札记""" +class LedgerPlugin(Plugin): + """旅行札记查询""" def __init__( self, - user_service: UserService = None, - cookies_service: CookiesService = None, - template_service: TemplateService = None, + helper: GenshinHelper, + cookies_service: CookiesService, + template_service: TemplateService, ): self.template_service = template_service self.cookies_service = cookies_service - self.user_service = user_service self.current_dir = os.getcwd() + self.helper = helper async def _start_get_ledger(self, client, month=None) -> RenderResult: - try: - diary_info = await client.get_diary(client.uid, month=month) - except GenshinException as error: - raise error + diary_info = await client.get_diary(client.uid, month=month) color = ["#73a9c6", "#d56565", "#70b2b4", "#bd9a5a", "#739970", "#7a6da7", "#597ea0"] categories = [ { @@ -104,25 +67,46 @@ def format_amount(amount: int) -> str: ) return render_result - @handler(CommandHandler, command="ledger", block=False) - @handler(MessageHandler, filters=filters.Regex("^旅行札记查询(.*)"), block=False) - @restricts() - @error_callable + @handler.command(command="ledger", block=False) + @handler.message(filters=filters.Regex("^旅行札记查询(.*)"), block=False) async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message + + now = datetime.now() + now_time = (now - timedelta(days=1)) if now.day == 1 and now.hour <= 4 else now + month = now_time.month try: - month = check_ledger_month(context) + args = self.get_args(context) + if len(args) >= 1: + month = args[0].replace("月", "") + if re_data := re.findall(r"\d+", str(month)): + month = int(re_data[0]) + else: + num_dict = {"一": 1, "二": 2, "三": 3, "四": 4, "五": 5, "六": 6, "七": 7, "八": 8, "九": 9, "十": 10} + month = sum(num_dict.get(i, 0) for i in str(month)) + # check right + allow_month = [now_time.month] + + last_month = now_time.replace(day=1) - timedelta(days=1) + allow_month.append(last_month.month) + + last_month = last_month.replace(day=1) - timedelta(days=1) + allow_month.append(last_month.month) + + if month not in allow_month and isinstance(month, int): + raise IndexError + month = now_time.month except IndexError: reply_message = await message.reply_text("仅可查询最新三月的数据,请重新输入") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return logger.info("用户 %s[%s] 查询旅行札记", user.full_name, user.id) await message.reply_chat_action(ChatAction.TYPING) try: - client = await get_genshin_client(user.id) + client = await self.helper.get_genshin_client(user.id) try: render_result = await self._start_get_ledger(client, month) except InvalidCookies as exc: # 如果抛出InvalidCookies 判断是否真的玄学过期(或权限不足?) @@ -132,26 +116,31 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: ) reply_message = await message.reply_text("出错了呜呜呜 ~ 当前访问令牌无法请求角色数数据,请尝试重新获取Cookie。") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return - except (UserNotFoundError, CookiesNotFoundError): - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] + except (PlayerNotFoundError, CookiesNotFoundError): + buttons = [ + [ + InlineKeyboardButton( + "点我绑定账号", url=create_deep_linked_url(self.application.bot.username, "set_cookie") + ) + ] + ] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return except DataNotPublic: reply_message = await message.reply_text("查询失败惹,可能是旅行札记功能被禁用了?请先通过米游社或者 hoyolab 获取一次旅行札记后重试。") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return except GenshinException as exc: if exc.retcode == -120: diff --git a/plugins/genshin/map.py b/plugins/genshin/map.py index 6a0878e3..c74952de 100644 --- a/plugins/genshin/map.py +++ b/plugins/genshin/map.py @@ -5,21 +5,17 @@ from telegram.constants import ChatAction from telegram.ext import CommandHandler, MessageHandler, filters, CallbackContext, CallbackQueryHandler -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin from core.config import config +from core.dependence.redisdb import RedisDB from core.plugin import handler, Plugin from modules.apihelper.client.components.map import MapHelper, MapException -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger -class Map(Plugin, BasePlugin): +class Map(Plugin): """资源点查询""" - def __init__(self, redis: RedisDB = None): + def __init__(self, redis: RedisDB): self.cache = redis.client self.cache_photo_key = "plugin:map:photo:" self.cache_doc_key = "plugin:map:doc:" @@ -120,8 +116,6 @@ def gen_caption(self, map_id: Union[int, str], name: str) -> str: @handler(CommandHandler, command="map", block=False) @handler(MessageHandler, filters=filters.Regex("^(?P.*)(在哪里|在哪|哪里有|哪儿有|哪有|在哪儿)$"), block=False) @handler(MessageHandler, filters=filters.Regex("^(哪里有|哪儿有|哪有)(?P.*)$"), block=False) - @error_callable - @restricts(restricts_time=20) async def command_start(self, update: Update, context: CallbackContext): message = update.effective_message args = context.args @@ -167,8 +161,6 @@ async def command_start(self, update: Update, context: CallbackContext): self.temp_photo = reply_message.photo[-1].file_id @handler(CallbackQueryHandler, pattern=r"^get_map\|", block=False) - @restricts(restricts_time=3, without_overlapping=True) - @error_callable async def get_maps(self, update: Update, _: CallbackContext) -> None: callback_query = update.callback_query user = callback_query.from_user @@ -192,8 +184,7 @@ async def get_map_callback(callback_query_data: str) -> Tuple[int, str, str]: except MapException as e: await message.reply_text(e.message) - @handler.command("refresh_map") - @bot_admins_rights_check + @handler.command("refresh_map", admin=True) async def refresh_map(self, update: Update, _: CallbackContext): message = update.effective_message msg = await message.reply_text("正在刷新地图数据,请耐心等待...") diff --git a/plugins/genshin/material.py b/plugins/genshin/material.py index 66b076db..d93f3a29 100644 --- a/plugins/genshin/material.py +++ b/plugins/genshin/material.py @@ -1,19 +1,16 @@ -from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update from telegram.constants import ChatAction, ParseMode -from telegram.ext import filters, ConversationHandler, CommandHandler, MessageHandler, CallbackContext +from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters -from core.baseplugin import BasePlugin -from core.game.services import GameMaterialService from core.plugin import Plugin, handler +from core.services.game.services import GameMaterialService from metadata.shortname import roleToName -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import url_to_file from utils.log import logger +__all__ = ("MaterialPlugin",) -class Material(Plugin, BasePlugin): + +class MaterialPlugin(Plugin): """角色培养素材查询""" KEYBOARD = [[InlineKeyboardButton(text="查看角色培养素材列表并查询", switch_inline_query_current_chat="查看角色培养素材列表并查询")]] @@ -23,12 +20,10 @@ def __init__(self, game_material_service: GameMaterialService = None): @handler(CommandHandler, command="material", block=False) @handler(MessageHandler, filters=filters.Regex("^角色培养素材查询(.*)"), block=False) - @restricts(return_data=ConversationHandler.END) - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) if len(args) >= 1: character_name = args[0] else: @@ -36,8 +31,8 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: "请回复你要查询的培养素材的角色名", reply_markup=InlineKeyboardMarkup(self.KEYBOARD) ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return character_name = roleToName(character_name) url = await self.game_material_service.get_material(character_name) @@ -46,12 +41,12 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: f"没有找到 {character_name} 的培养素材", reply_markup=InlineKeyboardMarkup(self.KEYBOARD) ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return - logger.info(f"用户 {user.full_name}[{user.id}] 查询角色培养素材命令请求 || 参数 {character_name}") + logger.info("用户 %s[%s] 查询角色培养素材命令请求 || 参数 %s", user.full_name, user.id, character_name) await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) - file_path = await url_to_file(url, return_path=True) + file_path = await self.download_resource(url, return_path=True) caption = "From 米游社 " f"查看 [原图]({url})" await message.reply_photo( photo=open(file_path, "rb"), diff --git a/plugins/genshin/pay_log.py b/plugins/genshin/pay_log.py index 11da410a..20fa6bd7 100644 --- a/plugins/genshin/pay_log.py +++ b/plugins/genshin/pay_log.py @@ -1,46 +1,39 @@ -import contextlib - import genshin from telegram import Update, User, InlineKeyboardButton, InlineKeyboardMarkup from telegram.constants import ChatAction from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters, ConversationHandler from telegram.helpers import create_deep_linked_url -from core.baseplugin import BasePlugin -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError +from core.basemodel import RegionEnum from core.plugin import Plugin, handler, conversation -from core.template import TemplateService -from core.user import UserService -from core.user.error import UserNotFoundError +from core.services.cookies import CookiesService +from core.services.players.services import PlayersService +from core.services.template.services import TemplateService from modules.gacha_log.helpers import from_url_get_authkey from modules.pay_log.error import PayLogNotFound, PayLogAccountNotFound, PayLogInvalidAuthkey, PayLogAuthkeyTimeout from modules.pay_log.log import PayLog -from utils.bot import get_args -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from plugins.tools.genshin import GenshinHelper, PlayerNotFoundError from utils.genshin import get_authkey_by_stoken -from utils.helpers import get_genshin_client from utils.log import logger -from utils.models.base import RegionEnum INPUT_URL, CONFIRM_DELETE = range(10100, 10102) -class PayLogPlugin(Plugin.Conversation, BasePlugin.Conversation): +class PayLogPlugin(Plugin.Conversation): """充值记录导入/导出/分析""" def __init__( self, - template_service: TemplateService = None, - user_service: UserService = None, - cookie_service: CookiesService = None, + template_service: TemplateService, + players_service: PlayersService, + cookie_service: CookiesService, + helper: GenshinHelper, ): self.template_service = template_service - self.user_service = user_service + self.players_service = players_service self.cookie_service = cookie_service self.pay_log = PayLog() + self.helper = helper async def _refresh_user_data(self, user: User, authkey: str = None) -> str: """刷新用户数据 @@ -50,7 +43,7 @@ async def _refresh_user_data(self, user: User, authkey: str = None) -> str: """ try: logger.debug("尝试获取已绑定的原神账号") - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) new_num = await self.pay_log.get_log_data(user.id, client, authkey) return "更新完成,本次没有新增数据" if new_num == 0 else f"更新完成,本次共新增{new_num}条充值记录" except PayLogNotFound: @@ -61,50 +54,41 @@ async def _refresh_user_data(self, user: User, authkey: str = None) -> str: return "更新数据失败,authkey 无效" except PayLogAuthkeyTimeout: return "更新数据失败,authkey 已经过期" - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) return "派蒙没有找到您所绑定的账号信息,请先私聊派蒙绑定账号" @conversation.entry_point @handler(CommandHandler, command="pay_log_import", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^导入充值记录$") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) logger.info("用户 %s[%s] 导入充值记录命令请求", user.full_name, user.id) authkey = from_url_get_authkey(args[0] if args else "") if not args: - try: - user_info = await self.user_service.get_user_by_id(user.id) - except UserNotFoundError: - user_info = None - if user_info and user_info.region == RegionEnum.HYPERION: - try: - cookies = await self.cookie_service.get_cookies(user_info.user_id, user_info.region) - except CookiesNotFoundError: - cookies = None - if cookies and cookies.cookies and "stoken" in cookies.cookies: + player_info = await self.players_service.get_player(user.id, region=RegionEnum.HYPERION) + if player_info is not None: + cookies = await self.cookie_service.get(user.id, account_id=player_info.account_id) + if cookies is not None and cookies.data and "stoken" in cookies.data: if stuid := next( - (value for key, value in cookies.cookies.items() if key in ["ltuid", "login_uid"]), None + (value for key, value in cookies.data.items() if key in ["ltuid", "login_uid"]), None ): - cookies.cookies["stuid"] = stuid + cookies.data["stuid"] = stuid client = genshin.Client( - cookies=cookies.cookies, + cookies=cookies.data, game=genshin.types.Game.GENSHIN, region=genshin.Region.CHINESE, lang="zh-cn", - uid=user_info.yuanshen_uid, + uid=player_info.player_id, ) - with contextlib.suppress(Exception): - authkey = await get_authkey_by_stoken(client) + authkey = await get_authkey_by_stoken(client) if not authkey: await message.reply_text( "开始导入充值历史记录:请通过 https://paimon.moe/wish/import 获取抽卡记录链接后发送给我" "(非 paimon.moe 导出的文件数据)\n\n" - "> 在绑定 Cookie 时添加 stoken 可能有特殊效果哦(国服)\n" + "> 在绑定 Cookie 时添加 stoken 可能有特殊效果哦(仅限国服)\n" "注意:导入的数据将会与旧数据进行合并。", parse_mode="html", ) @@ -119,15 +103,13 @@ async def command_start(self, update: Update, context: CallbackContext) -> int: return ConversationHandler.END @conversation.state(state=INPUT_URL) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) - @restricts() - @error_callable + @handler.message(filters=~filters.COMMAND, block=False) async def import_data_from_message(self, update: Update, _: CallbackContext) -> int: message = update.effective_message user = update.effective_user - if not message.text: - await message.reply_text("输入错误,请重新输入") - return INPUT_URL + if message.document: + await self.import_from_file(user, message) + return ConversationHandler.END authkey = from_url_get_authkey(message.text) reply = await message.reply_text("小派蒙正在从服务器获取数据,请稍后") await message.reply_chat_action(ChatAction.TYPING) @@ -138,26 +120,15 @@ async def import_data_from_message(self, update: Update, _: CallbackContext) -> @conversation.entry_point @handler(CommandHandler, command="pay_log_delete", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^删除充值记录$") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable - async def command_start_delete(self, update: Update, context: CallbackContext) -> int: + async def command_start_delete(self, update: Update, _: CallbackContext) -> int: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 删除充值记录命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id, need_cookie=False) - context.chat_data["uid"] = client.uid - except UserNotFoundError: + client = await self.helper.get_genshin_client(user.id, need_cookie=False) + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] - if filters.ChatType.GROUPS.filter(message): - reply_message = await message.reply_text( - "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) - ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) - else: - await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) + await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号") return ConversationHandler.END _, status = await self.pay_log.load_history_info(str(user.id), str(client.uid), only_status=True) if not status: @@ -168,8 +139,6 @@ async def command_start_delete(self, update: Update, context: CallbackContext) - @conversation.state(state=CONFIRM_DELETE) @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) - @restricts() - @error_callable async def command_confirm_delete(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user @@ -180,11 +149,10 @@ async def command_confirm_delete(self, update: Update, context: CallbackContext) await message.reply_text("已取消") return ConversationHandler.END - @handler(CommandHandler, command="pay_log_force_delete", block=False) - @bot_admins_rights_check + @handler(CommandHandler, command="pay_log_force_delete", block=False, admin=True) async def command_pay_log_force_delete(self, update: Update, context: CallbackContext): message = update.effective_message - args = get_args(context) + args = self.get_args(context) if not args: await message.reply_text("请指定用户ID") return @@ -192,7 +160,10 @@ async def command_pay_log_force_delete(self, update: Update, context: CallbackCo cid = int(args[0]) if cid < 0: raise ValueError("Invalid cid") - client = await get_genshin_client(cid, need_cookie=False) + client = await self.helper.get_genshin_client(cid, need_cookie=False) + if client is None: + await message.reply_text("该用户暂未绑定账号") + return _, status = await self.pay_log.load_history_info(str(cid), str(client.uid), only_status=True) if not status: await message.reply_text("该用户还没有导入充值记录") @@ -201,21 +172,17 @@ async def command_pay_log_force_delete(self, update: Update, context: CallbackCo await message.reply_text("充值记录已强制删除" if status else "充值记录删除失败") except PayLogNotFound: await message.reply_text("该用户还没有导入充值记录") - except UserNotFoundError: - await message.reply_text("该用户暂未绑定账号") except (ValueError, IndexError): await message.reply_text("用户ID 不合法") @handler(CommandHandler, command="pay_log_export", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^导出充值记录$") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable async def command_start_export(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 导出充值记录命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) await message.reply_chat_action(ChatAction.TYPING) path = self.pay_log.get_file_path(str(user.id), str(client.uid)) if not path.exists(): @@ -229,28 +196,18 @@ async def command_start_export(self, update: Update, context: CallbackContext) - await message.reply_text("派蒙没有找到你的充值记录,快来私聊派蒙导入吧~", reply_markup=InlineKeyboardMarkup(buttons)) except PayLogAccountNotFound: await message.reply_text("导出失败,可能文件包含的祈愿记录所属 uid 与你当前绑定的 uid 不同") - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] - if filters.ChatType.GROUPS.filter(message): - reply_message = await message.reply_text( - "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) - ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) - else: - await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) + await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号") @handler(CommandHandler, command="pay_log", block=False) @handler(MessageHandler, filters=filters.Regex("^充值记录$"), block=False) - @restricts() - @error_callable async def command_start_analysis(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 充值记录统计命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) await message.reply_chat_action(ChatAction.TYPING) data = await self.pay_log.get_analysis(user.id, client) await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) @@ -263,14 +220,15 @@ async def command_start_analysis(self, update: Update, context: CallbackContext) [InlineKeyboardButton("点我导入", url=create_deep_linked_url(context.bot.username, "pay_log_import"))] ] await message.reply_text("派蒙没有找到你的充值记录,快来点击按钮私聊派蒙导入吧~", reply_markup=InlineKeyboardMarkup(buttons)) - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) diff --git a/plugins/genshin/player_cards.py b/plugins/genshin/player_cards.py index 7794a257..9772c80a 100644 --- a/plugins/genshin/player_cards.py +++ b/plugins/genshin/player_cards.py @@ -1,7 +1,6 @@ import math from typing import Any, List, Tuple, Union, Optional -import ujson from enkanetwork import ( CharacterInfo, DigitType, @@ -26,32 +25,29 @@ from telegram.ext import CallbackContext, CallbackQueryHandler, CommandHandler, MessageHandler, filters from telegram.helpers import create_deep_linked_url -from core.base.assets import DEFAULT_EnkaAssets -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin from core.config import config +from core.dependence.assets import DEFAULT_EnkaAssets +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.template import TemplateService -from core.user import UserService -from core.user.error import UserNotFoundError +from core.services.players import PlayersService +from core.services.template.services import TemplateService from metadata.shortname import roleToName from modules.playercards.file import PlayerCardsFile from modules.playercards.helpers import ArtifactStatsTheory -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.enkanetwork import RedisCache -from utils.helpers import url_to_file +from utils.helpers import download_resource from utils.log import logger -from utils.models.base import RegionEnum from utils.patch.aiohttp import AioHttpTimeoutException +try: + import ujson as jsonlib +except ImportError: + import json as jsonlib -class PlayerCards(Plugin, BasePlugin): - def __init__( - self, user_service: UserService = None, template_service: TemplateService = None, redis: RedisDB = None - ): - self.user_service = user_service + +class PlayerCards(Plugin): + def __init__(self, player_service: PlayersService, template_service: TemplateService, redis: RedisDB): + self.player_service = player_service self.client = EnkaNetworkAPI(lang="chs", user_agent=config.enka_network_api_agent, cache=False) self.cache = RedisCache(redis.client, key="plugin:player_cards:enka_network") self.player_cards_file = PlayerCardsFile() @@ -65,7 +61,7 @@ async def _fetch_user(self, uid) -> Union[EnkaNetworkResponse, str]: return EnkaNetworkResponse.parse_obj(data) user = await self.client.http.fetch_user_by_uid(uid) data = user["content"].decode("utf-8", "surrogatepass") # type: ignore - data = ujson.loads(data) + data = jsonlib.loads(data) data = await self.player_cards_file.merge_info(uid, data) await self.cache.set(uid, data) return EnkaNetworkResponse.parse_obj(data) @@ -93,32 +89,25 @@ async def _fetch_user(self, uid) -> Union[EnkaNetworkResponse, str]: @handler(CommandHandler, command="player_card", block=False) @handler(MessageHandler, filters=filters.Regex("^角色卡片查询(.*)"), block=False) - @restricts(restricts_time_of_groups=20, without_overlapping=True) - @error_callable async def player_cards(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message - args = get_args(context) + args = self.get_args(context) await message.reply_chat_action(ChatAction.TYPING) - try: - user_info = await self.user_service.get_user_by_id(user.id) - if user_info.region == RegionEnum.HYPERION: - uid = user_info.yuanshen_uid - else: - uid = user_info.genshin_uid - except UserNotFoundError: + player_info = await self.player_service.get_player(user.id) + if player_info is None: buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return - data = await self._fetch_user(uid) + data = await self._fetch_user(player_info.player_id) if isinstance(data, str): await message.reply_text(data) return @@ -127,10 +116,16 @@ async def player_cards(self, update: Update, context: CallbackContext) -> None: return if len(args) == 1: character_name = roleToName(args[0]) - logger.info(f"用户 {user.full_name}[{user.id}] 角色卡片查询命令请求 || character_name[{character_name}] uid[{uid}]") + logger.info( + "用户 %s[%s] 角色卡片查询命令请求 || character_name[%s] uid[%s]", + user.full_name, + user.id, + character_name, + player_info.player_id, + ) else: - logger.info(f"用户 {user.full_name}[{user.id}] 角色卡片查询命令请求") - buttons = self.gen_button(data, user.id, uid) + logger.info("用户 %s[%s] 角色卡片查询命令请求", user.full_name, user.id) + buttons = self.gen_button(data, user.id, player_info.player_id) if isinstance(self.temp_photo, str): photo = self.temp_photo else: @@ -148,12 +143,12 @@ async def player_cards(self, update: Update, context: CallbackContext) -> None: await message.reply_text(f"角色展柜中未找到 {character_name} ,请检查角色是否存在于角色展柜中,或者等待角色数据更新后重试") return await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) - render_result = await RenderTemplate(uid, characters, self.template_service).render() # pylint: disable=W0631 - await render_result.reply_photo(message, filename=f"player_card_{uid}_{character_name}.png") + render_result = await RenderTemplate( + player_info.player_id, characters, self.template_service + ).render() # pylint: disable=W0631 + await render_result.reply_photo(message, filename=f"player_card_{player_info.player_id}_{character_name}.png") @handler(CallbackQueryHandler, pattern=r"^get_player_card\|", block=False) - @restricts(restricts_time=3, without_overlapping=True) - @error_callable async def get_player_cards(self, update: Update, _: CallbackContext) -> None: callback_query = update.callback_query user = callback_query.from_user @@ -164,7 +159,7 @@ async def get_player_card_callback(callback_query_data: str) -> Tuple[str, int, _user_id = int(_data[1]) _uid = int(_data[2]) _result = _data[3] - logger.debug(f"callback_query_data函数返回 result[{_result}] user_id[{_user_id}] uid[{_uid}]") + logger.debug("callback_query_data函数返回 result[%s] user_id[%s] uid[%s]", _result, _user_id, _uid) return _result, _user_id, _uid result, user_id, uid = await get_player_card_callback(callback_query.data) @@ -421,25 +416,26 @@ async def cache_images(self) -> None: # TODO: 并发下载所有资源 c = self.character # 角色 - c.image.banner.url = await url_to_file(c.image.banner.url) + c.image.banner.url = await download_resource(c.image.banner.url) # 技能 for item in c.skills: - item.icon.url = await url_to_file(item.icon.url) + item.icon.url = await download_resource(item.icon.url) # 命座 for item in c.constellations: - item.icon.url = await url_to_file(item.icon.url) + item.icon.url = await download_resource(item.icon.url) # 装备,包括圣遗物和武器 for item in c.equipments: - item.detail.icon.url = await url_to_file(item.detail.icon.url) + item.detail.icon.url = await download_resource(item.detail.icon.url) - def find_weapon(self) -> Union[Equipments, None]: + def find_weapon(self) -> Optional[Equipments]: """在 equipments 数组中找到武器,equipments 数组包含圣遗物和武器""" for item in self.character.equipments: if item.type == EquipmentsType.WEAPON: return item + return None def find_artifacts(self) -> List[Artifact]: """在 equipments 数组中找到圣遗物,并转换成带有分数的 model。equipments 数组包含圣遗物和武器""" diff --git a/plugins/genshin/quiz.py b/plugins/genshin/quiz.py index 1bdd320f..594b59b5 100644 --- a/plugins/genshin/quiz.py +++ b/plugins/genshin/quiz.py @@ -1,38 +1,36 @@ import random -from telegram import Update, Poll +from telegram import Poll, Update from telegram.constants import ChatAction from telegram.error import BadRequest -from telegram.ext import CallbackContext, CommandHandler, filters +from telegram.ext import filters, CallbackContext -from core.admin import BotAdminService -from core.baseplugin import BasePlugin from core.plugin import Plugin, handler -from core.quiz import QuizService -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from core.services.quiz.services import QuizService +from core.services.users.services import UserService from utils.log import logger +__all__ = ("QuizPlugin",) -class QuizPlugin(Plugin, BasePlugin): + +class QuizPlugin(Plugin): """派蒙的十万个为什么""" - def __init__(self, quiz_service: QuizService = None, bot_admin_service: BotAdminService = None): - self.bot_admin_service = bot_admin_service + def __init__(self, quiz_service: QuizService = None, user_service: UserService = None): + self.user_service = user_service self.quiz_service = quiz_service self.time_out = 120 - @handler(CommandHandler, command="quiz", block=False) - @restricts(restricts_time_of_groups=20) - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> None: - user = update.effective_user + @handler.message(filters=filters.Regex("来一道题")) + @handler.command(command="quiz", block=False) + async def command_start(self, update: Update, _: CallbackContext) -> None: message = update.effective_message - chat = message.chat + user = update.effective_user + chat = update.effective_chat await message.reply_chat_action(ChatAction.TYPING) question_id_list = await self.quiz_service.get_question_id_list() if filters.ChatType.GROUPS.filter(message): - logger.info(f"用户 {user.full_name}[{user.id}] 在群 {chat.title}[{chat.id}] 发送挑战问题命令请求") + logger.info("用户 %s[%s] 在群 %s[%s] 发送挑战问题命令请求", user.full_name, user.id, chat.title, chat.id) if len(question_id_list) == 0: return None if len(question_id_list) == 0: @@ -47,7 +45,7 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: correct_option = answer.text if correct_option is None: question_id = question["question_id"] - logger.warning(f"Quiz模块 correct_option 异常 question_id[{question_id}] ") + logger.warning("Quiz模块 correct_option 异常 question_id[%s]", question_id) return None random.shuffle(_options) index = _options.index(correct_option) @@ -66,5 +64,5 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: else: raise exc if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, message.chat_id, message.message_id, 300) - self._add_delete_message_job(context, poll_message.chat_id, poll_message.message_id, 300) + self.add_delete_message_job(message, delay=300) + self.add_delete_message_job(poll_message, delay=300) diff --git a/plugins/genshin/reg_time.py b/plugins/genshin/reg_time.py index 749d36ff..863e66b0 100644 --- a/plugins/genshin/reg_time.py +++ b/plugins/genshin/reg_time.py @@ -1,26 +1,20 @@ from datetime import datetime - from genshin import Client, GenshinException, InvalidCookies from genshin.client.routes import InternationalRoute # noqa F401 from genshin.utility import recognize_genshin_server, get_ds_headers from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.constants import ParseMode -from telegram.ext import CommandHandler, CallbackContext, MessageHandler +from telegram.ext import CallbackContext from telegram.ext import filters from telegram.helpers import create_deep_linked_url -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.user import UserService -from core.user.error import UserNotFoundError -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from core.services.cookies import CookiesService +from core.services.users.services import UserService +from plugins.tools.genshin import GenshinHelper, PlayerNotFoundError, CookiesNotFoundError from utils.genshin import fetch_hk4e_token_by_cookie, recognize_genshin_game_biz -from utils.helpers import get_genshin_client from utils.log import logger try: @@ -35,19 +29,21 @@ ) -class RegTimePlugin(Plugin, BasePlugin): +class RegTimePlugin(Plugin): """查询原神注册时间""" def __init__( self, user_service: UserService = None, cookie_service: CookiesService = None, + helper: GenshinHelper = None, redis: RedisDB = None, ): self.cache = redis.client self.cache_key = "plugin:reg_time:" self.user_service = user_service self.cookie_service = cookie_service + self.helper = helper @staticmethod async def get_reg_time(client: Client) -> str: @@ -78,16 +74,14 @@ async def get_reg_time_from_cache(self, client: Client) -> str: await self.cache.set(f"{self.cache_key}{client.uid}", reg_time) return reg_time - @handler(CommandHandler, command="reg_time", block=False) - @handler(MessageHandler, filters=filters.Regex("^原神账号注册时间$"), block=False) - @restricts() - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> None: + @handler.command("reg_time", block=False) + @handler.message(filters.Regex(r"^原神账号注册时间$"), block=False) + async def reg_time(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 原神注册时间命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id) + client = await self.helper.get_genshin_client(user.id) game_uid = client.uid try: reg_time = await self.get_reg_time_from_cache(client) @@ -96,11 +90,11 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: logger.warning("用户 %s[%s] 无法请求注册时间 API返回信息为 [%s]%s", user.full_name, user.id, exc.retcode, exc.original) reply_message = await message.reply_text("出错了呜呜呜 ~ 当前访问令牌无法请求角色数数据,") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) return await message.reply_text(f"你的原神账号 [{game_uid}] 注册时间为:{reg_time}") - except (UserNotFoundError, CookiesNotFoundError): + except (PlayerNotFoundError, CookiesNotFoundError): buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_msg = await message.reply_text( @@ -108,8 +102,8 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: reply_markup=InlineKeyboardMarkup(buttons), parse_mode=ParseMode.HTML, ) - self._add_delete_message_job(context, reply_msg.chat_id, reply_msg.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_msg, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text( "此功能需要绑定cookie后使用,请先私聊派蒙进行绑定", diff --git a/plugins/genshin/sign.py b/plugins/genshin/sign.py index 9796ca5e..bcb07464 100644 --- a/plugins/genshin/sign.py +++ b/plugins/genshin/sign.py @@ -1,325 +1,50 @@ -import asyncio import datetime -import random -import time -from json import JSONDecodeError from typing import Optional, Tuple -from genshin import Game, GenshinException, AlreadyClaimed, Client -from genshin.utility import recognize_genshin_server -from httpx import AsyncClient, TimeoutException from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.constants import ChatAction from telegram.ext import CommandHandler, CallbackContext, CallbackQueryHandler from telegram.ext import MessageHandler, filters from telegram.helpers import create_deep_linked_url -from core.admin.services import BotAdminService -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin -from core.bot import bot from core.config import config -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService from core.plugin import Plugin, handler -from core.sign.models import Sign as SignUser, SignStatusEnum -from core.sign.services import SignServices -from core.user.error import UserNotFoundError -from core.user.services import UserService -from modules.apihelper.client.components.verify import Verify -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client +from core.services.sign.models import Sign as SignUser, SignStatusEnum +from core.services.sign.services import SignServices +from core.services.users.services import UserAdminService +from plugins.tools.genshin import GenshinHelper, CookiesNotFoundError, PlayerNotFoundError +from plugins.tools.sign import SignSystem, NeedChallenge from utils.log import logger -class NeedChallenge(Exception): - def __init__(self, uid: int, gt: str = "", challenge: str = ""): - super().__init__() - self.uid = uid - self.gt = gt - self.challenge = challenge - - -class SignSystem: - REFERER = ( - "https://webstatic.mihoyo.com/bbs/event/signin-ys/index.html?" - "bbs_auth_required=true&act_id=e202009291139501&utm_source=bbs&utm_medium=mys&utm_campaign=icon" - ) - - def __init__(self, redis: RedisDB): - self.cache = redis.client - self.qname = "plugin:sign:" - self.verify = Verify() - - async def get_challenge(self, uid: int) -> Tuple[Optional[str], Optional[str]]: - data = await self.cache.get(f"{self.qname}{uid}") - if not data: - return None, None - data = data.decode("utf-8").split("|") - return data[0], data[1] - - async def set_challenge(self, uid: int, gt: str, challenge: str): - await self.cache.set(f"{self.qname}{uid}", f"{gt}|{challenge}") - await self.cache.expire(f"{self.qname}{uid}", 10 * 60) - - async def get_challenge_button( - self, uid: int, user_id: int, gt: Optional[str] = None, challenge: Optional[str] = None, callback: bool = True - ) -> Optional[InlineKeyboardMarkup]: - if not config.pass_challenge_user_web: - return None - if challenge and gt: - await self.set_challenge(uid, gt, challenge) - if not challenge or not gt: - gt, challenge = await self.get_challenge(uid) - if not challenge or not gt: - return None - if callback: - data = f"sign|{user_id}|{uid}" - return InlineKeyboardMarkup([[InlineKeyboardButton("请尽快点我进行手动验证", callback_data=data)]]) - else: - url = f"{config.pass_challenge_user_web}?username={bot.app.bot.username}&command=sign>={gt}&challenge={challenge}&uid={uid}" - return InlineKeyboardMarkup([[InlineKeyboardButton("请尽快点我进行手动验证", url=url)]]) - - async def recognize(self, gt: str, challenge: str, referer: str = None) -> Optional[str]: - if not referer: - referer = self.REFERER - if not gt or not challenge: - return None - pass_challenge_params = { - "gt": gt, - "challenge": challenge, - "referer": referer, - } - if config.pass_challenge_app_key: - pass_challenge_params["appkey"] = config.pass_challenge_app_key - headers = { - "Accept": "*/*", - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/107.0.0.0 Safari/537.36", - } - try: - async with AsyncClient(headers=headers) as client: - resp = await client.post( - config.pass_challenge_api, - params=pass_challenge_params, - timeout=60, - ) - logger.debug("recognize 请求返回:%s", resp.text) - data = resp.json() - status = data.get("status") - if status != 0: - logger.error("recognize 解析错误:[%s]%s", data.get("code"), data.get("msg")) - if data.get("code", 0) != 0: - raise RuntimeError - logger.info("recognize 解析成功") - return data["data"]["validate"] - except JSONDecodeError: - logger.warning("recognize 请求 JSON 解析失败") - except TimeoutException as exc: - logger.warning("recognize 请求超时") - raise exc - except KeyError: - logger.warning("recognize 请求数据错误") - except RuntimeError: - logger.warning("recognize 请求失败") - return None - - async def start_sign( - self, - client: Client, - challenge: Optional[str] = None, - validate: Optional[str] = None, - is_sleep: bool = False, - is_raise: bool = False, - title: Optional[str] = "签到结果", - ) -> str: - if is_sleep: - if recognize_genshin_server(client.uid) in ("cn_gf01", "cn_qd01"): - await asyncio.sleep(random.randint(10, 300)) # nosec - else: - await asyncio.sleep(random.randint(0, 3)) # nosec - try: - rewards = await client.get_monthly_rewards(game=Game.GENSHIN, lang="zh-cn") - except GenshinException as error: - logger.warning("UID[%s] 获取签到信息失败,API返回信息为 %s", client.uid, str(error)) - if is_raise: - raise error - return f"获取签到信息失败,API返回信息为 {str(error)}" - try: - daily_reward_info = await client.get_reward_info(game=Game.GENSHIN, lang="zh-cn") # 获取签到信息失败 - except GenshinException as error: - logger.warning("UID[%s] 获取签到状态失败,API返回信息为 %s", client.uid, str(error)) - if is_raise: - raise error - return f"获取签到状态失败,API返回信息为 {str(error)}" - if not daily_reward_info.signed_in: - try: - if validate: - logger.info("UID[%s] 正在尝试通过验证码\nchallenge[%s]\nvalidate[%s]", client.uid, challenge, validate) - request_daily_reward = await client.request_daily_reward( - "sign", - method="POST", - game=Game.GENSHIN, - lang="zh-cn", - challenge=challenge, - validate=validate, - ) - logger.debug("request_daily_reward 返回 %s", request_daily_reward) - if request_daily_reward and request_daily_reward.get("success", 0) == 1: - # 尝试通过 ajax 请求绕过签到 - gt = request_daily_reward.get("gt", "") - challenge = request_daily_reward.get("challenge", "") - logger.warning("UID[%s] 触发验证码\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) - validate = await self.verify.ajax( - referer=self.REFERER, - gt=gt, - challenge=challenge, - ) - if validate: - logger.success("ajax 通过验证成功\nchallenge[%s]\nvalidate[%s]", challenge, validate) - request_daily_reward = await client.request_daily_reward( - "sign", - method="POST", - game=Game.GENSHIN, - lang="zh-cn", - challenge=challenge, - validate=validate, - ) - logger.debug("request_daily_reward 返回 %s", request_daily_reward) - if request_daily_reward and request_daily_reward.get("success", 0) == 1: - logger.warning("UID[%s] 触发验证码\nchallenge[%s]", client.uid, challenge) - raise NeedChallenge( - uid=client.uid, - gt=request_daily_reward.get("gt", ""), - challenge=request_daily_reward.get("challenge", ""), - ) - elif config.pass_challenge_app_key: - # 如果无法绕过 检查配置文件是否配置识别 API 尝试请求绕过 - # 注意 需要重新获取没有进行任何请求的 Challenge - logger.info("UID[%s] 正在使用 recognize 重新请求签到", client.uid) - _request_daily_reward = await client.request_daily_reward( - "sign", - method="POST", - game=Game.GENSHIN, - lang="zh-cn", - ) - logger.debug("request_daily_reward 返回\n%s", _request_daily_reward) - if _request_daily_reward and _request_daily_reward.get("success", 0) == 1: - _gt = _request_daily_reward.get("gt", "") - _challenge = _request_daily_reward.get("challenge", "") - logger.info("UID[%s] 创建验证码\ngt[%s]\nchallenge[%s]", client.uid, _gt, _challenge) - _validate = await self.recognize(_gt, _challenge) - if _validate: - logger.success("recognize 通过验证成功\nchallenge[%s]\nvalidate[%s]", _challenge, _validate) - request_daily_reward = await client.request_daily_reward( - "sign", - method="POST", - game=Game.GENSHIN, - lang="zh-cn", - challenge=_challenge, - validate=_validate, - ) - if request_daily_reward and request_daily_reward.get("success", 0) == 1: - logger.warning("UID[%s] 触发验证码\nchallenge[%s]", client.uid, _challenge) - gt = request_daily_reward.get("gt", "") - challenge = request_daily_reward.get("challenge", "") - logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) - raise NeedChallenge( - uid=client.uid, - gt=gt, - challenge=challenge, - ) - else: - logger.success("UID[%s] 通过 recognize 签到成功", client.uid) - else: - request_daily_reward = await client.request_daily_reward( - "sign", method="POST", game=Game.GENSHIN, lang="zh-cn" - ) - gt = request_daily_reward.get("gt", "") - challenge = request_daily_reward.get("challenge", "") - logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) - raise NeedChallenge(uid=client.uid, gt=gt, challenge=challenge) - else: - request_daily_reward = await client.request_daily_reward( - "sign", method="POST", game=Game.GENSHIN, lang="zh-cn" - ) - gt = request_daily_reward.get("gt", "") - challenge = request_daily_reward.get("challenge", "") - logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) - raise NeedChallenge(uid=client.uid, gt=gt, challenge=challenge) - else: - logger.success("UID[%s] 签到成功", client.uid) - except TimeoutException as error: - logger.warning("UID[%s] 签到请求超时", client.uid) - if is_raise: - raise error - return "签到失败了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ " - except AlreadyClaimed as error: - logger.warning("UID[%s] 已经签到", client.uid) - if is_raise: - raise error - result = "今天旅行者已经签到过了~" - except GenshinException as error: - logger.warning("UID %s 签到失败,API返回信息为 %s", client.uid, str(error)) - if is_raise: - raise error - return f"获取签到状态失败,API返回信息为 {str(error)}" - else: - result = "OK" - else: - logger.info("UID[%s] 已经签到", client.uid) - result = "今天旅行者已经签到过了~" - logger.info("UID[%s] 签到结果 %s", client.uid, result) - reward = rewards[daily_reward_info.claimed_rewards - (1 if daily_reward_info.signed_in else 0)] - today = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) - cn_timezone = datetime.timezone(datetime.timedelta(hours=8)) - now = datetime.datetime.now(cn_timezone) - missed_days = now.day - daily_reward_info.claimed_rewards - if not daily_reward_info.signed_in: - missed_days -= 1 - message = ( - f"#### {title} ####\n" - f"时间:{today} (UTC+8)\n" - f"UID: {client.uid}\n" - f"今日奖励: {reward.name} × {reward.amount}\n" - f"本月漏签次数:{missed_days}\n" - f"签到结果: {result}" - ) - return message - - -class Sign(Plugin, BasePlugin): +class Sign(Plugin): """每日签到""" CHECK_SERVER, COMMAND_RESULT = range(10400, 10402) def __init__( self, - redis: RedisDB = None, - user_service: UserService = None, - cookies_service: CookiesService = None, - sign_service: SignServices = None, - bot_admin_service: BotAdminService = None, + genshin_helper: GenshinHelper, + sign_service: SignServices, + user_admin_service: UserAdminService, + sign_system: SignSystem, ): - self.bot_admin_service = bot_admin_service - self.cookies_service = cookies_service - self.user_service = user_service + self.user_admin_service = user_admin_service self.sign_service = sign_service - self.system = SignSystem(redis) + self.sign_system = sign_system + self.genshin_helper = genshin_helper async def _process_auto_sign(self, user_id: int, chat_id: int, method: str) -> str: try: - await get_genshin_client(user_id) - except (UserNotFoundError, CookiesNotFoundError): + await self.genshin_helper.get_genshin_client(user_id) + except (PlayerNotFoundError, CookiesNotFoundError): return "未查询到账号信息,请先私聊派蒙绑定账号" user: SignUser = await self.sign_service.get_by_user_id(user_id) if user: if method == "关闭": await self.sign_service.remove(user) return "关闭自动签到成功" - elif method == "开启": + if method == "开启": if user.chat_id == chat_id: return "自动签到已经开启过了" user.chat_id = chat_id @@ -340,18 +65,15 @@ async def _process_auto_sign(self, user_id: int, chat_id: int, method: str) -> s @handler(CommandHandler, command="sign", block=False) @handler(MessageHandler, filters=filters.Regex("^每日签到(.*)"), block=False) - @restricts() - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: user = update.effective_user message = update.effective_message - args = get_args(context) + args = self.get_args(context) validate: Optional[str] = None if len(args) >= 1: msg = None if args[0] == "开启自动签到": - admin_list = await self.bot_admin_service.get_admin_list() - if user.id in admin_list: + if await self.user_admin_service.is_admin(user.id): msg = await self._process_auto_sign(user.id, message.chat_id, "开启") else: msg = await self._process_auto_sign(user.id, user.id, "开启") @@ -360,58 +82,61 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: else: validate = args[0] if msg: - logger.info(f"用户 {user.full_name}[{user.id}] 自动签到命令请求 || 参数 {args[0]}") + logger.info("用户 %s[%s] 自动签到命令请求 || 参数 %s", user.full_name, user.id, args[0]) reply_message = await message.reply_text(msg) if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message.chat_id, delay=30) return - logger.info(f"用户 {user.full_name}[{user.id}] 每日签到命令请求") + logger.info("用户 %s[%s] 每日签到命令请求", user.full_name, user.id) if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, message.chat_id, message.message_id) + self.add_delete_message_job(message) try: - client = await get_genshin_client(user.id) + client = await self.genshin_helper.get_genshin_client(user.id) await message.reply_chat_action(ChatAction.TYPING) - _, challenge = await self.system.get_challenge(client.uid) + _, challenge = await self.sign_system.get_challenge(client.uid) if validate: - _, challenge = await self.system.get_challenge(client.uid) + _, challenge = await self.sign_system.get_challenge(client.uid) if challenge: - sign_text = await self.system.start_sign(client, challenge=challenge, validate=validate) + sign_text = await self.sign_system.start_sign(client, challenge=challenge, validate=validate) else: reply_message = await message.reply_text("请求已经过期", allow_sending_without_reply=True) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(reply_message) return else: - sign_text = await self.system.start_sign(client) + sign_text = await self.sign_system.start_sign(client) reply_message = await message.reply_text(sign_text, allow_sending_without_reply=True) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) - except (UserNotFoundError, CookiesNotFoundError): + self.add_delete_message_job(reply_message) + except (PlayerNotFoundError, CookiesNotFoundError): buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(message.chat_id, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) except NeedChallenge as exc: - button = await self.system.get_challenge_button( - exc.uid, user.id, exc.gt, exc.challenge, not filters.ChatType.PRIVATE.filter(message) + button = await self.sign_system.get_challenge_button( + context.bot.username, + exc.uid, + user.id, + exc.gt, + exc.challenge, + not filters.ChatType.PRIVATE.filter(message), ) reply_message = await message.reply_text( f"UID {exc.uid} 签到失败,触发验证码风控,请尝试点击下方按钮重新签到", allow_sending_without_reply=True, reply_markup=button ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(reply_message) @handler(CallbackQueryHandler, pattern=r"^sign\|", block=False) - @restricts(restricts_time_of_groups=20, without_overlapping=True) - @error_callable - async def sign_gen_link(self, update: Update, _: CallbackContext) -> None: + async def sign_gen_link(self, update: Update, context: CallbackContext) -> None: callback_query = update.callback_query user = callback_query.from_user @@ -419,15 +144,15 @@ async def get_sign_callback(callback_query_data: str) -> Tuple[int, int]: _data = callback_query_data.split("|") _user_id = int(_data[1]) _uid = int(_data[2]) - logger.debug(f"get_sign_callback 函数返回 user_id[{_user_id}] uid[{_uid}]") + logger.debug("get_sign_callback 函数返回 user_id[%s] uid[%s]", _user_id, _uid) return _user_id, _uid user_id, uid = await get_sign_callback(callback_query.data) if user.id != user_id: await callback_query.answer(text="这不是你的按钮!\n" + config.notice.user_mismatch, show_alert=True) return - _, challenge = await self.system.get_challenge(uid) + _, challenge = await self.sign_system.get_challenge(uid) if not challenge: await callback_query.answer(text="验证请求已经过期,请重新发起签到!", show_alert=True) return - await callback_query.answer(url=create_deep_linked_url(bot.app.bot.username, "sign")) + await callback_query.answer(url=create_deep_linked_url(context.bot.username, "sign")) diff --git a/plugins/genshin/userstats.py b/plugins/genshin/stats.py similarity index 69% rename from plugins/genshin/userstats.py rename to plugins/genshin/stats.py index c92a8aed..89a3ae49 100644 --- a/plugins/genshin/userstats.py +++ b/plugins/genshin/stats.py @@ -5,72 +5,66 @@ from genshin.models import GenshinUserStats from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.constants import ChatAction -from telegram.ext import ( - CallbackContext, - CommandHandler, - MessageHandler, - filters, -) +from telegram.ext import CallbackContext, filters from telegram.helpers import create_deep_linked_url -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError, TooManyRequestPublicCookies from core.plugin import Plugin, handler -from core.template.models import RenderResult -from core.template.services import TemplateService -from core.user.error import UserNotFoundError -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import url_to_file, get_genshin_client, get_public_genshin_client +from core.services.cookies.error import TooManyRequestPublicCookies +from core.services.template.models import RenderResult +from core.services.template.services import TemplateService +from plugins.tools.genshin import GenshinHelper, PlayerNotFoundError, CookiesNotFoundError from utils.log import logger +__all__ = ("PlayerStatsPlugins",) -class UserStatsPlugins(Plugin, BasePlugin): + +class PlayerStatsPlugins(Plugin): """玩家统计查询""" - def __init__(self, template_service: TemplateService = None): - self.template_service = template_service + def __init__( + self, + template: TemplateService, + helper: GenshinHelper, + ): + self.template_service = template + self.helper = helper - @handler(CommandHandler, command="stats", block=False) - @handler(MessageHandler, filters=filters.Regex("^玩家统计查询(.*)"), block=False) - @restricts() - @error_callable + @handler.command("stats", block=False) + @handler.message(filters.Regex("^玩家统计查询(.*)"), block=False) async def command_start(self, update: Update, context: CallbackContext) -> Optional[int]: user = update.effective_user message = update.effective_message - logger.info(f"用户 {user.full_name}[{user.id}] 查询游戏用户命令请求") + logger.info("用户 %s[%s] 查询游戏用户命令请求", user.full_name, user.id) uid: Optional[int] = None try: args = context.args if args is not None and len(args) >= 1: uid = int(args[0]) except ValueError as exc: - logger.warning(f"获取 uid 发生错误! 错误信息为 {repr(exc)}") + logger.warning("获取 uid 发生错误! 错误信息为 %s", str(exc)) await message.reply_text("输入错误") return try: try: - client = await get_genshin_client(user.id) + client = await self.helper.get_genshin_client(user.id) except CookiesNotFoundError: - client, uid = await get_public_genshin_client(user.id) + client, uid = await self.helper.get_public_genshin_client(user.id) render_result = await self.render(client, uid) - except UserNotFoundError: - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] + except PlayerNotFoundError: + buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) return except GenshinException as exc: - if exc.retcode == 1034: - if uid: - await message.reply_text("出错了呜呜呜 ~ 请稍后重试") - return + if exc.retcode == 1034 and uid: + await message.reply_text("出错了呜呜呜 ~ 请稍后重试") + return raise exc except TooManyRequestPublicCookies: await message.reply_text("用户查询次数过多 请稍后重试") @@ -133,13 +127,12 @@ async def render(self, client: Client, uid: Optional[int] = None) -> RenderResul full_page=True, ) - @staticmethod - async def cache_images(data: GenshinUserStats) -> None: + async def cache_images(self, data: GenshinUserStats) -> None: """缓存所有图片到本地""" # TODO: 并发下载所有资源 # 探索地区 for item in data.explorations: item.__config__.allow_mutation = True - item.icon = await url_to_file(item.icon) - item.cover = await url_to_file(item.cover) + item.icon = await self.download_resource(item.icon) + item.cover = await self.download_resource(item.cover) diff --git a/plugins/genshin/strategy.py b/plugins/genshin/strategy.py index 6424f9de..23f73416 100644 --- a/plugins/genshin/strategy.py +++ b/plugins/genshin/strategy.py @@ -1,24 +1,16 @@ -from telegram import InlineKeyboardButton, InlineKeyboardMarkup -from telegram import Update -from telegram.constants import ChatAction -from telegram.constants import ParseMode -from telegram.ext import CommandHandler, CallbackContext -from telegram.ext import MessageHandler, filters +from telegram import InlineKeyboardButton, InlineKeyboardMarkup, Update +from telegram.constants import ChatAction, ParseMode +from telegram.ext import CallbackContext, filters -from core.baseplugin import BasePlugin -from core.game.services import GameStrategyService from core.plugin import Plugin, handler -from core.search.models import StrategyEntry -from core.search.services import SearchServices +from core.services.game.services import GameStrategyService +from core.services.search.models import StrategyEntry +from core.services.search.services import SearchServices from metadata.shortname import roleToName, roleToTag -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import url_to_file from utils.log import logger -class StrategyPlugin(Plugin, BasePlugin): +class StrategyPlugin(Plugin): """角色攻略查询""" KEYBOARD = [[InlineKeyboardButton(text="查看角色攻略列表并查询", switch_inline_query_current_chat="查看角色攻略列表并查询")]] @@ -31,21 +23,19 @@ def __init__( self.game_strategy_service = game_strategy_service self.search_service = search_service - @handler(CommandHandler, command="strategy", block=False) - @handler(MessageHandler, filters=filters.Regex("^角色攻略查询(.*)"), block=False) - @restricts() - @error_callable + @handler.command(command="strategy", block=False) + @handler.message(filters=filters.Regex("^角色攻略查询(.*)"), block=False) async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) if len(args) >= 1: character_name = args[0] else: reply_message = await message.reply_text("请回复你要查询的攻略的角色名", reply_markup=InlineKeyboardMarkup(self.KEYBOARD)) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return character_name = roleToName(character_name) url = await self.game_strategy_service.get_strategy(character_name) @@ -54,12 +44,12 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: f"没有找到 {character_name} 的攻略", reply_markup=InlineKeyboardMarkup(self.KEYBOARD) ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return - logger.info(f"用户 {user.full_name}[{user.id}] 查询角色攻略命令请求 || 参数 {character_name}") + logger.info("用户 %s[%s] 查询角色攻略命令请求 || 参数 %s", user.full_name, user.id, character_name) await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) - file_path = await url_to_file(url, return_path=True) + file_path = await self.download_resource(url, return_path=True) caption = f"From 米游社 西风驿站 查看原图" reply_photo = await message.reply_photo( photo=open(file_path, "rb"), diff --git a/plugins/genshin/uid.py b/plugins/genshin/uid.py deleted file mode 100644 index a7863f26..00000000 --- a/plugins/genshin/uid.py +++ /dev/null @@ -1,199 +0,0 @@ -from typing import Optional - -import genshin -from genshin import GenshinException, types, DataNotPublic -from telegram import Update, ReplyKeyboardRemove, ReplyKeyboardMarkup, TelegramObject -from telegram.ext import CallbackContext, filters, ConversationHandler -from telegram.helpers import escape_markdown - -from core.baseplugin import BasePlugin -from core.cookies.error import CookiesNotFoundError, TooManyRequestPublicCookies -from core.cookies.services import CookiesService, PublicCookiesService -from core.plugin import Plugin, handler, conversation -from core.user.error import UserNotFoundError -from core.user.models import User -from core.user.services import UserService -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.log import logger -from utils.models.base import RegionEnum - - -class SetUserUidCommandData(TelegramObject): - user: Optional[User] = None - region: RegionEnum = RegionEnum.HYPERION - game_uid: int = 0 - - -CHECK_SERVER, CHECK_UID, COMMAND_RESULT = range(10100, 10103) - - -class SetUserUid(Plugin.Conversation, BasePlugin.Conversation): - """UID用户绑定""" - - def __init__( - self, - user_service: UserService = None, - cookies_service: CookiesService = None, - public_cookies_service: PublicCookiesService = None, - ): - self.public_cookies_service = public_cookies_service - self.cookies_service = cookies_service - self.user_service = user_service - - @conversation.entry_point - @handler.command(command="setuid", filters=filters.ChatType.PRIVATE, block=True) - @restricts() - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - logger.info(f"用户 {user.full_name}[{user.id}] 绑定账号命令请求") - set_user_uid_command_data: SetUserUidCommandData = context.chat_data.get("set_user_uid_command_data") - if set_user_uid_command_data is None: - cookies_command_data = SetUserUidCommandData() - context.chat_data["set_user_uid_command_data"] = cookies_command_data - text = ( - f"你好 {user.mention_markdown_v2()} " - f'{escape_markdown("!请输入通行证UID(非游戏UID),BOT将会通过通行证UID查找游戏UID。请选择要绑定的服务器!或回复退出取消操作")}' - ) - reply_keyboard = [["米游社", "HoYoLab"], ["退出"]] - await message.reply_markdown_v2(text, reply_markup=ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)) - return CHECK_SERVER - - @conversation.state(state=CHECK_SERVER) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable - async def check_server(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - set_user_uid_command_data: SetUserUidCommandData = context.chat_data.get("set_user_uid_command_data") - if message.text == "退出": - await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - elif message.text == "米游社": - region = set_user_uid_command_data.region = RegionEnum.HYPERION - elif message.text == "HoYoLab": - region = set_user_uid_command_data.region = RegionEnum.HOYOLAB - else: - await message.reply_text("选择错误,请重新选择") - return CHECK_SERVER - try: - user_info = await self.user_service.get_user_by_id(user.id) - set_user_uid_command_data.user = user_info - except UserNotFoundError: - set_user_uid_command_data.user = None - user_info = None - if user_info is not None: - try: - await self.cookies_service.get_cookies(user.id, region) - except CookiesNotFoundError: - pass - else: - await message.reply_text("你已经通过 Cookie 绑定了账号,无法继续下一步") - return ConversationHandler.END - await message.reply_text("请输入你的通行证UID(非游戏UID)", reply_markup=ReplyKeyboardRemove()) - return CHECK_UID - - @conversation.state(state=CHECK_UID) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable - async def check_cookies(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - set_user_uid_command_data: SetUserUidCommandData = context.chat_data.get("set_user_uid_command_data") - region = set_user_uid_command_data.region - if message.text == "退出": - await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - try: - hoyolab_uid = int(message.text) - except ValueError: - await message.reply_text("UID 格式有误,请检查", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - try: - cookies = await self.public_cookies_service.get_cookies(user.id, region) - except TooManyRequestPublicCookies: - await message.reply_text("用户查询次数过多,请稍后重试", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - if region == RegionEnum.HYPERION: - client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE) - elif region == RegionEnum.HOYOLAB: - client = genshin.Client( - cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" - ) - else: - return ConversationHandler.END - try: - user_info = (await client.get_record_cards(hoyolab_uid))[0] - except DataNotPublic: - await message.reply_text("角色未公开", reply_markup=ReplyKeyboardRemove()) - logger.warning(f"获取账号信息发生错误 hoyolab_uid[{hoyolab_uid}] 账户信息未公开") - return ConversationHandler.END - except GenshinException as exc: - await message.reply_text("获取账号信息发生错误", reply_markup=ReplyKeyboardRemove()) - logger.error("获取账号信息发生错误") - logger.exception(exc) - return ConversationHandler.END - if user_info.game != types.Game.GENSHIN: - await message.reply_text("角色信息查询返回非原神游戏信息," "请设置展示主界面为原神", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - reply_keyboard = [["确认", "退出"]] - await message.reply_text("获取角色基础信息成功,请检查是否正确!") - logger.info(f"用户 {user.full_name}[{user.id}] 获取账号 {user_info.nickname}[{user_info.uid}] 信息成功") - text = ( - f"*角色信息*\n" - f"角色名称:{escape_markdown(user_info.nickname, version=2)}\n" - f"角色等级:{user_info.level}\n" - f"UID:`{user_info.uid}`\n" - f"服务器名称:`{user_info.server_name}`\n" - ) - set_user_uid_command_data.game_uid = user_info.uid - await message.reply_markdown_v2(text, reply_markup=ReplyKeyboardMarkup(reply_keyboard, one_time_keyboard=True)) - return COMMAND_RESULT - - @conversation.state(state=COMMAND_RESULT) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable - async def command_result(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - set_user_uid_command_data: SetUserUidCommandData = context.chat_data.get("set_user_uid_command_data") - if message.text == "退出": - await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - elif message.text == "确认": - if set_user_uid_command_data.user is None: - if set_user_uid_command_data.region == RegionEnum.HYPERION: - user_db = User( - user_id=user.id, - yuanshen_uid=set_user_uid_command_data.game_uid, - region=set_user_uid_command_data.region, - ) - elif set_user_uid_command_data.region == RegionEnum.HOYOLAB: - user_db = User( - user_id=user.id, - genshin_uid=set_user_uid_command_data.game_uid, - region=set_user_uid_command_data.region, - ) - else: - await message.reply_text("数据错误") - return ConversationHandler.END - await self.user_service.add_user(user_db) - else: - user_db = set_user_uid_command_data.user - user_db.region = set_user_uid_command_data.region - if set_user_uid_command_data.region == RegionEnum.HYPERION: - user_db.yuanshen_uid = set_user_uid_command_data.game_uid - elif set_user_uid_command_data.region == RegionEnum.HOYOLAB: - user_db.genshin_uid = set_user_uid_command_data.game_uid - else: - await message.reply_text("数据错误") - return ConversationHandler.END - await self.user_service.update_user(user_db) - logger.info(f"用户 {user.full_name}[{user.id}] 绑定UID账号成功") - await message.reply_text("保存成功", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - else: - await message.reply_text("回复错误,请重新输入") - return COMMAND_RESULT diff --git a/plugins/genshin/user.py b/plugins/genshin/user.py deleted file mode 100644 index b7159ebe..00000000 --- a/plugins/genshin/user.py +++ /dev/null @@ -1,121 +0,0 @@ -from typing import Optional - -from telegram import Update, TelegramObject, User, ReplyKeyboardRemove -from telegram.ext import CallbackContext, filters, ConversationHandler - -from core.baseplugin import BasePlugin -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError -from core.cookies.models import Cookies -from core.plugin import Plugin, handler, conversation -from core.sign import SignServices -from core.user import UserService -from core.user.error import UserNotFoundError -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.log import logger -from utils.models.base import RegionEnum - - -class DelUserCommandData(TelegramObject): - user: Optional[User] = None - region: RegionEnum = RegionEnum.HYPERION - cookies: Optional[Cookies] = None - - -CHECK_SERVER, DEL_USER = range(10800, 10802) - - -class UserPlugin(Plugin.Conversation, BasePlugin.Conversation): - def __init__( - self, - user_service: UserService = None, - cookies_service: CookiesService = None, - sign_service: SignServices = None, - ): - self.cookies_service = cookies_service - self.user_service = user_service - self.sign_service = sign_service - - @conversation.entry_point - @handler.command(command="deluser", filters=filters.ChatType.PRIVATE, block=True) - @restricts() - @error_callable - async def command_start(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - logger.info("用户 %s[%s] 删除账号命令请求", user.full_name, user.id) - del_user_command_data: DelUserCommandData = context.chat_data.get("del_user_command_data") - if del_user_command_data is None: - del_user_command_data = DelUserCommandData() - context.chat_data["del_user_command_data"] = del_user_command_data - try: - user_info = await self.user_service.get_user_by_id(user.id) - del_user_command_data.user = user_info - except UserNotFoundError: - await message.reply_text("用户未找到") - return ConversationHandler.END - cookies_status: bool = False - try: - cookies = await self.cookies_service.get_cookies(user.id, user_info.region) - del_user_command_data.cookies = cookies - cookies_status = True - except CookiesNotFoundError: - logger.info("用户 %s[%s] Cookies 不存在", user.full_name, user.id) - if user_info.region == RegionEnum.HYPERION: - uid = user_info.yuanshen_uid - region_str = "米游社" - del_user_command_data.region = RegionEnum.HYPERION - elif user_info.region == RegionEnum.HOYOLAB: - uid = user_info.genshin_uid - region_str = "HoYoLab" - del_user_command_data.region = RegionEnum.HOYOLAB - else: - await message.reply_text("数据非法") - return ConversationHandler.END - await message.reply_text("获取用户信息成功") - text = ( - f"绑定信息\n" - f"UID:{uid}\n" - f"注册:{region_str}\n" - f"是否绑定Cookie:{'√' if cookies_status else '×'}" - ) - await message.reply_html(text) - await message.reply_html("请回复确认即可解除绑定并从数据库移除,如绑定Cookies也会跟着一起从数据库删除,删除后操作无法逆转,回复 /cancel 可退出操作") - return DEL_USER - - @conversation.state(state=DEL_USER) - @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=True) - @error_callable - async def command_result(self, update: Update, context: CallbackContext) -> int: - user = update.effective_user - message = update.effective_message - if message.text == "退出": - await message.reply_text("退出任务", reply_markup=ReplyKeyboardRemove()) - return ConversationHandler.END - elif message.text == "确认": - del_user_command_data: DelUserCommandData = context.chat_data.get("del_user_command_data") - sign = await self.sign_service.get_by_user_id(user.id) - if sign: - await self.sign_service.remove(sign) - logger.success("用户 %s[%s] 从数据库删除定时签到成功", user.full_name, user.id) - try: - await self.user_service.del_user_by_id(user.id) - except UserNotFoundError: - await message.reply_text("用户未找到") - return ConversationHandler.END - else: - logger.success("用户 %s[%s] 从数据库删除账号成功", user.full_name, user.id) - cookies = del_user_command_data.cookies - if cookies: - try: - await self.cookies_service.del_cookies(user.id, del_user_command_data.region) - except CookiesNotFoundError: - logger.warning("用户 %s[%s] Cookies 不存在", user.full_name, user.id) - else: - logger.success("用户 %s[%s] 从数据库删除Cookies成功", user.full_name, user.id) - await message.reply_text("删除成功") - return ConversationHandler.END - else: - await message.reply_text("回复错误,退出当前会话") - return ConversationHandler.END diff --git a/plugins/genshin/verification.py b/plugins/genshin/verification.py deleted file mode 100644 index a8d19576..00000000 --- a/plugins/genshin/verification.py +++ /dev/null @@ -1,96 +0,0 @@ -from typing import Tuple, Optional - -from genshin import Region, GenshinException -from telegram import Update, WebAppInfo, KeyboardButton, ReplyKeyboardMarkup -from telegram.ext import CallbackContext, filters - -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin -from core.config import config -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError -from core.plugin import Plugin, handler -from core.user import UserService -from core.user.error import UserNotFoundError -from modules.apihelper.client.components.verify import Verify -from modules.apihelper.error import ResponseException -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client -from utils.log import logger - - -class VerificationSystem: - def __init__(self, redis: RedisDB = None): - self.cache = redis.client - self.qname = "plugin:verification:" - - async def get_challenge(self, uid: int) -> Tuple[Optional[str], Optional[str]]: - data = await self.cache.get(f"{self.qname}{uid}") - if not data: - return None, None - data = data.decode("utf-8").split("|") - return data[0], data[1] - - async def set_challenge(self, uid: int, gt: str, challenge: str): - await self.cache.set(f"{self.qname}{uid}", f"{gt}|{challenge}") - await self.cache.expire(f"{self.qname}{uid}", 10 * 60) - - -class VerificationPlugins(Plugin, BasePlugin): - def __init__(self, user_service: UserService = None, cookies_service: CookiesService = None, redis: RedisDB = None): - self.cookies_service = cookies_service - self.user_service = user_service - self.system = VerificationSystem(redis) - - @handler.command("verify", filters=filters.ChatType.PRIVATE, block=False) - @restricts(restricts_time=60) - @error_callable - async def verify(self, update: Update, context: CallbackContext) -> None: - user = update.effective_user - message = update.effective_message - logger.info("用户 %s[%s] 发出verify命令", user.full_name, user.id) - try: - client = await get_genshin_client(user.id) - if client.region != Region.CHINESE: - await message.reply_text("非法用户") - return - except UserNotFoundError: - await message.reply_text("用户未找到") - return - except CookiesNotFoundError: - await message.reply_text("检测到用户为UID绑定,无需认证") - return - is_high: bool = False - verification = Verify(cookies=client.cookie_manager.cookies) - if not context.args: - try: - await client.get_genshin_notes() - except GenshinException as exc: - if exc.retcode == 1034: - is_high = True - else: - raise exc - else: - await message.reply_text("账户正常,无需认证") - return - try: - data = await verification.create(is_high=is_high) - challenge = data["challenge"] - gt = data["gt"] - logger.success("用户 %s[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", user.full_name, user.id, gt, challenge) - except ResponseException as exc: - logger.warning("用户 %s[%s] 创建验证失效 API返回 [%s]%s", user.full_name, user.id, exc.code, exc.message) - await message.reply_text(f"创建验证失败 错误信息为 [{exc.code}]{exc.message} 请稍后重试") - return - await self.system.set_challenge(client.uid, gt, challenge) - url = f"{config.pass_challenge_user_web}/webapp?username={context.bot.username}&command=verify>={gt}&challenge={challenge}&uid={client.uid}" - await message.reply_text( - "请尽快在10秒内完成手动验证\n或发送 /web_cancel 取消操作", - reply_markup=ReplyKeyboardMarkup.from_button( - KeyboardButton( - text="点我手动验证", - web_app=WebAppInfo(url=url), - ) - ), - ) diff --git a/plugins/genshin/verify.py b/plugins/genshin/verify.py new file mode 100644 index 00000000..494e37e5 --- /dev/null +++ b/plugins/genshin/verify.py @@ -0,0 +1,39 @@ +from telegram import KeyboardButton, ReplyKeyboardMarkup, Update, WebAppInfo +from telegram.ext import CallbackContext, filters + +from core.config import config +from core.plugin import Plugin, handler +from plugins.tools.challenge import ChallengeSystem, ChallengeSystemException +from utils.log import logger + + +class VerificationPlugins(Plugin): + def __init__( + self, + challenge_system: ChallengeSystem, + ): + self.challenge_system = challenge_system + + @handler.command("verify", filters=filters.ChatType.PRIVATE, block=False) + async def verify(self, update: Update, context: CallbackContext) -> None: + user = update.effective_user + message = update.effective_message + logger.info("用户 %s[%s] 发出verify命令", user.full_name, user.id) + try: + uid, gt, challenge = await self.challenge_system.create_challenge(user.id, context.args is not None) + except ChallengeSystemException as exc: + await message.reply_text(exc.message) + return + url = ( + f"{config.pass_challenge_user_web}/webapp?" + f"username={context.bot.username}&command=verify>={gt}&challenge={challenge}&uid={uid}" + ) + await message.reply_text( + "请尽快在10秒内完成手动验证\n或发送 /web_cancel 取消操作", + reply_markup=ReplyKeyboardMarkup.from_button( + KeyboardButton( + text="点我手动验证", + web_app=WebAppInfo(url=url), + ) + ), + ) diff --git a/plugins/genshin/weapon.py b/plugins/genshin/weapon.py index 0629d9c0..70c333d4 100644 --- a/plugins/genshin/weapon.py +++ b/plugins/genshin/weapon.py @@ -2,24 +2,19 @@ from telegram.constants import ChatAction from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters -from core.base.assets import AssetsService, AssetsCouldNotFound -from core.baseplugin import BasePlugin +from core.dependence.assets import AssetsCouldNotFound, AssetsService from core.plugin import Plugin, handler -from core.search.models import WeaponEntry -from core.search.services import SearchServices -from core.template import TemplateService -from core.wiki.services import WikiService +from core.services.search.models import WeaponEntry +from core.services.search.services import SearchServices +from core.services.template.services import TemplateService +from core.services.wiki.services import WikiService from metadata.genshin import honey_id_to_game_id from metadata.shortname import weaponToName, weapons as _weapons_data from modules.wiki.weapon import Weapon -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts -from utils.helpers import url_to_file from utils.log import logger -class WeaponPlugin(Plugin, BasePlugin): +class WeaponPlugin(Plugin): """武器查询""" KEYBOARD = [[InlineKeyboardButton(text="查看武器列表并查询", switch_inline_query_current_chat="查看武器列表并查询")]] @@ -38,22 +33,20 @@ def __init__( @handler(CommandHandler, command="weapon", block=False) @handler(MessageHandler, filters=filters.Regex("^武器查询(.*)"), block=False) - @error_callable - @restricts() async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) if len(args) >= 1: weapon_name = args[0] else: reply_message = await message.reply_text("请回复你要查询的武器", reply_markup=InlineKeyboardMarkup(self.KEYBOARD)) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return weapon_name = weaponToName(weapon_name) - logger.info(f"用户 {user.full_name}[{user.id}] 查询武器命令请求 || 参数 weapon_name={weapon_name}") + logger.info("用户 %s[%s] 查询角色攻略命令请求 weapon_name[%s]", user.full_name, user.id, weapon_name) weapons_list = await self.wiki_service.get_weapons_list() for weapon in weapons_list: if weapon.name == weapon_name: @@ -64,8 +57,8 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: f"没有找到 {weapon_name}", reply_markup=InlineKeyboardMarkup(self.KEYBOARD) ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return await message.reply_chat_action(ChatAction.TYPING) @@ -78,7 +71,7 @@ async def input_template_data(_weapon_data: Weapon): bonus = str(round(float(bonus))) _template_data = { "weapon_name": _weapon_data.name, - "weapon_info_type_img": await url_to_file(_weapon_data.weapon_type.icon_url()), + "weapon_info_type_img": await self.download_resource(_weapon_data.weapon_type.icon_url()), "progression_secondary_stat_value": bonus, "progression_secondary_stat_name": _weapon_data.attribute.type.value, "weapon_info_source_img": ( @@ -96,7 +89,7 @@ async def input_template_data(_weapon_data: Weapon): else: _template_data = { "weapon_name": _weapon_data.name, - "weapon_info_type_img": await url_to_file(_weapon_data.weapon_type.icon_url()), + "weapon_info_type_img": await self.download_resource(_weapon_data.weapon_type.icon_url()), "progression_secondary_stat_value": " ", "progression_secondary_stat_name": "无其它属性加成", "weapon_info_source_img": ( @@ -119,8 +112,8 @@ async def input_template_data(_weapon_data: Weapon): logger.warning("%s weapon_name[%s]", exc.message, weapon_name) reply_message = await message.reply_text(f"数据库中没有找到 {weapon_name}") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id) + self.add_delete_message_job(message) + self.add_delete_message_job(reply_message) return png_data = await self.template_service.render( "genshin/weapon/weapon.html", template_data, {"width": 540, "height": 540}, ttl=31 * 24 * 60 * 60 diff --git a/plugins/genshin/wiki.py b/plugins/genshin/wiki.py deleted file mode 100644 index b2703b32..00000000 --- a/plugins/genshin/wiki.py +++ /dev/null @@ -1,21 +0,0 @@ -from telegram import Update -from telegram.ext import CommandHandler, CallbackContext - -from core.plugin import Plugin, handler -from core.wiki.services import WikiService -from utils.decorators.admins import bot_admins_rights_check - - -class Wiki(Plugin): - """有关WIKI操作""" - - def __init__(self, wiki_service: WikiService = None): - self.wiki_service = wiki_service - - @handler(CommandHandler, command="refresh_wiki", block=False) - @bot_admins_rights_check - async def refresh_wiki(self, update: Update, _: CallbackContext): - message = update.effective_message - await message.reply_text("正在刷新Wiki缓存,请稍等") - await self.wiki_service.refresh_wiki() - await message.reply_text("刷新Wiki缓存成功") diff --git a/plugins/genshin/gacha/gacha.py b/plugins/genshin/wish.py similarity index 84% rename from plugins/genshin/gacha/gacha.py rename to plugins/genshin/wish.py index a072ef67..6ef2f363 100644 --- a/plugins/genshin/gacha/gacha.py +++ b/plugins/genshin/wish.py @@ -3,17 +3,15 @@ from datetime import datetime from typing import Any, List, Optional, Tuple, Union -import ujson as json from bs4 import BeautifulSoup from telegram import Update from telegram.constants import ChatAction from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters -from core.base.assets import AssetsService -from core.base.redisdb import RedisDB -from core.baseplugin import BasePlugin +from core.dependence.assets import AssetsService +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.template import TemplateService +from core.services.template.services import TemplateService from metadata.genshin import AVATAR_DATA, WEAPON_DATA, avatar_to_game_id, weapon_to_game_id from metadata.shortname import weaponToName from modules.apihelper.client.components.gacha import Gacha as GachaClient @@ -21,11 +19,14 @@ from modules.gacha.banner import BannerType, GachaBanner from modules.gacha.player.info import PlayerGachaInfo from modules.gacha.system import BannerSystem -from utils.bot import get_args -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger +try: + import ujson as jsonlib + +except ImportError: + import json as jsonlib + class GachaNotFound(Exception): """卡池未找到""" @@ -52,14 +53,14 @@ async def get(self, user_id: int) -> PlayerGachaInfo: data = await self.client.get(f"{self.qname}{user_id}") if data is None: return PlayerGachaInfo() - return PlayerGachaInfo(**json.loads(data)) + return PlayerGachaInfo(**jsonlib.loads(data)) async def set(self, user_id: int, player_gacha_info: PlayerGachaInfo): value = player_gacha_info.json() await self.client.set(f"{self.qname}{user_id}", value) -class GachaHandle: +class WishSimulatorHandle: def __init__(self): self.hyperion = GachaClient() @@ -119,8 +120,7 @@ async def gacha_base_info(self, gacha_name: str = "角色活动", default: bool else: # pylint: disable=W0120 if default and len(gacha_list_info) > 0: return gacha_list_info[0] - else: - raise GachaNotFound(gacha_name) + raise GachaNotFound(gacha_name) @staticmethod def de_title(title: str) -> Union[Tuple[str, None], Tuple[str, Any]]: @@ -134,12 +134,12 @@ def de_title(title: str) -> Union[Tuple[str, None], Tuple[str, Any]]: return title_html.text, title_html.p -class Gacha(Plugin, BasePlugin): +class WishSimulatorPlugin(Plugin): """抽卡模拟器(非首模拟器/减寿模拟器)""" - def __init__(self, assets: AssetsService = None, template_service: TemplateService = None, redis: RedisDB = None): + def __init__(self, assets: AssetsService, template_service: TemplateService, redis: RedisDB): self.gacha_db = GachaRedis(redis) - self.handle = GachaHandle() + self.handle = WishSimulatorHandle() self.banner_system = BannerSystem() self.template_service = template_service self.banner_cache = {} @@ -157,6 +157,8 @@ async def get_banner(self, gacha_base_info: GachaInfo): async def de_item_list(self, item_list: List[int]) -> List[dict]: gacha_item: List[dict] = [] for item_id in item_list: + if item_id is None: + continue if 10000 <= item_id <= 100000: data = WEAPON_DATA.get(str(item_id)) avatar = self.assets_service.weapon(item_id) @@ -175,15 +177,31 @@ async def de_item_list(self, item_list: List[int]) -> List[dict]: gacha_item.append(data) return gacha_item - @handler(CommandHandler, command="gacha", block=False) + async def shutdown(self) -> None: + pass + # todo 目前清理消息无法执行 因为先停止Job导致无法获取全部信息 + # logger.info("正在清理消息") + # job_queue = self.application.telegram.job_queue + # jobs = job_queue.jobs() + # for job in jobs: + # if "wish_simulator" in job.name and not job.removed: + # logger.info("当前Job name %s", job.name) + # try: + # await job.run(job_queue.application) + # except CancelledError: + # continue + # except Exception as exc: + # logger.warning("执行失败 %", str(exc)) + # else: + # logger.info("Jobs为空") + # logger.success("清理卡池消息成功") + @handler(CommandHandler, command="wish", block=False) @handler(MessageHandler, filters=filters.Regex("^抽卡模拟器(.*)"), block=False) - @restricts(restricts_time=3, restricts_time_of_groups=20) - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) gacha_name = "角色活动" if len(args) >= 1: gacha_name = args[0] @@ -222,8 +240,8 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: logger.warning("角色 item_id[%s] 抽卡立绘未找到", exc.item_id) reply_message = await message.reply_text("出错了呜呜呜 ~ 卡池部分数据未找到!") if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 60) - self._add_delete_message_job(context, message.chat_id, message.message_id, 60) + self.add_delete_message_job(reply_message, name="wish_simulator") + self.add_delete_message_job(message.chat_id, name="wish_simulator") return player_gacha_banner_info = player_gacha_info.get_banner_info(banner) template_data = { @@ -235,10 +253,6 @@ async def command_start(self, update: Update, context: CallbackContext) -> None: "items": [], "wish_name": "", } - # logger.debug(f"{banner.banner_id}") - # logger.debug(f"{banner.banner_type}") - # logger.debug(f"{banner.rate_up_items5}") - # logger.debug(f"{banner.fallback_items5_pool1}") if player_gacha_banner_info.wish_item_id != 0: weapon = WEAPON_DATA.get(str(player_gacha_banner_info.wish_item_id)) if weapon is not None: @@ -257,24 +271,22 @@ def take_rang(elem: dict): reply_message = await message.reply_photo(png_data.photo) if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 300) - self._add_delete_message_job(context, message.chat_id, message.message_id, 300) + self.add_delete_message_job(reply_message, name="wish_simulator") + self.add_delete_message_job(message, name="wish_simulator") @handler(CommandHandler, command="set_wish", block=False) @handler(MessageHandler, filters=filters.Regex("^非首模拟器定轨(.*)"), block=False) - @restricts(restricts_time=3, restricts_time_of_groups=20) - @error_callable async def set_wish(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) try: gacha_base_info = await self.handle.gacha_base_info("武器活动") except GachaNotFound: reply_message = await message.reply_text("当前还没有武器正在 UP,可能是卡池不存在或者卡池已经结束。") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 10) + self.add_delete_message_job(message, delay=30) + self.add_delete_message_job(reply_message, delay=30) return banner = await self.get_banner(gacha_base_info) up_weapons = {} @@ -289,8 +301,8 @@ async def set_wish(self, update: Update, context: CallbackContext) -> None: else: reply_message = await message.reply_text(f"输入的参数不正确,请输入需要定轨的武器名称。\n{up_weapons_text}") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 10) + self.add_delete_message_job(message, delay=30) + self.add_delete_message_job(reply_message, delay=30) return weapon_name = weaponToName(weapon_name) player_gacha_info = await self.gacha_db.get(user.id) @@ -302,12 +314,11 @@ async def set_wish(self, update: Update, context: CallbackContext) -> None: f"输入的参数不正确,可能是没有名为 {weapon_name} 的武器或该武器不存在当前 UP 卡池中\n{up_weapons_text}" ) if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 10) + self.add_delete_message_job(message, delay=30) + self.add_delete_message_job(reply_message, delay=30) return await self.gacha_db.set(user.id, player_gacha_info) reply_message = await message.reply_text(f"抽卡模拟器定轨 {weapon_name} 武器成功") if filters.ChatType.GROUPS.filter(reply_message): - self._add_delete_message_job(context, message.chat_id, message.message_id, 10) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 10) - return + self.add_delete_message_job(message, delay=30) + self.add_delete_message_job(reply_message, delay=30) diff --git a/plugins/genshin/gacha/gacha_log.py b/plugins/genshin/wish_log.py similarity index 74% rename from plugins/genshin/gacha/gacha_log.py rename to plugins/genshin/wish_log.py index c69495d0..8f2049a4 100644 --- a/plugins/genshin/gacha/gacha_log.py +++ b/plugins/genshin/wish_log.py @@ -1,44 +1,35 @@ -import contextlib from io import BytesIO import genshin from aiofiles import open as async_open from genshin.models import BannerType -from telegram import Update, User, Message, Document, InlineKeyboardButton, InlineKeyboardMarkup +from telegram import Document, InlineKeyboardButton, InlineKeyboardMarkup, Message, Update, User from telegram.constants import ChatAction -from telegram.ext import CallbackContext, CommandHandler, MessageHandler, filters, ConversationHandler +from telegram.ext import CallbackContext, CommandHandler, ConversationHandler, MessageHandler, filters from telegram.helpers import create_deep_linked_url -from core.base.assets import AssetsService -from core.baseplugin import BasePlugin -from core.config import config -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError -from core.plugin import Plugin, handler, conversation -from core.template import TemplateService -from core.template.models import FileType -from core.user import UserService -from core.user.error import UserNotFoundError -from metadata.scripts.paimon_moe import update_paimon_moe_zh, GACHA_LOG_PAIMON_MOE_PATH +from core.basemodel import RegionEnum +from core.dependence.assets import AssetsService +from core.plugin import Plugin, conversation, handler +from core.services.cookies import CookiesService +from core.services.players import PlayersService +from core.services.template.models import FileType +from core.services.template.services import TemplateService +from metadata.scripts.paimon_moe import GACHA_LOG_PAIMON_MOE_PATH, update_paimon_moe_zh from modules.gacha_log.error import ( - GachaLogInvalidAuthkey, - PaimonMoeGachaLogFileError, - GachaLogFileError, - GachaLogNotFound, GachaLogAccountNotFound, - GachaLogMixedProvider, GachaLogAuthkeyTimeout, + GachaLogFileError, + GachaLogInvalidAuthkey, + GachaLogMixedProvider, + GachaLogNotFound, + PaimonMoeGachaLogFileError, ) from modules.gacha_log.helpers import from_url_get_authkey from modules.gacha_log.log import GachaLog -from utils.bot import get_args -from utils.decorators.admins import bot_admins_rights_check -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts +from plugins.tools.genshin import PlayerNotFoundError, GenshinHelper from utils.genshin import get_authkey_by_stoken -from utils.helpers import get_genshin_client from utils.log import logger -from utils.models.base import RegionEnum try: import ujson as jsonlib @@ -49,24 +40,26 @@ INPUT_URL, INPUT_FILE, CONFIRM_DELETE = range(10100, 10103) -class GachaLogPlugin(Plugin.Conversation, BasePlugin.Conversation): +class WishLogPlugin(Plugin.Conversation): """抽卡记录导入/导出/分析""" def __init__( self, - template_service: TemplateService = None, - user_service: UserService = None, - assets: AssetsService = None, - cookie_service: CookiesService = None, + template_service: TemplateService, + players_service: PlayersService, + assets: AssetsService, + cookie_service: CookiesService, + helper: GenshinHelper, ): self.template_service = template_service - self.user_service = user_service + self.players_service = players_service self.assets_service = assets self.cookie_service = cookie_service self.zh_dict = None self.gacha_log = GachaLog() + self.helper = helper - async def __async_init__(self): + async def initialize(self) -> None: await update_paimon_moe_zh(False) async with async_open(GACHA_LOG_PAIMON_MOE_PATH, "r", encoding="utf-8") as load_f: self.zh_dict = jsonlib.loads(await load_f.read()) @@ -82,7 +75,7 @@ async def _refresh_user_data( """ try: logger.debug("尝试获取已绑定的原神账号") - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) if authkey: new_num = await self.gacha_log.get_gacha_log_data(user.id, client, authkey) return "更新完成,本次没有新增数据" if new_num == 0 else f"更新完成,本次共新增{new_num}条抽卡记录" @@ -101,7 +94,7 @@ async def _refresh_user_data( return "更新数据失败,authkey 已经过期" except GachaLogMixedProvider: return "导入失败,你已经通过其他方式导入过抽卡记录了,本次无法导入" - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) return "派蒙没有找到您所绑定的账号信息,请先私聊派蒙绑定账号" @@ -116,8 +109,8 @@ async def import_from_file(self, user: User, message: Message, document: Documen else: await message.reply_text("文件格式错误,请发送符合 UIGF 标准的抽卡记录文件或者 paimon.moe、非小酋导出的 xlsx 格式的抽卡记录文件") return - if document.file_size > config.plugin.download_file_max_size * 1024 * 1024: - await message.reply_text(f"文件过大,请发送小于 {config.plugin.download_file_max_size} MB 的文件") + if document.file_size > 2 * 1024 * 1024: + await message.reply_text("文件过大,请发送小于 2 MB 的文件") return try: out = BytesIO() @@ -158,44 +151,29 @@ async def import_from_file(self, user: User, message: Message, document: Documen @conversation.entry_point @handler(CommandHandler, command="gacha_log_import", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^导入抽卡记录(.*)") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable async def command_start(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user - args = get_args(context) + args = self.get_args(context) logger.info("用户 %s[%s] 导入抽卡记录命令请求", user.full_name, user.id) authkey = from_url_get_authkey(args[0] if args else "") if not args: - if message.document: - await self.import_from_file(user, message) - return ConversationHandler.END - elif message.reply_to_message and message.reply_to_message.document: - await self.import_from_file(user, message, document=message.reply_to_message.document) - return ConversationHandler.END - try: - user_info = await self.user_service.get_user_by_id(user.id) - except UserNotFoundError: - user_info = None - if user_info and user_info.region == RegionEnum.HYPERION: - try: - cookies = await self.cookie_service.get_cookies(user_info.user_id, user_info.region) - except CookiesNotFoundError: - cookies = None - if cookies and cookies.cookies and "stoken" in cookies.cookies: + player_info = await self.players_service.get_player(user.id, region=RegionEnum.HYPERION) + if player_info is not None: + cookies = await self.cookie_service.get(user.id, account_id=player_info.account_id) + if cookies is not None and cookies.data and "stoken" in cookies.data: if stuid := next( - (value for key, value in cookies.cookies.items() if key in ["ltuid", "login_uid"]), None + (value for key, value in cookies.data.items() if key in ["ltuid", "login_uid"]), None ): - cookies.cookies["stuid"] = stuid + cookies.data["stuid"] = stuid client = genshin.Client( - cookies=cookies.cookies, + cookies=cookies.data, game=genshin.types.Game.GENSHIN, region=genshin.Region.CHINESE, lang="zh-cn", - uid=user_info.yuanshen_uid, + uid=player_info.player_id, ) - with contextlib.suppress(Exception): - authkey = await get_authkey_by_stoken(client) + authkey = await get_authkey_by_stoken(client) if not authkey: await message.reply_text( "开始导入祈愿历史记录:请通过 https://paimon.moe/wish/import 获取抽卡记录链接后发送给我" @@ -218,17 +196,12 @@ async def command_start(self, update: Update, context: CallbackContext) -> int: @conversation.state(state=INPUT_URL) @handler.message(filters=~filters.COMMAND, block=False) - @restricts() - @error_callable async def import_data_from_message(self, update: Update, _: CallbackContext) -> int: message = update.effective_message user = update.effective_user if message.document: await self.import_from_file(user, message) return ConversationHandler.END - elif not message.text: - await message.reply_text("请发送正确的抽卡记录链接") - return INPUT_URL authkey = from_url_get_authkey(message.text) reply = await message.reply_text("小派蒙正在从服务器获取数据,请稍后") await message.reply_chat_action(ChatAction.TYPING) @@ -239,27 +212,16 @@ async def import_data_from_message(self, update: Update, _: CallbackContext) -> @conversation.entry_point @handler(CommandHandler, command="gacha_log_delete", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^删除抽卡记录(.*)") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable async def command_start_delete(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 删除抽卡记录命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) context.chat_data["uid"] = client.uid - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] - if filters.ChatType.GROUPS.filter(message): - reply_message = await message.reply_text( - "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) - ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) - else: - await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) + await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号") return ConversationHandler.END _, status = await self.gacha_log.load_history_info(str(user.id), str(client.uid), only_status=True) if not status: @@ -270,8 +232,6 @@ async def command_start_delete(self, update: Update, context: CallbackContext) - @conversation.state(state=CONFIRM_DELETE) @handler.message(filters=filters.TEXT & ~filters.COMMAND, block=False) - @restricts() - @error_callable async def command_confirm_delete(self, update: Update, context: CallbackContext) -> int: message = update.effective_message user = update.effective_user @@ -282,11 +242,10 @@ async def command_confirm_delete(self, update: Update, context: CallbackContext) await message.reply_text("已取消") return ConversationHandler.END - @handler(CommandHandler, command="gacha_log_force_delete", block=False) - @bot_admins_rights_check + @handler(CommandHandler, command="gacha_log_force_delete", block=False, admin=True) async def command_gacha_log_force_delete(self, update: Update, context: CallbackContext): message = update.effective_message - args = get_args(context) + args = self.get_args(context) if not args: await message.reply_text("请指定用户ID") return @@ -294,7 +253,7 @@ async def command_gacha_log_force_delete(self, update: Update, context: Callback cid = int(args[0]) if cid < 0: raise ValueError("Invalid cid") - client = await get_genshin_client(cid, need_cookie=False) + client = await self.helper.get_genshin_client(cid, need_cookie=False) _, status = await self.gacha_log.load_history_info(str(cid), str(client.uid), only_status=True) if not status: await message.reply_text("该用户还没有导入抽卡记录") @@ -303,21 +262,19 @@ async def command_gacha_log_force_delete(self, update: Update, context: Callback await message.reply_text("抽卡记录已强制删除" if status else "抽卡记录删除失败") except GachaLogNotFound: await message.reply_text("该用户还没有导入抽卡记录") - except UserNotFoundError: + except PlayerNotFoundError: await message.reply_text("该用户暂未绑定账号") except (ValueError, IndexError): await message.reply_text("用户ID 不合法") @handler(CommandHandler, command="gacha_log_export", filters=filters.ChatType.PRIVATE, block=False) @handler(MessageHandler, filters=filters.Regex("^导出抽卡记录(.*)") & filters.ChatType.PRIVATE, block=False) - @restricts() - @error_callable async def command_start_export(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user logger.info("用户 %s[%s] 导出抽卡记录命令请求", user.full_name, user.id) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) await message.reply_chat_action(ChatAction.TYPING) path = await self.gacha_log.gacha_log_to_uigf(str(user.id), str(client.uid)) await message.reply_chat_action(ChatAction.UPLOAD_DOCUMENT) @@ -332,42 +289,31 @@ async def command_start_export(self, update: Update, context: CallbackContext) - await message.reply_text("导入失败,可能文件包含的祈愿记录所属 uid 与你当前绑定的 uid 不同") except GachaLogFileError: await message.reply_text("导入失败,数据格式错误") - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) - buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] - if filters.ChatType.GROUPS.filter(message): - reply_message = await message.reply_text( - "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) - ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) - else: - await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) + await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号") @handler(CommandHandler, command="gacha_log", block=False) @handler(MessageHandler, filters=filters.Regex("^抽卡记录?(武器|角色|常驻|)$"), block=False) - @restricts() - @error_callable async def command_start_analysis(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user pool_type = BannerType.CHARACTER1 - if args := get_args(context): + if args := self.get_args(context): if "武器" in args: pool_type = BannerType.WEAPON elif "常驻" in args: pool_type = BannerType.STANDARD logger.info("用户 %s[%s] 抽卡记录命令请求 || 参数 %s", user.full_name, user.id, pool_type.name) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.helper.get_genshin_client(user.id, need_cookie=False) await message.reply_chat_action(ChatAction.TYPING) data = await self.gacha_log.get_analysis(user.id, client, pool_type, self.assets_service) if isinstance(data, str): reply_message = await message.reply_text(data) if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 300) - self._add_delete_message_job(context, message.chat_id, message.message_id, 300) + self.add_delete_message_job(reply_message, delay=300) + self.add_delete_message_job(message, delay=300) else: await message.reply_chat_action(ChatAction.UPLOAD_PHOTO) png_data = await self.template_service.render( @@ -380,29 +326,26 @@ async def command_start_analysis(self, update: Update, context: CallbackContext) [InlineKeyboardButton("点我导入", url=create_deep_linked_url(context.bot.username, "gacha_log_import"))] ] await message.reply_text("派蒙没有找到你的抽卡记录,快来点击按钮私聊派蒙导入吧~", reply_markup=InlineKeyboardMarkup(buttons)) - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) - - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) @handler(CommandHandler, command="gacha_count", block=True) @handler(MessageHandler, filters=filters.Regex("^抽卡统计?(武器|角色|常驻|仅五星|)$"), block=True) - @restricts() - @error_callable async def command_start_count(self, update: Update, context: CallbackContext) -> None: message = update.effective_message user = update.effective_user pool_type = BannerType.CHARACTER1 all_five = False - if args := get_args(context): + if args := self.get_args(context): if "武器" in args: pool_type = BannerType.WEAPON elif "常驻" in args: @@ -411,7 +354,7 @@ async def command_start_count(self, update: Update, context: CallbackContext) -> all_five = True logger.info("用户 %s[%s] 抽卡统计命令请求 || 参数 %s || 仅五星 %s", user.full_name, user.id, pool_type.name, all_five) try: - client = await get_genshin_client(user.id, need_cookie=False) + client = await self.get_genshin_client(user.id, need_cookie=False) group = filters.ChatType.GROUPS.filter(message) await message.reply_chat_action(ChatAction.TYPING) if all_five: @@ -421,8 +364,8 @@ async def command_start_count(self, update: Update, context: CallbackContext) -> if isinstance(data, str): reply_message = await message.reply_text(data) if filters.ChatType.GROUPS.filter(message): - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 300) - self._add_delete_message_job(context, message.chat_id, message.message_id, 300) + self.add_delete_message_job(reply_message) + self.add_delete_message_job(message) else: document = False if data["hasMore"] and not group: @@ -446,15 +389,15 @@ async def command_start_count(self, update: Update, context: CallbackContext) -> [InlineKeyboardButton("点我导入", url=create_deep_linked_url(context.bot.username, "gacha_log_import"))] ] await message.reply_text("派蒙没有找到你的抽卡记录,快来私聊派蒙导入吧~", reply_markup=InlineKeyboardMarkup(buttons)) - except UserNotFoundError: + except PlayerNotFoundError: logger.info("未查询到用户 %s[%s] 所绑定的账号信息", user.full_name, user.id) buttons = [[InlineKeyboardButton("点我绑定账号", url=create_deep_linked_url(context.bot.username, "set_uid"))]] if filters.ChatType.GROUPS.filter(message): reply_message = await message.reply_text( "未查询到您所绑定的账号信息,请先私聊派蒙绑定账号", reply_markup=InlineKeyboardMarkup(buttons) ) - self._add_delete_message_job(context, reply_message.chat_id, reply_message.message_id, 30) + self.add_delete_message_job(reply_message, delay=30) - self._add_delete_message_job(context, message.chat_id, message.message_id, 30) + self.add_delete_message_job(message, delay=30) else: await message.reply_text("未查询到您所绑定的账号信息,请先绑定账号", reply_markup=InlineKeyboardMarkup(buttons)) diff --git a/plugins/system/auth.py b/plugins/group/captcha.py similarity index 84% rename from plugins/system/auth.py rename to plugins/group/captcha.py index ff75c559..b80c8e66 100644 --- a/plugins/system/auth.py +++ b/plugins/group/captcha.py @@ -1,23 +1,20 @@ import asyncio import random import time -from typing import Tuple, Union, Dict, List, Optional +from typing import Tuple, Union, Dict, Optional from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup, ChatPermissions, ChatMember, Message, User from telegram.constants import ParseMode -from telegram.error import BadRequest, RetryAfter -from telegram.ext import CallbackContext, CallbackQueryHandler, ChatMemberHandler +from telegram.error import BadRequest +from telegram.ext import CallbackContext, CallbackQueryHandler, ChatMemberHandler, filters from telegram.helpers import escape_markdown -from core.base.mtproto import MTProto -from core.base.redisdb import RedisDB -from core.bot import bot from core.config import config +from core.dependence.mtproto import MTProto +from core.dependence.redisdb import RedisDB from core.plugin import Plugin, handler -from core.quiz import QuizService +from core.services.quiz.services import QuizService from utils.chatmember import extract_status_change -from utils.decorators.error import error_callable -from utils.decorators.restricts import restricts from utils.log import logger try: @@ -47,7 +44,7 @@ ) -class GroupJoiningVerification(Plugin): +class GroupCaptcha(Plugin): """群验证模块""" def __init__(self, quiz_service: QuizService = None, mtp: MTProto = None, redis: RedisDB = None): @@ -55,12 +52,12 @@ def __init__(self, quiz_service: QuizService = None, mtp: MTProto = None, redis: self.time_out = 120 self.kick_time = 120 self.lock = asyncio.Lock() - self.chat_administrators_cache: Dict[Union[str, int], Tuple[float, List[ChatMember]]] = {} + self.chat_administrators_cache: Dict[Union[str, int], Tuple[float, Tuple[ChatMember]]] = {} self.is_refresh_quiz = False self.mtp = mtp.client self.redis = redis.client - async def __async_init__(self): + async def initialize(self): logger.info("群验证模块正在刷新问题列表") await self.refresh_quiz() logger.success("群验证模块刷新问题列表成功") @@ -71,7 +68,7 @@ async def refresh_quiz(self): await self.quiz_service.refresh_quiz() self.is_refresh_quiz = True - async def get_chat_administrators(self, context: CallbackContext, chat_id: Union[str, int]) -> List[ChatMember]: + async def get_chat_administrators(self, context: CallbackContext, chat_id: Union[str, int]) -> Tuple[ChatMember]: async with self.lock: cache_data = self.chat_administrators_cache.get(f"{chat_id}") if cache_data is not None: @@ -83,43 +80,40 @@ async def get_chat_administrators(self, context: CallbackContext, chat_id: Union return chat_administrators @staticmethod - def is_admin(chat_administrators: List[ChatMember], user_id: int) -> bool: + def is_admin(chat_administrators: Tuple[ChatMember], user_id: int) -> bool: return any(admin.user.id == user_id for admin in chat_administrators) async def kick_member_job(self, context: CallbackContext): job = context.job - logger.info(f"踢出用户 user_id[{job.user_id}] 在 chat_id[{job.chat_id}]") + logger.info("踢出用户 user_id[%s] 在 chat_id[%s]", job.user_id, job.chat_id) try: await context.bot.ban_chat_member( chat_id=job.chat_id, user_id=job.user_id, until_date=int(time.time()) + self.kick_time ) except BadRequest as exc: - logger.error(f"Auth模块在 chat_id[{job.chat_id}] user_id[{job.user_id}] 执行kick失败") - logger.exception(exc) + logger.error("GroupCaptcha插件在 chat_id[%s] user_id[%s] 执行kick失败", job.chat_id, job.user_id, exc_info=exc) @staticmethod async def clean_message_job(context: CallbackContext): job = context.job - logger.debug(f"删除消息 chat_id[{job.chat_id}] 的 message_id[{job.data}]") + logger.debug("删除消息 chat_id[%s] 的 message_id[%s]", job.chat_id, job.data) try: await context.bot.delete_message(chat_id=job.chat_id, message_id=job.data) except BadRequest as exc: - if "not found" in str(exc): - logger.warning(f"Auth模块删除消息 chat_id[{job.chat_id}] message_id[{job.data}]失败 消息不存在") - elif "Message can't be deleted" in str(exc): - logger.warning(f"Auth模块删除消息 chat_id[{job.chat_id}] message_id[{job.data}]失败 消息无法删除 可能是没有授权") + if "not found" in exc.message: + logger.warning("GroupCaptcha插件删除消息 chat_id[%s] message_id[%s]失败 消息不存在", job.chat_id, job.data) + elif "Message can't be deleted" in exc.message: + logger.warning("GroupCaptcha插件删除消息 chat_id[%s] message_id[%s]失败 消息无法删除 可能是没有授权", job.chat_id, job.data) else: - logger.error(f"Auth模块删除消息 chat_id[{job.chat_id}] message_id[{job.data}]失败") - logger.exception(exc) + logger.error("GroupCaptcha插件删除消息 chat_id[%s] message_id[%s]失败", job.chat_id, job.data, exc_info=exc) @staticmethod async def restore_member(context: CallbackContext, chat_id: int, user_id: int): - logger.debug(f"重置用户权限 user_id[{user_id}] 在 chat_id[{chat_id}]") + logger.debug("重置用户权限 user_id[%s] 在 chat_id[%s]", chat_id, user_id) try: await context.bot.restrict_chat_member(chat_id=chat_id, user_id=user_id, permissions=FullChatPermissions) except BadRequest as exc: - logger.error(f"Auth模块在 chat_id[{chat_id}] user_id[{user_id}] 执行restore失败") - logger.exception(exc) + logger.error("GroupCaptcha插件在 chat_id[%s] user_id[%s] 执行restore失败", chat_id, user_id, exc_info=exc) async def get_new_chat_members_message(self, user: User, context: CallbackContext) -> Optional[Message]: qname = f"plugin:auth:new_chat_members_message:{user.id}" @@ -134,31 +128,29 @@ async def set_new_chat_members_message(self, user: User, message: Message): await self.redis.set(qname, message.to_json(), ex=60) @handler(CallbackQueryHandler, pattern=r"^auth_admin\|", block=False) - @error_callable - @restricts(without_overlapping=True) async def admin(self, update: Update, context: CallbackContext) -> None: async def admin_callback(callback_query_data: str) -> Tuple[str, int]: _data = callback_query_data.split("|") _result = _data[1] _user_id = int(_data[2]) - logger.debug(f"admin_callback函数返回 result[{_result}] user_id[{_user_id}]") + logger.debug("admin_callback函数返回 result[%s] user_id[%s]", _result, _user_id) return _result, _user_id callback_query = update.callback_query user = callback_query.from_user message = callback_query.message chat = message.chat - logger.info(f"用户 {user.full_name}[{user.id}] 在群 {chat.title}[{chat.id}] 点击Auth管理员命令") + logger.info("用户 %s[%s] 在群 %s[%s] 点击Auth管理员命令", user.full_name, user.id, chat.title, chat.id) chat_administrators = await self.get_chat_administrators(context, chat_id=chat.id) if not self.is_admin(chat_administrators, user.id): - logger.debug(f"用户 {user.full_name}[{user.id}] 在群 {chat.title}[{chat.id}] 非群管理") + logger.debug("用户 %s[%s] 在群 %s[%s] 非群管理", user.full_name, user.id, chat.title, chat.id) await callback_query.answer(text="你不是管理!\n" + config.notice.user_mismatch, show_alert=True) return result, user_id = await admin_callback(callback_query.data) try: member_info = await context.bot.get_chat_member(chat.id, user_id) except BadRequest as error: - logger.warning(f"获取用户 {user_id} 在群 {chat.title}[{chat.id}] 信息失败 \n", error) + logger.warning("获取用户 %s 在群 %s[%s] 信息失败 \n %s", user_id, chat.title, chat.id, error.message) user_info = f"{user_id}" else: user_info = member_info.user.mention_markdown_v2() @@ -169,12 +161,12 @@ async def admin_callback(callback_query_data: str) -> Tuple[str, int]: if schedule := context.job_queue.scheduler.get_job(f"{chat.id}|{user_id}|auth_kick"): schedule.remove() await message.edit_text(f"{user_info} 被 {user.mention_markdown_v2()} 放行", parse_mode=ParseMode.MARKDOWN_V2) - logger.info(f"用户 user_id[{user_id}] 在群 {chat.title}[{chat.id}] 被管理放行") + logger.info("用户 %s[%s] 在群 %s[%s] 被管理放行", user.full_name, user.id, chat.title, chat.id) elif result == "kick": await callback_query.answer(text="驱离", show_alert=False) await context.bot.ban_chat_member(chat.id, user_id) await message.edit_text(f"{user_info} 被 {user.mention_markdown_v2()} 驱离", parse_mode=ParseMode.MARKDOWN_V2) - logger.info(f"用户 user_id[{user_id}] 在群 {chat.title}[{chat.id}] 被管理踢出") + logger.info("用户 %s[%s] 在群 %s[%s] 被管理踢出", user.full_name, user.id, chat.title, chat.id) elif result == "unban": await callback_query.answer(text="解除驱离", show_alert=False) await self.restore_member(context, chat.id, user_id) @@ -183,16 +175,14 @@ async def admin_callback(callback_query_data: str) -> Tuple[str, int]: await message.edit_text( f"{user_info} 被 {user.mention_markdown_v2()} 解除驱离", parse_mode=ParseMode.MARKDOWN_V2 ) - logger.info(f"用户 user_id[{user_id}] 在群 {chat.title}[{chat.id}] 被管理解除封禁") + logger.info("用户 user_id[%s] 在群 %s[%s] 被管理解除封禁", user_id, chat.title, chat.id) else: - logger.warning(f"auth 模块 admin 函数 发现未知命令 result[{result}]") + logger.warning("auth 模块 admin 函数 发现未知命令 result[%s]", result) await context.bot.send_message(chat.id, "派蒙这边收到了错误的消息!请检查详细日记!") if schedule := context.job_queue.scheduler.get_job(f"{chat.id}|{user_id}|auth_kick"): schedule.remove() @handler(CallbackQueryHandler, pattern=r"^auth_challenge\|", block=False) - @error_callable - @restricts(without_overlapping=True) async def query(self, update: Update, context: CallbackContext) -> None: async def query_callback(callback_query_data: str) -> Tuple[int, bool, str, str]: _data = callback_query_data.split("|") @@ -205,8 +195,11 @@ async def query_callback(callback_query_data: str) -> Tuple[int, bool, str, str] _answer_encode = _answer.text _question_encode = _question.text logger.debug( - f"query_callback函数返回 user_id[{_user_id}] result[{_result}] \n" - f"question_encode[{_question_encode}] answer_encode[{_answer_encode}]" + "query_callback函数返回 user_id[%s] result[%s] \nquestion_encode[%s] answer_encode[%s]", + _user_id, + _result, + _question_encode, + _answer_encode, ) return _user_id, _result, _question_encode, _answer_encode @@ -215,11 +208,13 @@ async def query_callback(callback_query_data: str) -> Tuple[int, bool, str, str] message = callback_query.message chat = message.chat user_id, result, question, answer = await query_callback(callback_query.data) - logger.info(f"用户 {user.full_name}[{user.id}] 在群 {chat.title}[{chat.id}] 点击Auth认证命令 ") + logger.info("用户 %s[%s] 在群 %s[%s] 点击Auth认证命令", user.full_name, user.id, chat.title, chat.id) if user.id != user_id: await callback_query.answer(text="这不是你的验证!\n" + config.notice.user_mismatch, show_alert=True) return - logger.info(f"用户 {user.full_name}[{user.id}] 在群 {chat.title}[{chat.id}] 认证结果为 {'通过' if result else '失败'}") + logger.info( + "用户 %s[%s] 在群 %s[%s] 认证结果为 %s", user.full_name, user.id, chat.title, chat.id, "通过" if result else "失败" + ) if result: buttons = [[InlineKeyboardButton("驱离", callback_data=f"auth_admin|kick|{user.id}")]] await callback_query.answer(text="验证成功", show_alert=False) @@ -231,7 +226,7 @@ async def query_callback(callback_query_data: str) -> Tuple[int, bool, str, str] f"问题:{escape_markdown(question, version=2)} \n" f"回答:{escape_markdown(answer, version=2)}" ) - logger.info(f"用户 user_id[{user_id}] 在群 {chat.title}[{chat.id}] 验证成功") + logger.info("用户 user_id[%s] 在群 %s[%s] 验证成功", user_id, chat.title, chat.id) else: buttons = [ [ @@ -249,24 +244,23 @@ async def query_callback(callback_query_data: str) -> Tuple[int, bool, str, str] f"问题:{escape_markdown(question, version=2)} \n" f"回答:{escape_markdown(answer, version=2)}" ) - logger.info(f"用户 user_id[{user_id}] 在群 {chat.title}[{chat.id}] 验证失败") + logger.info("用户 user_id[%s] 在群 %s[%s] 验证失败", user_id, chat.title, chat.id) try: await message.edit_text(text, reply_markup=InlineKeyboardMarkup(buttons), parse_mode=ParseMode.MARKDOWN_V2) except BadRequest as exc: - if "are exactly the same as " in str(exc): + if "are exactly the same as " in exc.message: logger.warning("编辑消息发生异常,可能为用户点按多次键盘导致") else: raise exc if schedule := context.job_queue.scheduler.get_job(f"{chat.id}|{user.id}|auth_kick"): schedule.remove() - @handler.message.new_chat_members(priority=1) - @error_callable + @handler.message(filters=filters.StatusUpdate.NEW_CHAT_MEMBERS, block=False) async def new_mem(self, update: Update, context: CallbackContext) -> None: message = update.effective_message chat = message.chat - if len(bot.config.verify_groups) >= 1: - for verify_group in bot.config.verify_groups: + if len(config.verify_groups) >= 1: + for verify_group in config.verify_groups: if verify_group == chat.id: break else: @@ -280,11 +274,10 @@ async def new_mem(self, update: Update, context: CallbackContext) -> None: await self.set_new_chat_members_message(user, message) @handler.chat_member(chat_member_types=ChatMemberHandler.CHAT_MEMBER, block=False) - @error_callable async def track_users(self, update: Update, context: CallbackContext) -> None: chat = update.effective_chat - if len(bot.config.verify_groups) >= 1: - for verify_group in bot.config.verify_groups: + if len(config.verify_groups) >= 1: + for verify_group in config.verify_groups: if verify_group == chat.id: break else: @@ -301,7 +294,7 @@ async def track_users(self, update: Update, context: CallbackContext) -> None: if was_member and not is_member: logger.info("用户 %s[%s] 退出群聊 %s[%s]", user.full_name, user.id, chat.title, chat.id) return - elif not was_member and is_member: + if not was_member and is_member: logger.info("用户 %s[%s] 尝试加入群 %s[%s]", user.full_name, user.id, chat.title, chat.id) if user.is_bot: return @@ -323,8 +316,7 @@ async def track_users(self, update: Update, context: CallbackContext) -> None: parse_mode=ParseMode.HTML, ) return - else: - raise exc + raise exc new_chat_members_message = await self.get_new_chat_members_message(user, context) question_id = random.choice(question_id_list) # nosec question = await self.quiz_service.get_question(question_id) diff --git a/plugins/jobs/public_cookies.py b/plugins/jobs/public_cookies.py index eacd0a96..80e4bd0c 100644 --- a/plugins/jobs/public_cookies.py +++ b/plugins/jobs/public_cookies.py @@ -1,25 +1,18 @@ -import asyncio import datetime from telegram.ext import CallbackContext -from core.cookies.services import PublicCookiesService from core.plugin import Plugin, job +from core.services.cookies.services import PublicCookiesService from utils.log import logger +__all__ = ("PublicCookiesPlugin",) -class PublicCookies(Plugin): + +class PublicCookiesPlugin(Plugin): def __init__(self, public_cookies_service: PublicCookiesService = None): self.public_cookies_service = public_cookies_service - async def __async_init__(self): - async def _refresh(): - logger.info("正在刷新公共Cookies池") - await self.public_cookies_service.refresh() - logger.success("刷新公共Cookies池成功") - - asyncio.create_task(_refresh()) - @job.run_repeating(interval=datetime.timedelta(hours=2), name="PublicCookiesRefresh") async def refresh(self, _: CallbackContext): logger.info("正在刷新公共Cookies池") diff --git a/plugins/jobs/sign.py b/plugins/jobs/sign.py index bcb06eb3..9cfd92f3 100644 --- a/plugins/jobs/sign.py +++ b/plugins/jobs/sign.py @@ -1,105 +1,25 @@ import datetime -from aiohttp import ClientConnectorError -from genshin import GenshinException, AlreadyClaimed, InvalidCookies -from httpx import TimeoutException -from telegram.constants import ParseMode -from telegram.error import BadRequest, Forbidden from telegram.ext import CallbackContext -from core.base.redisdb import RedisDB -from core.cookies import CookiesService from core.plugin import Plugin, job -from core.sign.models import SignStatusEnum -from core.sign.services import SignServices -from core.user import UserService -from plugins.genshin.sign import SignSystem, NeedChallenge -from plugins.system.errorhandler import notice_chat_id -from plugins.system.sign_status import SignStatus -from utils.helpers import get_genshin_client +from plugins.genshin.sign import SignSystem +from plugins.tools.sign import SignJobType from utils.log import logger class SignJob(Plugin): - def __init__( - self, - sign_service: SignServices = None, - user_service: UserService = None, - cookies_service: CookiesService = None, - redis: RedisDB = None, - ): - self.sign_service = sign_service - self.cookies_service = cookies_service - self.user_service = user_service - self.sign_system = SignSystem(redis) + def __init__(self, sign_system: SignSystem): + self.sign_system = sign_system @job.run_daily(time=datetime.time(hour=0, minute=1, second=0), name="SignJob") async def sign(self, context: CallbackContext): - logger.info("正在执行自动签到" if context.job.name == "SignJob" else "正在执行自动重签") - sign_list = await self.sign_service.get_all() - for sign_db in sign_list: - user_id = sign_db.user_id - if sign_db.status in [ - SignStatusEnum.INVALID_COOKIES, - SignStatusEnum.FORBIDDEN, - ]: - continue - if context.job.name == "SignJob": - if sign_db.status not in [SignStatusEnum.STATUS_SUCCESS, SignStatusEnum.ALREADY_CLAIMED]: - continue - elif context.job.name == "SignAgainJob" and sign_db.status in [ - SignStatusEnum.STATUS_SUCCESS, - SignStatusEnum.ALREADY_CLAIMED, - ]: - continue - try: - client = await get_genshin_client(user_id) - text = await self.sign_system.start_sign( - client, is_sleep=True, is_raise=True, title="自动签到" if context.job.name == "SignJob" else "自动重新签到" - ) - sign_db.status = SignStatusEnum.STATUS_SUCCESS - except InvalidCookies: - text = "自动签到执行失败,Cookie无效" - sign_db.status = SignStatusEnum.INVALID_COOKIES - except AlreadyClaimed: - text = "今天旅行者已经签到过了~" - sign_db.status = SignStatusEnum.ALREADY_CLAIMED - except GenshinException as exc: - text = f"自动签到执行失败,API返回信息为 {str(exc)}" - sign_db.status = SignStatusEnum.GENSHIN_EXCEPTION - except TimeoutException: - text = "签到失败了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ " - sign_db.status = SignStatusEnum.TIMEOUT_ERROR - except ClientConnectorError as exc: - logger.warning("aiohttp 请求错误 %s", str(exc)) - text = "签到失败了呜呜呜 ~ 链接服务器发生错误 服务器熟啦 ~ " - sign_db.status = SignStatusEnum.TIMEOUT_ERROR - except NeedChallenge: - text = "签到失败,触发验证码风控" - sign_db.status = SignStatusEnum.NEED_CHALLENGE - except Exception as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - text = "签到失败了呜呜呜 ~ 执行自动签到时发生错误" - if sign_db.chat_id < 0: - text = f'NOTICE {sign_db.user_id}\n\n{text}' - try: - await context.bot.send_message(sign_db.chat_id, text, parse_mode=ParseMode.HTML) - except BadRequest as exc: - logger.error("执行自动签到时发生错误 message[%s] user_id[%s]", exc.message, user_id) - sign_db.status = SignStatusEnum.BAD_REQUEST - except Forbidden as exc: - logger.error("执行自动签到时发生错误 message[%s] user_id[%s]", exc.message, user_id) - sign_db.status = SignStatusEnum.FORBIDDEN - except Exception as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - continue - sign_db.time_updated = datetime.datetime.now() - await self.sign_service.update(sign_db) - logger.info("执行自动签到完成" if context.job.name == "SignJob" else "执行自动重签完成") - if context.job.name == "SignJob": - context.job_queue.run_once(self.sign, when=60, name="SignAgainJob") - elif context.job.name == "SignAgainJob": - text = await SignStatus.get_sign_status(self.sign_service) - await context.bot.send_message(notice_chat_id, text, parse_mode=ParseMode.HTML) + logger.info("正在执行自动签到") + await self.sign_system.do_sign_job(context, job_type=SignJobType.START) + logger.success("执行自动签到完成") + await self.re_sign(context) + + async def re_sign(self, context: CallbackContext): + logger.info("正在执行自动重签") + await self.sign_system.do_sign_job(context, job_type=SignJobType.REDO) + logger.success("执行自动重签完成") diff --git a/plugins/system/admin.py b/plugins/system/admin.py deleted file mode 100644 index 984afdd5..00000000 --- a/plugins/system/admin.py +++ /dev/null @@ -1,72 +0,0 @@ -import contextlib - -from telegram import Update -from telegram.error import BadRequest, Forbidden -from telegram.ext import CallbackContext, CommandHandler - -from core.admin import BotAdminService -from core.plugin import handler, Plugin -from utils.decorators.admins import bot_admins_rights_check -from utils.log import logger - - -class AdminPlugin(Plugin): - """有关BOT ADMIN处理""" - - def __init__(self, bot_admin_service: BotAdminService = None): - self.bot_admin_service = bot_admin_service - - @handler(CommandHandler, command="add_admin", block=False) - @bot_admins_rights_check - async def add_admin(self, update: Update, _: CallbackContext): - message = update.effective_message - reply_to_message = message.reply_to_message - if reply_to_message is None: - await message.reply_text("请回复对应消息") - else: - admin_list = await self.bot_admin_service.get_admin_list() - if reply_to_message.from_user.id in admin_list: - await message.reply_text("该用户已经存在管理员列表") - else: - await self.bot_admin_service.add_admin(reply_to_message.from_user.id) - await message.reply_text("添加成功") - - @handler(CommandHandler, command="del_admin", block=False) - @bot_admins_rights_check - async def del_admin(self, update: Update, _: CallbackContext): - message = update.effective_message - reply_to_message = message.reply_to_message - admin_list = await self.bot_admin_service.get_admin_list() - if reply_to_message is None: - await message.reply_text("请回复对应消息") - else: - if reply_to_message.from_user.id in admin_list: - await self.bot_admin_service.delete_admin(reply_to_message.from_user.id) - await message.reply_text("删除成功") - else: - await message.reply_text("该用户不存在管理员列表") - - @handler(CommandHandler, command="leave_chat", block=False) - @bot_admins_rights_check - async def leave_chat(self, update: Update, context: CallbackContext): - message = update.effective_message - try: - args = message.text.split() - if len(args) >= 2: - chat_id = int(args[1]) - else: - await message.reply_text("输入错误") - return - except ValueError as error: - logger.error("获取 chat_id 发生错误! 错误信息为 \n", exc_info=error) - await message.reply_text("输入错误") - return - try: - with contextlib.suppress(BadRequest, Forbidden): - chat = await context.bot.get_chat(chat_id) - await message.reply_text(f"正在尝试退出群 {chat.title}[{chat.id}]") - await context.bot.leave_chat(chat_id) - except (BadRequest, Forbidden) as exc: - await message.reply_text(f"退出 chat_id[{chat_id}] 发生错误! 错误信息为 {str(exc)}") - return - await message.reply_text(f"退出 chat_id[{chat_id}] 成功!") diff --git a/plugins/system/chat_member.py b/plugins/system/chat_member.py index 702f7bd8..dc9fd4c7 100644 --- a/plugins/system/chat_member.py +++ b/plugins/system/chat_member.py @@ -1,34 +1,28 @@ -import contextlib - -from telegram import Update, Chat, User -from telegram.error import BadRequest +from telegram import Chat, Update, User +from telegram.error import NetworkError from telegram.ext import CallbackContext, ChatMemberHandler -from core.admin.services import BotAdminService -from core.config import config, JoinGroups -from core.cookies.error import CookiesNotFoundError -from core.cookies.services import CookiesService +from core.config import JoinGroups, config from core.plugin import Plugin, handler -from core.user.error import UserNotFoundError -from core.user.services import UserService +from core.services.cookies import CookiesService +from core.services.players import PlayersService +from core.services.users.services import UserAdminService from utils.chatmember import extract_status_change -from utils.decorators.error import error_callable from utils.log import logger class ChatMember(Plugin): def __init__( self, - bot_admin_service: BotAdminService = None, - user_service: UserService = None, + user_admin_service: UserAdminService = None, + players_service: PlayersService = None, cookies_service: CookiesService = None, ): self.cookies_service = cookies_service - self.user_service = user_service - self.bot_admin_service = bot_admin_service + self.players_service = players_service + self.user_admin_service = user_admin_service @handler.chat_member(chat_member_types=ChatMemberHandler.MY_CHAT_MEMBER, block=False) - @error_callable async def track_chats(self, update: Update, context: CallbackContext) -> None: result = extract_status_change(update.my_chat_member) if result is None: @@ -57,8 +51,7 @@ async def greet(self, user: User, chat: Chat, context: CallbackContext) -> None: quit_status = True if config.join_groups == JoinGroups.NO_ALLOW: try: - admin_list = await self.bot_admin_service.get_admin_list() - if user.id in admin_list: + if await self.user_admin_service.is_admin(user.id): quit_status = False else: logger.warning("不是管理员邀请!退出群聊") @@ -66,30 +59,27 @@ async def greet(self, user: User, chat: Chat, context: CallbackContext) -> None: logger.error("获取信息出现错误", exc_info=exc) elif config.join_groups == JoinGroups.ALLOW_AUTH_USER: try: - user_info = await self.user_service.get_user_by_id(user.id) - await self.cookies_service.get_cookies(user.id, user_info.region) - except (UserNotFoundError, CookiesNotFoundError): - logger.warning("用户 %s[%s] 邀请请求被拒绝", user.full_name, user.id) - except Exception as exc: + if await self.cookies_service.get(user.id) is not None: + quit_status = False + except Exception as exc: # pylint: disable=W0703 logger.error("获取信息出现错误", exc_info=exc) - else: - quit_status = False elif config.join_groups == JoinGroups.ALLOW_USER: try: - await self.user_service.get_user_by_id(user.id) - except UserNotFoundError: - logger.warning("用户 %s[%s] 邀请请求被拒绝", user.full_name, user.id) - except Exception as exc: + if await self.players_service.get(user.id) is not None: + quit_status = False + except Exception as exc: # pylint: disable=W0703 logger.error("获取信息出现错误", exc_info=exc) - else: - quit_status = False elif config.join_groups == JoinGroups.ALLOW_ALL: quit_status = False else: quit_status = True if quit_status: - with contextlib.suppress(BadRequest): + try: await context.bot.send_message(chat.id, "派蒙不想进去!不是旅行者的邀请!") + except NetworkError as exc: + logger.info("发送消息失败 %s", exc.message) + except Exception as exc: + logger.info("发送消息失败", exc_info=exc) await context.bot.leave_chat(chat.id) else: await context.bot.send_message(chat.id, "感谢邀请小派蒙到本群!请使用 /help 查看咱已经学会的功能。") diff --git a/plugins/system/errorhandler.py b/plugins/system/errorhandler.py index d4a50de9..f11053dc 100644 --- a/plugins/system/errorhandler.py +++ b/plugins/system/errorhandler.py @@ -1,18 +1,29 @@ import os import time import traceback +from typing import Optional import aiofiles -from telegram import ReplyKeyboardRemove, Update +from aiohttp import ClientError, ClientConnectorError +from genshin import DataNotPublic, GenshinException, InvalidCookies, TooManyRequests +from httpx import Timeout as HttpxTimeout, HTTPError +from telegram import ReplyKeyboardRemove, Update, InlineKeyboardMarkup, InlineKeyboardButton from telegram.constants import ParseMode -from telegram.error import BadRequest, Forbidden, NetworkError, TimedOut -from telegram.ext import CallbackContext +from telegram.error import BadRequest, Forbidden, TelegramError, TimedOut, NetworkError +from telegram.ext import CallbackContext, ApplicationHandlerStop +from telegram.helpers import create_deep_linked_url -from core.bot import bot from core.config import config from core.plugin import Plugin, error_handler -from modules.errorpush import PbClient, SentryClient, PbClientException, SentryClientException +from modules.apihelper.error import APIHelperException, APIHelperTimedOut, ResponseException, ReturnCodeError +from modules.errorpush import ( + PbClient, + PbClientException, + SentryClient, + SentryClientException, +) from utils.log import logger +from utils.patch.aiohttp import AioHttpTimeoutException try: import ujson as jsonlib @@ -20,34 +31,189 @@ except ImportError: import json as jsonlib -notice_chat_id = bot.config.error.notification_chat_id -current_dir = os.getcwd() -logs_dir = os.path.join(current_dir, "logs") -if not os.path.exists(logs_dir): - os.mkdir(logs_dir) -report_dir = os.path.join(current_dir, "report") -if not os.path.exists(report_dir): - os.mkdir(report_dir) -pb_client = PbClient(config.error.pb_url, config.error.pb_sunset, config.error.pb_max_lines) -sentry = SentryClient(config.error.sentry_dsn) - class ErrorHandler(Plugin): - @error_handler(block=False) # pylint: disable=E1123, E1120 - async def error_handler(self, update: object, context: CallbackContext) -> None: - """记录错误并发送消息通知开发人员。 logger the error and send a telegram message to notify the developer.""" + ERROR_MSG_PREFIX = "出错了呜呜呜 ~ " + SEND_MSG_ERROR_NOTICE = "发送 update_id[%s] 错误信息失败 错误信息为 [%s]" + + def __init__(self): + self.notice_chat_id = config.error.notification_chat_id + current_dir = os.getcwd() + logs_dir = os.path.join(current_dir, "logs") + if not os.path.exists(logs_dir): + os.mkdir(logs_dir) + self.report_dir = os.path.join(current_dir, "report") + if not os.path.exists(self.report_dir): + os.mkdir(self.report_dir) + self.pb_client = PbClient(config.error.pb_url, config.error.pb_sunset, config.error.pb_max_lines) + self.sentry = SentryClient(config.error.sentry_dsn) + + async def notice_user(self, update: object, context: CallbackContext, content: str): + if not isinstance(update, Update): + logger.warning("错误的消息类型 %s", repr(update)) + return None + if update.inline_query is not None: # 忽略 inline_query + return None + + if "重新绑定" in content: + buttons = InlineKeyboardMarkup( + [[InlineKeyboardButton("点我重新绑定", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] + ) + elif "通过验证" in content: + buttons = InlineKeyboardMarkup( + [ + [ + InlineKeyboardButton( + "点我通过验证", url=create_deep_linked_url(context.bot.username, "verify_verification") + ) + ] + ] + ) + else: + buttons = ReplyKeyboardRemove() - if isinstance(context.error, NetworkError): - logger.error("Bot请求异常 %s", context.error.message) + user = update.effective_user + message = update.effective_message + chat = update.effective_chat + + if chat.id == user.id: + logger.info("尝试通知用户 %s[%s] 错误信息[%s]", user.full_name, user.id, content) + else: + logger.info("尝试通知用户 %s[%s] 在 %s[%s] 的错误信息[%s]", user.full_name, user.id, chat.title, chat.id, content) + try: + if update.callback_query: + await update.callback_query.answer(content, show_alert=True) + return None + return await message.reply_text(content, reply_markup=buttons, allow_sending_without_reply=True) + except TelegramError as exc: + logger.error(self.SEND_MSG_ERROR_NOTICE, update.update_id, exc.message) + except Exception as exc: + logger.error(self.SEND_MSG_ERROR_NOTICE, update.update_id, repr(exc), exc_info=exc) + + def create_notice_task(self, update: object, context: CallbackContext, content: str): + context.application.create_task(self.notice_user(update, context, content), update) + + @error_handler() + async def process_genshin_exception(self, update: object, context: CallbackContext): + if not isinstance(context.error, GenshinException) or not isinstance(update, Update): return + exc = context.error + notice: Optional[str] = None + if isinstance(exc, TooManyRequests): + notice = self.ERROR_MSG_PREFIX + "Cookie 无效,请尝试重新绑定" + elif isinstance(exc, InvalidCookies): + if exc.retcode in (10001, -100): + notice = self.ERROR_MSG_PREFIX + "Cookie 无效,请尝试重新绑定" + elif exc.retcode == 10103: + notice = self.ERROR_MSG_PREFIX + "Cookie 有效,但没有绑定到游戏帐户,请尝试登录通行证,在账号管理里面选择账号游戏信息,将原神设置为默认角色。" + else: + logger.error("未知Cookie错误", exc_info=exc) + notice = self.ERROR_MSG_PREFIX + f"Cookie 无效 错误信息为 {exc.original} 请尝试重新绑定" + elif isinstance(exc, DataNotPublic): + notice = self.ERROR_MSG_PREFIX + "查询的用户数据未公开" + else: + if exc.retcode == -130: + notice = self.ERROR_MSG_PREFIX + "未设置默认角色,请尝试重新绑定" + elif exc.retcode == 1034: + notice = self.ERROR_MSG_PREFIX + "服务器检测到该账号可能存在异常,请求被拒绝,请尝试通过验证" + elif exc.retcode == -500001: + notice = self.ERROR_MSG_PREFIX + "网络出小差了,请稍后重试~" + elif exc.retcode == -1: + notice = self.ERROR_MSG_PREFIX + "系统发生错误,请稍后重试~" + elif exc.retcode == -10001: # 参数异常 不应该抛出异常 进入下一步处理 + pass + else: + logger.error("GenshinException", exc_info=exc) + notice = ( + self.ERROR_MSG_PREFIX + f"获取账号信息发生错误 错误信息为 {exc.original if exc.original else exc.retcode} ~ 请稍后再试" + ) + if notice: + self.create_notice_task(update, context, notice) + raise ApplicationHandlerStop + + @error_handler() + async def process_telegram_exception(self, update: object, context: CallbackContext): + if not isinstance(context.error, TelegramError) or not isinstance(update, Update): + return + notice: Optional[str] = None if isinstance(context.error, TimedOut): - logger.error("Bot请求超时 %s", context.error.message) + notice = self.ERROR_MSG_PREFIX + " 连接连接服务器异常" + elif isinstance(context.error, BadRequest): + if "Replied message not found" in context.error.message: + notice = "气死我了!怎么有人喜欢发一个命令就秒删了!" + elif "Message is not modified" in context.error.message: + logger.warning("编辑消息异常") + raise ApplicationHandlerStop + elif "Not enough rights" in context.error.message: + notice = self.ERROR_MSG_PREFIX + "权限不足,请检查对应权限是否开启" + else: + logger.error("python-telegram-bot 请求错误", exc_info=context.error) + notice = self.ERROR_MSG_PREFIX + "telegram-bot-api请求错误 ~ 请稍后再试" + elif isinstance(context.error, Forbidden): + logger.error("python-telegram-bot 返回 Forbidden") + notice = self.ERROR_MSG_PREFIX + "telegram-bot-api请求错误 ~ 请稍后再试" + if notice: + self.create_notice_task(update, context, notice) + raise ApplicationHandlerStop + + @error_handler() + async def process_telegram_update_exception(self, update: object, context: CallbackContext): + if update is None and isinstance(context.error, NetworkError): + logger.error("python-telegram-bot NetworkError : %s", context.error.message) + raise ApplicationHandlerStop + + @error_handler() + async def process_apihelper_exception(self, update: object, context: CallbackContext): + if not isinstance(context.error, APIHelperException) or not isinstance(update, Update): + return + exc = context.error + notice: Optional[str] = None + if isinstance(exc, APIHelperTimedOut): + notice = self.ERROR_MSG_PREFIX + " 连接连接服务器异常" + elif isinstance(exc, ReturnCodeError): + notice = self.ERROR_MSG_PREFIX + f"API请求错误 错误信息为 {exc.message if exc.message else exc.code} ~ 请稍后再试" + elif isinstance(exc, ResponseException): + notice = self.ERROR_MSG_PREFIX + f"API请求错误 错误信息为 {exc.message if exc.message else exc.code} ~ 请稍后再试" + if notice: + self.create_notice_task(update, context, notice) + raise ApplicationHandlerStop + + @error_handler() + async def process_httpx_exception(self, update: object, context: CallbackContext): + if not isinstance(context.error, HTTPError) or not isinstance(update, Update): + return + exc = context.error + notice: Optional[str] = None + if isinstance(exc, HttpxTimeout): + notice = self.ERROR_MSG_PREFIX + " 连接连接服务器异常" + if notice: + self.create_notice_task(update, context, notice) + raise ApplicationHandlerStop + + @error_handler() + async def process_aiohttp_exception(self, update: object, context: CallbackContext): + if not isinstance(context.error, ClientError) or not isinstance(update, Update): return + exc = context.error + notice: Optional[str] = None + if isinstance(exc, AioHttpTimeoutException): + notice = self.ERROR_MSG_PREFIX + " 连接连接服务器异常" + elif isinstance(exc, ClientConnectorError): + notice = self.ERROR_MSG_PREFIX + " 连接连接服务器异常" + if notice: + self.create_notice_task(update, context, notice) + raise ApplicationHandlerStop + + @error_handler(block=False) + async def process_z_error(self, update: object, context: CallbackContext) -> None: + # 必须 `process_` 加上 `z` 保证该函数最后一个注册 + """记录错误并发送消息通知开发人员。 + logger the error and send a telegram message to notify the developer.""" logger.error("处理函数时发生异常") logger.exception(context.error, exc_info=(type(context.error), context.error, context.error.__traceback__)) - if not notice_chat_id: + if not self.notice_chat_id: return tb_list = traceback.format_exception(None, context.error, context.error.__traceback__) @@ -65,7 +231,7 @@ async def error_handler(self, update: object, context: CallbackContext) -> None: f"{tb_string}" ) file_name = f"error_{update.update_id if isinstance(update, Update) else int(time.time())}.txt" - log_file = os.path.join(report_dir, file_name) + log_file = os.path.join(self.report_dir, file_name) try: async with aiofiles.open(log_file, mode="w+", encoding="utf-8") as f: await f.write(error_text) @@ -77,11 +243,11 @@ async def error_handler(self, update: object, context: CallbackContext) -> None: logger.error("其他机器人在运行,请停止!") return await context.bot.send_document( - chat_id=notice_chat_id, + chat_id=self.notice_chat_id, document=open(log_file, "rb"), caption=f'Error: "{context.error.__class__.__name__}"', ) - except (BadRequest, Forbidden) as exc: + except NetworkError as exc: logger.error("发送日记失败") logger.exception(exc) except FileNotFoundError: @@ -92,25 +258,27 @@ async def error_handler(self, update: object, context: CallbackContext) -> None: if effective_message is not None: chat = effective_message.chat logger.info( - f"尝试通知用户 {effective_user.full_name}[{effective_user.id}] " - f"在 {chat.full_name}[{chat.id}]" - f"的 update_id[{update.update_id}] 错误信息" + "尝试通知用户 %s[%s] 在 %s[%s] 的 update_id[%s] 错误信息", + effective_user.full_name, + effective_user.id, + chat.full_name, + chat.id, + update.update_id, ) text = "出错了呜呜呜 ~ 派蒙这边发生了点问题无法处理!" await context.bot.send_message( effective_message.chat_id, text, reply_markup=ReplyKeyboardRemove(), parse_mode=ParseMode.HTML ) - except (BadRequest, Forbidden) as exc: - logger.error(f"发送 update_id[{update.update_id}] 错误信息失败 错误信息为") - logger.exception(exc) - if pb_client.enabled: + except NetworkError as exc: + logger.error("发送 update_id[%s] 错误信息失败 错误信息为 %s", update.update_id, exc.message) + if self.pb_client.enabled: logger.info("正在上传日记到 pb") try: - pb_url = await pb_client.create_pb(error_text) + pb_url = await self.pb_client.create_pb(error_text) if pb_url: logger.success("上传日记到 pb 成功") await context.bot.send_message( - chat_id=notice_chat_id, + chat_id=self.notice_chat_id, text=f"错误信息已上传至 fars 请查看", parse_mode=ParseMode.HTML, ) @@ -119,10 +287,10 @@ async def error_handler(self, update: object, context: CallbackContext) -> None: except Exception as exc: logger.error("上传错误信息至 fars 失败") logger.exception(exc) - if sentry.enabled: + if self.sentry.enabled: logger.info("正在上传日记到 sentry") try: - sentry.report_error(update, (type(context.error), context.error, context.error.__traceback__)) + self.sentry.report_error(update, (type(context.error), context.error, context.error.__traceback__)) logger.success("上传日记到 sentry 成功") except SentryClientException as exc: logger.warning("上传错误信息至 sentry 失败", exc_info=exc) diff --git a/plugins/system/get_chat.py b/plugins/system/get_chat.py deleted file mode 100644 index ae3eeb01..00000000 --- a/plugins/system/get_chat.py +++ /dev/null @@ -1,182 +0,0 @@ -import contextlib -import html -import os.path -from datetime import datetime -from typing import Tuple - -from telegram import Update, Chat, ChatMember, ChatMemberOwner, ChatMemberAdministrator -from telegram.error import BadRequest, Forbidden -from telegram.ext import CommandHandler, CallbackContext - -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError -from core.plugin import Plugin, handler -from core.sign import SignServices -from core.user import UserService -from core.user.error import UserNotFoundError -from core.user.models import User -from modules.gacha_log.log import GachaLog -from modules.pay_log.log import PayLog -from modules.playercards.file import PlayerCardsFile -from utils.bot import get_args, get_chat as get_chat_with_cache -from utils.decorators.admins import bot_admins_rights_check -from utils.helpers import get_genshin_client -from utils.log import logger -from utils.models.base import RegionEnum - - -class GetChat(Plugin): - def __init__( - self, - user_service: UserService = None, - cookies_service: CookiesService = None, - sign_service: SignServices = None, - ): - self.cookies_service = cookies_service - self.user_service = user_service - self.sign_service = sign_service - self.gacha_log = GachaLog() - self.pay_log = PayLog() - self.player_cards_file = PlayerCardsFile() - - async def parse_group_chat(self, chat: Chat, admins: Tuple[ChatMember]) -> str: - text = f"群 ID:{chat.id}\n群名称:{chat.title}\n" - if chat.username: - text += f"群用户名:@{chat.username}\n" - sign_info = await self.sign_service.get_by_chat_id(chat.id) - if sign_info: - text += f"自动签到推送人数:{len(sign_info)}\n" - if chat.description: - text += f"群简介:{html.escape(chat.description)}\n" - if admins: - for admin in admins: - text += f'{html.escape(admin.user.full_name)} ' - if isinstance(admin, ChatMemberAdministrator): - text += "C" if admin.can_change_info else "_" - text += "D" if admin.can_delete_messages else "_" - text += "R" if admin.can_restrict_members else "_" - text += "I" if admin.can_invite_users else "_" - text += "T" if admin.can_manage_topics else "_" - text += "P" if admin.can_pin_messages else "_" - text += "V" if admin.can_manage_video_chats else "_" - text += "N" if admin.can_promote_members else "_" - text += "A" if admin.is_anonymous else "_" - elif isinstance(admin, ChatMemberOwner): - text += "创建者" - text += "\n" - return text - - @staticmethod - async def parse_private_bind(user_info: User, chat_id: int) -> Tuple[str, int]: - if user_info.region == RegionEnum.HYPERION: - text = "米游社绑定:" - uid = user_info.yuanshen_uid - else: - text = "原神绑定:" - uid = user_info.genshin_uid - temp = "Cookie 绑定" - try: - await get_genshin_client(chat_id) - except CookiesNotFoundError: - temp = "UID 绑定" - return f"{text}{temp}\n游戏 ID:{uid}", uid - - async def parse_private_sign(self, chat_id: int) -> str: - sign_info = await self.sign_service.get_by_user_id(chat_id) - if sign_info is not None: - text = ( - f"\n自动签到:已开启" - f"\n推送会话:{sign_info.chat_id}" - f"\n开启时间:{sign_info.time_created}" - f"\n更新时间:{sign_info.time_updated}" - f"\n签到状态:{sign_info.status.name}" - ) - else: - text = "\n自动签到:未开启" - return text - - async def parse_private_gacha_log(self, chat_id: int, uid: int) -> str: - gacha_log, status = await self.gacha_log.load_history_info(str(chat_id), str(uid)) - if status: - text = "\n抽卡记录:" - for key, value in gacha_log.item_list.items(): - text += f"\n - {key}:{len(value)} 条" - text += f"\n - 最后更新:{gacha_log.update_time.strftime('%Y-%m-%d %H:%M:%S')}" - else: - text = "\n抽卡记录:未导入" - return text - - async def parse_private_pay_log(self, chat_id: int, uid: int) -> str: - pay_log, status = await self.pay_log.load_history_info(str(chat_id), str(uid)) - return ( - f"\n充值记录:\n - 已导入 {len(pay_log.list)} 条\n - 最后更新:{pay_log.info.export_time}" - if status - else "\n充值记录:未导入" - ) - - @staticmethod - def get_file_modify_time(path: str) -> datetime: - return datetime.fromtimestamp(os.path.getmtime(path)) - - async def parse_private_player_cards_file(self, uid: int) -> str: - player_cards = await self.player_cards_file.load_history_info(uid) - if player_cards is None: - text = "\n角色卡片:未缓存" - else: - time = self.get_file_modify_time(self.player_cards_file.get_file_path(uid)) - text = ( - f"\n角色卡片:" - f"\n - 已缓存 {len(player_cards.get('avatarInfoList', []))} 个角色" - f"\n - 最后更新:{time.strftime('%Y-%m-%d %H:%M:%S')}" - ) - return text - - async def parse_private_chat(self, chat: Chat) -> str: - text = ( - f'MENTION\n' - f"用户 ID:{chat.id}\n" - f"用户名称:{chat.full_name}\n" - ) - if chat.username: - text += f"用户名:@{chat.username}\n" - try: - user_info = await self.user_service.get_user_by_id(chat.id) - except UserNotFoundError: - user_info = None - if user_info is not None: - temp, uid = await self.parse_private_bind(user_info, chat.id) - text += temp - text += await self.parse_private_sign(chat.id) - with contextlib.suppress(Exception): - text += await self.parse_private_gacha_log(chat.id, uid) - with contextlib.suppress(Exception): - text += await self.parse_private_pay_log(chat.id, uid) - with contextlib.suppress(Exception): - text += await self.parse_private_player_cards_file(uid) - return text - - @handler(CommandHandler, command="get_chat", block=False) - @bot_admins_rights_check - async def get_chat(self, update: Update, context: CallbackContext): - user = update.effective_user - logger.info("用户 %s[%s] get_chat 命令请求", user.full_name, user.id) - message = update.effective_message - args = get_args(context) - if not args: - await message.reply_text("参数错误,请指定群 id !") - return - try: - chat_id = int(args[0]) - except ValueError: - await message.reply_text("参数错误,请指定群 id !") - return - try: - chat = await get_chat_with_cache(args[0]) - if chat_id < 0: - admins = await chat.get_administrators() if chat_id < 0 else None - text = await self.parse_group_chat(chat, admins) - else: - text = await self.parse_private_chat(chat) - await message.reply_text(text, parse_mode="HTML") - except (BadRequest, Forbidden) as exc: - await message.reply_text(f"通过 id 获取会话信息失败,API 返回:{exc.message}") diff --git a/plugins/system/log.py b/plugins/system/log.py index 0c4f3ee5..09dcb54f 100644 --- a/plugins/system/log.py +++ b/plugins/system/log.py @@ -2,12 +2,11 @@ from telegram import Update from telegram.constants import ChatAction -from telegram.ext import CommandHandler, CallbackContext +from telegram.ext import CallbackContext from core.config import config from core.plugin import Plugin, handler from modules.errorpush import PbClient, PbClientException -from utils.decorators.admins import bot_admins_rights_check from utils.log import logger current_dir = os.getcwd() @@ -31,11 +30,10 @@ async def send_to_pb(self, file_name: str): logger.exception(exc) return pb_url - @handler(CommandHandler, command="send_log", block=False) - @bot_admins_rights_check + @handler.command(command="send_log", block=False, admin=True) async def send_log(self, update: Update, _: CallbackContext): user = update.effective_user - logger.info(f"用户 {user.full_name}[{user.id}] send_log 命令请求") + logger.info("用户 %s[%s] send_log 命令请求", user.full_name, user.id) message = update.effective_message if os.path.exists(error_log) and os.path.getsize(error_log) > 0: pb_url = await self.send_to_pb(error_log) diff --git a/plugins/system/sign_all.py b/plugins/system/sign_all.py deleted file mode 100644 index e0f89702..00000000 --- a/plugins/system/sign_all.py +++ /dev/null @@ -1,92 +0,0 @@ -import datetime - -from aiohttp import ClientConnectorError -from genshin import InvalidCookies, AlreadyClaimed, GenshinException -from telegram import Update -from telegram.constants import ParseMode -from telegram.error import BadRequest, Forbidden -from telegram.ext import CommandHandler, CallbackContext - -from core.base.redisdb import RedisDB -from core.cookies import CookiesService -from core.plugin import Plugin, handler -from core.sign import SignServices -from core.sign.models import SignStatusEnum -from core.user import UserService -from plugins.genshin.sign import SignSystem -from plugins.jobs.sign import NeedChallenge -from utils.decorators.admins import bot_admins_rights_check -from utils.helpers import get_genshin_client -from utils.log import logger - - -class SignAll(Plugin): - def __init__( - self, - sign_service: SignServices = None, - user_service: UserService = None, - cookies_service: CookiesService = None, - redis: RedisDB = None, - ): - self.sign_service = sign_service - self.cookies_service = cookies_service - self.user_service = user_service - self.sign_system = SignSystem(redis) - - @handler(CommandHandler, command="sign_all", block=False) - @bot_admins_rights_check - async def sign_all(self, update: Update, context: CallbackContext): - user = update.effective_user - logger.info(f"用户 {user.full_name}[{user.id}] sign_all 命令请求") - message = update.effective_message - reply = await message.reply_text("正在全部重新签到,请稍后...") - sign_list = await self.sign_service.get_all() - for sign_db in sign_list: - user_id = sign_db.user_id - old_status = sign_db.status - try: - client = await get_genshin_client(user_id) - text = await self.sign_system.start_sign(client, is_sleep=True, is_raise=True, title="自动重新签到") - except InvalidCookies: - text = "自动签到执行失败,Cookie无效" - sign_db.status = SignStatusEnum.INVALID_COOKIES - except AlreadyClaimed: - text = "今天旅行者已经签到过了~" - sign_db.status = SignStatusEnum.ALREADY_CLAIMED - except GenshinException as exc: - text = f"自动签到执行失败,API返回信息为 {str(exc)}" - sign_db.status = SignStatusEnum.GENSHIN_EXCEPTION - except ClientConnectorError: - text = "签到失败了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ " - sign_db.status = SignStatusEnum.TIMEOUT_ERROR - except NeedChallenge: - text = "签到失败,触发验证码风控,自动签到自动关闭" - sign_db.status = SignStatusEnum.NEED_CHALLENGE - except Exception as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - text = "签到失败了呜呜呜 ~ 执行自动签到时发生错误" - else: - sign_db.status = SignStatusEnum.STATUS_SUCCESS - if sign_db.chat_id < 0: - text = f'NOTICE {sign_db.user_id}\n\n{text}' - try: - await context.bot.send_message(sign_db.chat_id, text, parse_mode=ParseMode.HTML) - except BadRequest as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - sign_db.status = SignStatusEnum.BAD_REQUEST - except Forbidden as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - sign_db.status = SignStatusEnum.FORBIDDEN - except Exception as exc: - logger.error(f"执行自动签到时发生错误 用户UID[{user_id}]") - logger.exception(exc) - continue - else: - sign_db.status = SignStatusEnum.STATUS_SUCCESS - sign_db.time_updated = datetime.datetime.now() - if sign_db.status != old_status: - await self.sign_service.update(sign_db) - await reply.edit_text("全部账号重新签到完成") diff --git a/plugins/system/update.py b/plugins/system/update.py index 5c5f0dbe..2b76afd4 100644 --- a/plugins/system/update.py +++ b/plugins/system/update.py @@ -3,14 +3,11 @@ from sys import executable from aiofiles import open as async_open -from telegram import Update, Message +from telegram import Message, Update from telegram.error import NetworkError -from telegram.ext import CallbackContext, CommandHandler +from telegram.ext import CallbackContext -from core.bot import bot -from core.plugin import handler, Plugin -from utils.bot import get_args -from utils.decorators.admins import bot_admins_rights_check +from core.plugin import Plugin, handler from utils.helpers import execute from utils.log import logger @@ -27,33 +24,33 @@ class UpdatePlugin(Plugin): def __init__(self): - self._lock = asyncio.Lock() + self.lock = asyncio.Lock() - @staticmethod - async def __async_init__(): + async def initialize(self) -> None: if os.path.exists(UPDATE_DATA): async with async_open(UPDATE_DATA) as file: data = jsonlib.loads(await file.read()) try: - reply_text = Message.de_json(data, bot.app.bot) + reply_text = Message.de_json(data, self.application.telegram.bot) await reply_text.edit_text("重启成功") except NetworkError as exc: - logger.error("UpdatePlugin 编辑消息出现错误 %s", exc.message) + logger.error("编辑消息出现错误 %s", exc.message) + except jsonlib.JSONDecodeError: + logger.error("JSONDecodeError") except KeyError as exc: - logger.error("UpdatePlugin 编辑消息出现错误", exc_info=exc) + logger.error("编辑消息出现错误", exc_info=exc) os.remove(UPDATE_DATA) - @handler(CommandHandler, command="update", block=False) - @bot_admins_rights_check + @handler.command("update", block=False, admin=True) async def update(self, update: Update, context: CallbackContext): user = update.effective_user message = update.effective_message - args = get_args(context) - logger.info(f"用户 {user.full_name}[{user.id}] update命令请求") - if self._lock.locked(): + args = self.get_args(context) + logger.info("用户 %s[%s] update命令请求", user.full_name, user.id) + if self.lock.locked(): await message.reply_text("程序正在更新 请勿重复操作") return - async with self._lock: + async with self.lock: reply_text = await message.reply_text("正在更新") logger.info("正在更新代码") await execute("git fetch --all") diff --git a/plugins/system/webapp.py b/plugins/system/webapp.py deleted file mode 100644 index 212bbdd8..00000000 --- a/plugins/system/webapp.py +++ /dev/null @@ -1,155 +0,0 @@ -from typing import Optional - -from genshin import GenshinException, Region -from pydantic import BaseModel -from telegram import KeyboardButton, ReplyKeyboardMarkup, ReplyKeyboardRemove, Update, WebAppInfo -from telegram.ext import CallbackContext, filters - -from core.base.redisdb import RedisDB -from core.config import config -from core.cookies import CookiesService -from core.cookies.error import CookiesNotFoundError -from core.plugin import Plugin, handler -from core.user import UserService -from core.user.error import UserNotFoundError -from modules.apihelper.client.components.verify import Verify -from modules.apihelper.error import ResponseException -from plugins.genshin.verification import VerificationSystem -from utils.decorators.restricts import restricts -from utils.helpers import get_genshin_client -from utils.log import logger - - -class WebAppData(BaseModel): - path: str - data: Optional[dict] - code: int - message: str - - -class WebAppDataException(Exception): - def __init__(self, data): - self.data = data - super().__init__() - - -class WebApp(Plugin): - def __init__(self, user_service: UserService = None, cookies_service: CookiesService = None, redis: RedisDB = None): - self.cookies_service = cookies_service - self.user_service = user_service - self.verification_system = VerificationSystem(redis) - - @staticmethod - def de_web_app_data(data: str) -> WebAppData: - try: - return WebAppData.parse_raw(data) - except Exception as exc: - raise WebAppDataException(data) from exc - - @handler.message(filters=filters.StatusUpdate.WEB_APP_DATA, block=False) - @restricts() - async def app(self, update: Update, context: CallbackContext): - user = update.effective_user - message = update.effective_message - web_app_data = message.web_app_data - if web_app_data: - logger.info("用户 %s[%s] 触发 WEB_APP_DATA 请求", user.full_name, user.id) - result = self.de_web_app_data(web_app_data.data) - logger.debug( - "path[%s]\ndata[%s]\ncode[%s]\nmessage[%s]", result.path, result.data, result.code, result.message - ) - if result.code == 0: - if result.path == "verify": - validate = result.data.get("geetest_validate") - try: - client = await get_genshin_client(user.id) - if client.region != Region.CHINESE: - await message.reply_text("非法用户", reply_markup=ReplyKeyboardRemove()) - return - except UserNotFoundError: - await message.reply_text("用户未找到", reply_markup=ReplyKeyboardRemove()) - return - except CookiesNotFoundError: - await message.reply_text("检测到用户为UID绑定,无需认证", reply_markup=ReplyKeyboardRemove()) - return - verify = Verify(cookies=client.cookie_manager.cookies) - if validate: - _, challenge = await self.verification_system.get_challenge(client.uid) - if challenge: - logger.info( - "用户 %s[%s] 请求通过认证\nchallenge[%s]\nvalidate[%s]", - user.full_name, - user.id, - challenge, - validate, - ) - try: - await verify.verify(challenge=challenge, validate=validate) - logger.success("用户 %s[%s] 验证成功", user.full_name, user.id) - await message.reply_text("验证成功", reply_markup=ReplyKeyboardRemove()) - except ResponseException as exc: - logger.warning( - "用户 %s[%s] 验证失效 API返回 [%s]%s", user.full_name, user.id, exc.code, exc.message - ) - if "拼图已过期" in exc.message: - await message.reply_text( - "验证失败,拼图已过期,请稍后重试或更换使用环境进行验证", reply_markup=ReplyKeyboardRemove() - ) - else: - await message.reply_text( - f"验证失败,错误信息为 [{exc.code}]{exc.message},请稍后重试", - reply_markup=ReplyKeyboardRemove(), - ) - else: - logger.warning("用户 %s[%s] 验证失效 请求已经过期", user.full_name, user.id) - await message.reply_text("验证失效 请求已经过期 请稍后重试", reply_markup=ReplyKeyboardRemove()) - return - try: - await client.get_genshin_notes() - except GenshinException as exc: - if exc.retcode != 1034: - raise exc - else: - await message.reply_text("账户正常,无需认证") - return - try: - data = await verify.create(is_high=True) - challenge = data["challenge"] - gt = data["gt"] - logger.success("用户 %s[%s] 创建验证成功\ngt:%s\nchallenge%s", user.full_name, user.id, gt, challenge) - except ResponseException as exc: - logger.warning("用户 %s[%s] 创建验证失效 API返回 [%s]%s", user.full_name, user.id, exc.code, exc.message) - await message.reply_text( - f"创建验证失败 错误信息为 [{exc.code}]{exc.message} 请稍后重试", reply_markup=ReplyKeyboardRemove() - ) - return - await self.verification_system.set_challenge(client.uid, gt, challenge) - url = f"{config.pass_challenge_user_web}/webapp?username={context.bot.username}&command=verify>={gt}&challenge={challenge}&uid={client.uid}" - await message.reply_text( - "请尽快点击下方手动验证 或发送 /web_cancel 取消操作", - reply_markup=ReplyKeyboardMarkup.from_button( - KeyboardButton( - text="点我手动验证", - web_app=WebAppInfo(url=url), - ) - ), - ) - else: - logger.warning( - "用户 %s[%s] WEB_APP_DATA 请求错误 [%s]%s", user.full_name, user.id, result.code, result.message - ) - if result.path == "verify": - await message.reply_text( - "验证过程中出现问题 %s\n" "如果继续遇到该问题,请打开米游社→我的角色中尝试手动通过验证,或发送 /verify 进行手动验证" % result.message, - reply_markup=ReplyKeyboardRemove(), - ) - else: - await message.reply_text("WebApp返回错误 %s" % result.message, reply_markup=ReplyKeyboardRemove()) - else: - logger.warning("用户 %s[%s] WEB_APP_DATA 非法数据", user.full_name, user.id) - - @handler.command("web_cancel", block=False) - @restricts() - async def web_cancel(self, update: Update, _: CallbackContext) -> None: - message = update.effective_message - await message.reply_text("取消操作", reply_markup=ReplyKeyboardRemove()) diff --git a/plugins/tools/__init__.py b/plugins/tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/plugins/tools/challenge.py b/plugins/tools/challenge.py new file mode 100644 index 00000000..70cb283f --- /dev/null +++ b/plugins/tools/challenge.py @@ -0,0 +1,99 @@ +from typing import Tuple, Optional + +from genshin import Region, GenshinException + +from core.dependence.redisdb import RedisDB +from core.plugin import Plugin +from core.services.cookies import CookiesService +from modules.apihelper.client.components.verify import Verify +from modules.apihelper.error import ResponseException, APIHelperException +from plugins.tools.genshin import GenshinHelper, PlayerNotFoundError +from utils.log import logger + +__all__ = ("ChallengeSystemException", "ChallengeSystem") + + +class ChallengeSystemException(Exception): + def __init__(self, message: str): + self.message = message + super().__init__() + + +class ChallengeSystem(Plugin): + def __init__( + self, + cookies_service: CookiesService, + redis: RedisDB, + genshin_helper: GenshinHelper, + ) -> None: + self.cookies_service = cookies_service + self.genshin_helper = genshin_helper + self.cache = redis.client + self.qname = "plugin:challenge:" + + async def get_challenge(self, uid: int) -> Tuple[Optional[str], Optional[str]]: + data = await self.cache.get(f"{self.qname}{uid}") + if not data: + return None, None + data = data.decode("utf-8").split("|") + return data[0], data[1] + + async def set_challenge(self, uid: int, gt: str, challenge: str): + await self.cache.set(f"{self.qname}{uid}", f"{gt}|{challenge}") + await self.cache.expire(f"{self.qname}{uid}", 10 * 60) + + async def create_challenge( + self, user_id: int, need_verify: bool = True, ajax: bool = False + ) -> Tuple[Optional[int], Optional[str], Optional[str]]: + try: + client = await self.genshin_helper.get_genshin_client(user_id) + except PlayerNotFoundError: + raise ChallengeSystemException("用户未找到") + if client.region != Region.CHINESE: + raise ChallengeSystemException("非法用户") + if need_verify: + try: + await client.get_genshin_notes() + except GenshinException as exc: + if exc.retcode != 1034: + raise exc + raise ChallengeSystemException("账户正常,无需认证") + verify = Verify(cookies=client.cookie_manager.cookies) + try: + data = await verify.create() + challenge = data["challenge"] + gt = data["gt"] + except ResponseException as exc: + logger.warning("用户 %s 创建验证失效 API返回 [%s]%s", user_id, exc.code, exc.message) + raise ChallengeSystemException(f"创建验证失败 错误信息为 [{exc.code}]{exc.message} 请稍后重试") + if ajax: + try: + validate = await verify.ajax(referer="https://webstatic.mihoyo.com/", gt=gt, challenge=challenge) + if validate: + await verify.verify(challenge, validate) + return client.uid, "ajax", "ajax" + except APIHelperException as exc: + logger.warning("用户 %s ajax 验证失效 错误信息为 %s", user_id, str(exc)) + await self.set_challenge(client.uid, gt, challenge) + return client.uid, gt, challenge + + async def pass_challenge(self, user_id: int, validate: str, challenge: Optional[str] = None) -> bool: + try: + client = await self.genshin_helper.get_genshin_client(user_id) + except PlayerNotFoundError: + raise ChallengeSystemException("用户未找到") + if client.region != Region.CHINESE: + raise ChallengeSystemException("非法用户") + if challenge is None: + _, challenge = await self.get_challenge(client.uid) + if challenge is None: + raise ChallengeSystemException("验证失效 请求已经过期") + verify = Verify(cookies=client.cookie_manager.cookies) + try: + await verify.verify(challenge=challenge, validate=validate) + except ResponseException as exc: + logger.warning("用户 %s 验证失效 API返回 [%s]%s", user_id, exc.code, exc.message) + if "拼图已过期" in exc.message: + raise ChallengeSystemException("验证失败,拼图已过期,请稍后重试或更换使用环境进行验证") + raise ChallengeSystemException(f"验证失败,错误信息为 [{exc.code}]{exc.message},请稍后重试") + return True diff --git a/plugins/tools/genshin.py b/plugins/tools/genshin.py new file mode 100644 index 00000000..a453ab9d --- /dev/null +++ b/plugins/tools/genshin.py @@ -0,0 +1,297 @@ +import asyncio +import random +import re +from datetime import datetime, timedelta, time +from typing import Optional, Tuple, Union, TYPE_CHECKING + +import genshin +from genshin.errors import GenshinException +from genshin.models import BaseCharacter +from genshin.models import CalculatorCharacterDetails +from pydantic import ValidationError +from sqlalchemy.exc import SQLAlchemyError +from sqlmodel import SQLModel, Field, String, Column, Integer, BigInteger, select, DateTime, func, delete +from telegram.ext import ContextTypes + +from core.basemodel import RegionEnum +from core.config import config +from core.dependence.mysql import MySQL +from core.dependence.redisdb import RedisDB +from core.error import ServiceNotFoundError +from core.plugin import Plugin +from core.services.cookies.services import CookiesService, PublicCookiesService +from core.services.players.services import PlayersService +from core.services.users.services import UserService +from core.sqlmodel.session import AsyncSession +from utils.const import REGION_MAP +from utils.log import logger + +if TYPE_CHECKING: + from sqlalchemy import Table + from genshin import Client as GenshinClient + +__all__ = ("GenshinHelper", "PlayerNotFoundError", "CookiesNotFoundError") + + +class PlayerNotFoundError(Exception): + def __init__(self, user_id): + super().__init__(f"User not found, user_id: {user_id}") + + +class CookiesNotFoundError(Exception): + def __init__(self, user_id): + super().__init__(f"{user_id} cookies not found") + + +class CharacterDetailsSQLModel(SQLModel, table=True): + __tablename__ = "character_details" + __table_args__ = (dict(mysql_charset="utf8mb4", mysql_collate="utf8mb4_general_ci"),) + id: Optional[int] = Field(default=None, sa_column=Column(Integer, primary_key=True, autoincrement=True)) + player_id: int = Field(sa_column=Column(BigInteger(), primary_key=True)) + character_id: int = Field(sa_column=Column(BigInteger(), primary_key=True)) + data: Optional[str] = Field(sa_column=Column(String(length=4096))) + time_updated: Optional[datetime] = Field(sa_column=Column(DateTime, onupdate=func.now())) # pylint: disable=E1102 + + +class CharacterDetails(Plugin): + def __init__( + self, + mysql: MySQL, + redis: RedisDB, + ) -> None: + self.mysql = mysql + self.redis = redis.client + self.ttl = 60 * 60 * 3 + + async def initialize(self) -> None: + def fetch_and_update_objects(connection): + if not self.mysql.engine.dialect.has_table(connection, table_name="character_details"): + logger.info("正在创建角色详细信息表") + table: "Table" = SQLModel.metadata.tables["character_details"] + table.create(connection) + logger.success("创建角色详细信息表成功") + + async with self.mysql.engine.begin() as conn: + await conn.run_sync(fetch_and_update_objects) + asyncio.create_task(self.save_character_details_task(max_ttl=None)) + self.application.job_queue.run_daily(self.del_old_data_job, time(hour=3, minute=0)) + self.application.job_queue.run_repeating(self.save_character_details_job, timedelta(hours=1)) + + async def save_character_details_job(self, _: ContextTypes.DEFAULT_TYPE): + await self.save_character_details() + + async def del_old_data_job(self, _: ContextTypes.DEFAULT_TYPE): + await self.del_old_data(timedelta(days=7)) + + async def del_old_data(self, expiration_time: timedelta): + expire_time = datetime.now() - expiration_time + statement = delete(CharacterDetailsSQLModel).where(CharacterDetailsSQLModel.time_updated <= expire_time) + async with AsyncSession(self.mysql.engine) as session: + await session.execute(statement) + + async def save_character_details_task(self, max_ttl: Optional[int] = 60 * 60): + logger.info("正在从Redis中保存角色详细信息") + try: + await self.save_character_details(max_ttl) + except SQLAlchemyError as exc: + logger.error("写入到数据库失败 code[%s]", exc.code) + logger.debug("写入到数据库失败", exc_info=exc) + except Exception as exc: + logger.error("save_character_details 执行失败", exc_info=exc) + else: + logger.success("从Redis中保存角色详细信息成功") + + async def save_character_details(self, max_ttl: Optional[int] = 60 * 60): + keys = await self.redis.keys("plugins:character_details:*") + for key in keys: + key = str(key, encoding="utf-8") + ttl = await self.redis.ttl(key) + if max_ttl is None or (0 <= ttl <= max_ttl): + try: + uid, character_id = re.findall(r"\d+", key) + except ValueError: + logger.warning("非法Key %s", key) + continue + data = await self.redis.get(key) + str_data = str(data, encoding="utf-8") + sql_data = CharacterDetailsSQLModel( + player_id=uid, character_id=character_id, data=str_data, time_updated=datetime.now() + ) + async with AsyncSession(self.mysql.engine) as session: + await session.merge(sql_data) + await session.commit() + + @staticmethod + def get_qname(uid: int, character: int): + return f"plugins:character_details:{uid}:{character}" + + async def get_character_details_for_redis( + self, + uid: int, + character_id: int, + ) -> Optional["CalculatorCharacterDetails"]: + name = self.get_qname(uid, character_id) + data = await self.redis.get(name) + if data is None: + return None + json_data = str(data, encoding="utf-8") + return CalculatorCharacterDetails.parse_raw(json_data) + + async def set_character_details_for_redis(self, uid: int, character_id: int, detail: "CalculatorCharacterDetails"): + randint = random.randint(1, 30) # nosec + await self.redis.set( + self.get_qname(uid, character_id), detail.json(), ex=self.ttl + randint * 60 # 使用随机数防止缓存雪崩 + ) + + async def set_character_details_for_mysql(self, uid: int, character_id: int, detail: "CalculatorCharacterDetails"): + data = CharacterDetailsSQLModel(player_id=uid, character_id=character_id, data=detail.json()) + async with AsyncSession(self.mysql.engine) as session: + await session.merge(data) + await session.commit() + + async def get_character_details_for_mysql( + self, + uid: int, + character_id: int, + ) -> Optional["CalculatorCharacterDetails"]: + async with AsyncSession(self.mysql.engine) as session: + statement = ( + select(CharacterDetailsSQLModel) + .where(CharacterDetailsSQLModel.player_id == uid) + .where(CharacterDetailsSQLModel.character_id == character_id) + ) + results = await session.exec(statement) + data = results.first() + if data is not None: + try: + return CalculatorCharacterDetails.parse_raw(data.data) + except ValidationError as exc: + logger.error("解析数据出现异常 ValidationError", exc_info=exc) + await session.delete(data) + await session.commit() + except ValueError as exc: + logger.error("解析数据出现异常 ValueError", exc_info=exc) + await session.delete(data) + await session.commit() + return None + + async def get_character_details( + self, client: "GenshinClient", character: "Union[int,BaseCharacter]" + ) -> Optional["CalculatorCharacterDetails"]: + """缓存 character_details 并定时对其进行数据存储 当遇到 Too Many Requests 可以获取以前的数据 + :param client: genshin.py + :param character: + :return: + """ + uid = client.uid + if uid is not None: + if isinstance(character, BaseCharacter): + character_id = character.id + else: + character_id = character + detail = await self.get_character_details_for_redis(uid, character_id) + if detail is not None: + return detail + try: + detail = await client.get_character_details(character) + except GenshinException as exc: + if "Too Many Requests" in exc.msg: + return await self.get_character_details_for_mysql(uid, character_id) + await self.set_character_details_for_redis(uid, character_id, detail) + return detail + try: + return await client.get_character_details(character) + except GenshinException as exc: + if "Too Many Requests" in exc.msg: + logger.warning("Too Many Requests") + else: + raise exc + return None + + +class GenshinHelper(Plugin): + def __init__( + self, + cookies: CookiesService, + public_cookies: PublicCookiesService, + user: UserService, + redis: RedisDB, + player: PlayersService, + ) -> None: + self.cookies_service = cookies + self.public_cookies_service = public_cookies + self.user_service = user + self.redis_db = redis + self.players_service = player + + if self.redis_db and config.genshin_ttl: + self.genshin_cache = genshin.RedisCache(self.redis_db.client, ttl=config.genshin_ttl) + else: + self.genshin_cache = None + + if None in (temp := [self.user_service, self.cookies_service, self.players_service]): + raise ServiceNotFoundError(*filter(lambda x: x is None, temp)) + + @staticmethod + def region_server(uid: Union[int, str]) -> RegionEnum: + if isinstance(uid, (int, str)): + region = REGION_MAP.get(str(uid)[0]) + else: + raise TypeError("UID variable type error") + if region: + return region + raise ValueError(f"UID {uid} isn't associated with any region.") + + async def get_genshin_client( + self, user_id: int, region: Optional[RegionEnum] = None, need_cookie: bool = True + ) -> Optional[genshin.Client]: + """通过 user_id 和 region 获取私有的 `genshin.Client`""" + player = await self.players_service.get_player(user_id, region) + if player is None: + raise PlayerNotFoundError(user_id) + cookies = None + if need_cookie: + cookie_model = await self.cookies_service.get(player.user_id, player.account_id, player.region) + if cookie_model is None: + raise CookiesNotFoundError(user_id) + cookies = cookie_model.data + + uid = player.player_id + region = player.region + if region == RegionEnum.HYPERION: # 国服 + game_region = genshin.types.Region.CHINESE + elif region == RegionEnum.HOYOLAB: # 国际服 + game_region = genshin.types.Region.OVERSEAS + else: + raise TypeError("Region is not None") + + client = genshin.Client(cookies, lang="zh-cn", game=genshin.types.Game.GENSHIN, region=game_region, uid=uid) + + if self.genshin_cache is not None: + client.cache = self.genshin_cache + + return client + + async def get_public_genshin_client(self, user_id: int) -> Tuple[genshin.Client, int]: + """通过 user_id 获取公共的 `genshin.Client`""" + player = await self.players_service.get_player(user_id) + + region = player.region + cookies = await self.public_cookies_service.get_cookies(user_id, region) + + uid = player.player_id + if region is RegionEnum.HYPERION: + game_region = genshin.types.Region.CHINESE + elif region is RegionEnum.HOYOLAB: + game_region = genshin.types.Region.OVERSEAS + else: + raise TypeError("Region is not `RegionEnum.NULL`") + + client = genshin.Client( + cookies.data, region=game_region, uid=uid, game=genshin.types.Game.GENSHIN, lang="zh-cn" + ) + + if self.genshin_cache is not None: + client.cache = self.genshin_cache + + return client, uid diff --git a/plugins/tools/sign.py b/plugins/tools/sign.py new file mode 100644 index 00000000..05fa4da1 --- /dev/null +++ b/plugins/tools/sign.py @@ -0,0 +1,370 @@ +import asyncio +import datetime +import random +import time +from enum import Enum +from json import JSONDecodeError +from typing import Optional, Tuple, List + +from aiohttp import ClientConnectorError +from genshin import Game, GenshinException, AlreadyClaimed, Client, InvalidCookies +from genshin.utility import recognize_genshin_server +from httpx import AsyncClient, TimeoutException +from telegram import InlineKeyboardButton, InlineKeyboardMarkup +from telegram.constants import ParseMode +from telegram.error import Forbidden, BadRequest +from telegram.ext import CallbackContext + +from core.config import config +from core.dependence.redisdb import RedisDB +from core.plugin import Plugin +from core.services.cookies import CookiesService +from core.services.sign.models import SignStatusEnum +from core.services.sign.services import SignServices +from core.services.users.services import UserService +from modules.apihelper.client.components.verify import Verify +from plugins.tools.genshin import GenshinHelper +from utils.log import logger + + +class SignJobType(Enum): + START = 1 + REDO = 2 + + +class SignSystemException(Exception): + def __init__(self, message: str): + self.message = message + super().__init__() + + +class NeedChallenge(Exception): + def __init__(self, uid: int, gt: str = "", challenge: str = ""): + super().__init__() + self.uid = uid + self.gt = gt + self.challenge = challenge + + +class SignSystem(Plugin): + REFERER = ( + "https://webstatic.mihoyo.com/bbs/event/signin-ys/index.html?" + "bbs_auth_required=true&act_id=e202009291139501&utm_source=bbs&utm_medium=mys&utm_campaign=icon" + ) + + def __init__( + self, + redis: RedisDB, + user_service: UserService, + cookies_service: CookiesService, + sign_service: SignServices, + genshin_helper: GenshinHelper, + ): + self.cookies_service = cookies_service + self.user_service = user_service + self.sign_service = sign_service + self.genshin_helper = genshin_helper + self.cache = redis.client + self.qname = "plugin:sign:" + self.verify = Verify() + + async def get_challenge(self, uid: int) -> Tuple[Optional[str], Optional[str]]: + data = await self.cache.get(f"{self.qname}{uid}") + if not data: + return None, None + data = data.decode("utf-8").split("|") + return data[0], data[1] + + async def set_challenge(self, uid: int, gt: str, challenge: str): + await self.cache.set(f"{self.qname}{uid}", f"{gt}|{challenge}") + await self.cache.expire(f"{self.qname}{uid}", 10 * 60) + + async def get_challenge_button( + self, + bot_username: str, + uid: int, + user_id: int, + gt: Optional[str] = None, + challenge: Optional[str] = None, + callback: bool = True, + ) -> Optional[InlineKeyboardMarkup]: + if not config.pass_challenge_user_web: + return None + if challenge and gt: + await self.set_challenge(uid, gt, challenge) + if not challenge or not gt: + gt, challenge = await self.get_challenge(uid) + if not challenge or not gt: + return None + if callback: + data = f"sign|{user_id}|{uid}" + return InlineKeyboardMarkup([[InlineKeyboardButton("请尽快点我进行手动验证", callback_data=data)]]) + url = ( + f"{config.pass_challenge_user_web}?" + f"username={bot_username}&command=sign>={gt}&challenge={challenge}&uid={uid}" + ) + return InlineKeyboardMarkup([[InlineKeyboardButton("请尽快点我进行手动验证", url=url)]]) + + async def recognize(self, gt: str, challenge: str, referer: str = None) -> Optional[str]: + if not referer: + referer = self.REFERER + if not gt or not challenge: + return None + pass_challenge_params = { + "gt": gt, + "challenge": challenge, + "referer": referer, + } + if config.pass_challenge_app_key: + pass_challenge_params["appkey"] = config.pass_challenge_app_key + headers = { + "Accept": "*/*", + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/107.0.0.0 Safari/537.36", + } + try: + async with AsyncClient(headers=headers) as client: + resp = await client.post( + config.pass_challenge_api, + params=pass_challenge_params, + timeout=60, + ) + logger.debug("recognize 请求返回:%s", resp.text) + data = resp.json() + status = data.get("status") + if status != 0: + logger.error("recognize 解析错误:[%s]%s", data.get("code"), data.get("msg")) + if data.get("code", 0) != 0: + raise RuntimeError + logger.info("recognize 解析成功") + return data["data"]["validate"] + except JSONDecodeError: + logger.warning("recognize 请求 JSON 解析失败") + except TimeoutException as exc: + logger.warning("recognize 请求超时") + raise exc + except KeyError: + logger.warning("recognize 请求数据错误") + except RuntimeError: + logger.warning("recognize 请求失败") + return None + + async def start_sign( + self, + client: Client, + challenge: Optional[str] = None, + validate: Optional[str] = None, + is_sleep: bool = False, + is_raise: bool = False, + title: Optional[str] = "签到结果", + ) -> str: + if is_sleep: + if recognize_genshin_server(client.uid) in ("cn_gf01", "cn_qd01"): + await asyncio.sleep(random.randint(10, 300)) # nosec + else: + await asyncio.sleep(random.randint(0, 3)) # nosec + try: + rewards = await client.get_monthly_rewards(game=Game.GENSHIN, lang="zh-cn") + except GenshinException as error: + logger.warning("UID[%s] 获取签到信息失败,API返回信息为 %s", client.uid, str(error)) + if is_raise: + raise error + return f"获取签到信息失败,API返回信息为 {str(error)}" + try: + daily_reward_info = await client.get_reward_info(game=Game.GENSHIN, lang="zh-cn") # 获取签到信息失败 + except GenshinException as error: + logger.warning("UID[%s] 获取签到状态失败,API返回信息为 %s", client.uid, str(error)) + if is_raise: + raise error + return f"获取签到状态失败,API返回信息为 {str(error)}" + if not daily_reward_info.signed_in: + try: + if validate: + logger.info("UID[%s] 正在尝试通过验证码\nchallenge[%s]\nvalidate[%s]", client.uid, challenge, validate) + request_daily_reward = await client.request_daily_reward( + "sign", + method="POST", + game=Game.GENSHIN, + lang="zh-cn", + challenge=challenge, + validate=validate, + ) + logger.debug("request_daily_reward 返回 %s", request_daily_reward) + if request_daily_reward and request_daily_reward.get("success", 0) == 1: + # 尝试通过 ajax 请求绕过签到 + gt = request_daily_reward.get("gt", "") + challenge = request_daily_reward.get("challenge", "") + logger.warning("UID[%s] 触发验证码\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) + validate = await self.verify.ajax( + referer=self.REFERER, + gt=gt, + challenge=challenge, + ) + if validate: + logger.success("ajax 通过验证成功\nchallenge[%s]\nvalidate[%s]", challenge, validate) + request_daily_reward = await client.request_daily_reward( + "sign", + method="POST", + game=Game.GENSHIN, + lang="zh-cn", + challenge=challenge, + validate=validate, + ) + logger.debug("request_daily_reward 返回 %s", request_daily_reward) + if request_daily_reward and request_daily_reward.get("success", 0) == 1: + logger.warning("UID[%s] 触发验证码\nchallenge[%s]", client.uid, challenge) + raise NeedChallenge( + uid=client.uid, + gt=request_daily_reward.get("gt", ""), + challenge=request_daily_reward.get("challenge", ""), + ) + elif config.pass_challenge_app_key: + # 如果无法绕过 检查配置文件是否配置识别 API 尝试请求绕过 + # 注意 需要重新获取没有进行任何请求的 Challenge + logger.info("UID[%s] 正在使用 recognize 重新请求签到", client.uid) + _request_daily_reward = await client.request_daily_reward( + "sign", + method="POST", + game=Game.GENSHIN, + lang="zh-cn", + ) + logger.debug("request_daily_reward 返回\n%s", _request_daily_reward) + if _request_daily_reward and _request_daily_reward.get("success", 0) == 1: + _gt = _request_daily_reward.get("gt", "") + _challenge = _request_daily_reward.get("challenge", "") + logger.info("UID[%s] 创建验证码\ngt[%s]\nchallenge[%s]", client.uid, _gt, _challenge) + _validate = await self.recognize(_gt, _challenge) + if _validate: + logger.success("recognize 通过验证成功\nchallenge[%s]\nvalidate[%s]", _challenge, _validate) + request_daily_reward = await client.request_daily_reward( + "sign", + method="POST", + game=Game.GENSHIN, + lang="zh-cn", + challenge=_challenge, + validate=_validate, + ) + if request_daily_reward and request_daily_reward.get("success", 0) == 1: + logger.warning("UID[%s] 触发验证码\nchallenge[%s]", client.uid, _challenge) + gt = request_daily_reward.get("gt", "") + challenge = request_daily_reward.get("challenge", "") + logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) + raise NeedChallenge( + uid=client.uid, + gt=gt, + challenge=challenge, + ) + logger.success("UID[%s] 通过 recognize 签到成功", client.uid) + else: + request_daily_reward = await client.request_daily_reward( + "sign", method="POST", game=Game.GENSHIN, lang="zh-cn" + ) + gt = request_daily_reward.get("gt", "") + challenge = request_daily_reward.get("challenge", "") + logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) + raise NeedChallenge(uid=client.uid, gt=gt, challenge=challenge) + else: + request_daily_reward = await client.request_daily_reward( + "sign", method="POST", game=Game.GENSHIN, lang="zh-cn" + ) + gt = request_daily_reward.get("gt", "") + challenge = request_daily_reward.get("challenge", "") + logger.success("UID[%s] 创建验证成功\ngt[%s]\nchallenge[%s]", client.uid, gt, challenge) + raise NeedChallenge(uid=client.uid, gt=gt, challenge=challenge) + else: + logger.success("UID[%s] 签到成功", client.uid) + except TimeoutException as error: + logger.warning("UID[%s] 签到请求超时", client.uid) + if is_raise: + raise error + return "签到失败了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ " + except AlreadyClaimed as error: + logger.warning("UID[%s] 已经签到", client.uid) + if is_raise: + raise error + result = "今天旅行者已经签到过了~" + except GenshinException as error: + logger.warning("UID %s 签到失败,API返回信息为 %s", client.uid, str(error)) + if is_raise: + raise error + return f"获取签到状态失败,API返回信息为 {str(error)}" + else: + result = "OK" + else: + logger.info("UID[%s] 已经签到", client.uid) + result = "今天旅行者已经签到过了~" + logger.info("UID[%s] 签到结果 %s", client.uid, result) + reward = rewards[daily_reward_info.claimed_rewards - (1 if daily_reward_info.signed_in else 0)] + today = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) + cn_timezone = datetime.timezone(datetime.timedelta(hours=8)) + now = datetime.datetime.now(cn_timezone) + missed_days = now.day - daily_reward_info.claimed_rewards + if not daily_reward_info.signed_in: + missed_days -= 1 + message = ( + f"#### {title} ####\n" + f"时间:{today} (UTC+8)\n" + f"UID: {client.uid}\n" + f"今日奖励: {reward.name} × {reward.amount}\n" + f"本月漏签次数:{missed_days}\n" + f"签到结果: {result}" + ) + return message + + async def do_sign_job(self, context: CallbackContext, job_type: SignJobType): + include_status: List[SignStatusEnum] = [ + SignStatusEnum.STATUS_SUCCESS, + SignStatusEnum.TIMEOUT_ERROR, + SignStatusEnum.NEED_CHALLENGE, + ] + if job_type == SignJobType.START: + title = "自动签到" + elif job_type == SignJobType.REDO: + title = "自动重新签到" + else: + raise ValueError + sign_list = await self.sign_service.get_all() + for sign_db in sign_list: + if sign_db.status not in include_status: + continue + user_id = sign_db.user_id + try: + client = await self.genshin_helper.get_genshin_client(user_id) + text = await self.start_sign(client, is_sleep=True, is_raise=True, title=title) + except InvalidCookies: + text = "自动签到执行失败,Cookie无效" + sign_db.status = SignStatusEnum.INVALID_COOKIES + except AlreadyClaimed: + text = "今天旅行者已经签到过了~" + sign_db.status = SignStatusEnum.ALREADY_CLAIMED + except GenshinException as exc: + text = f"自动签到执行失败,API返回信息为 {str(exc)}" + sign_db.status = SignStatusEnum.GENSHIN_EXCEPTION + except ClientConnectorError: + text = "签到失败了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ " + sign_db.status = SignStatusEnum.TIMEOUT_ERROR + except NeedChallenge: + text = "签到失败,触发验证码风控,自动签到自动关闭" + sign_db.status = SignStatusEnum.NEED_CHALLENGE + except Exception as exc: + logger.error("执行自动签到时发生错误 user_id[%s] Message[%s]", user_id, exc.message) + text = "签到失败了呜呜呜 ~ 执行自动签到时发生错误" + else: + sign_db.status = SignStatusEnum.STATUS_SUCCESS + if sign_db.chat_id < 0: + text = f'NOTICE {sign_db.user_id}\n\n{text}' + try: + await context.bot.send_message(sign_db.chat_id, text, parse_mode=ParseMode.HTML) + except BadRequest as exc: + logger.error("执行自动签到时发生错误 user_id[%s] Message[%s]", user_id, exc.message) + sign_db.status = SignStatusEnum.BAD_REQUEST + except Forbidden as exc: + logger.error("执行自动签到时发生错误 user_id[%s] message[%s]", user_id, exc.message) + sign_db.status = SignStatusEnum.FORBIDDEN + except Exception as exc: + logger.error("执行自动签到时发生错误 user_id[%s]", user_id, exc_info=exc) + continue + else: + sign_db.status = SignStatusEnum.STATUS_SUCCESS + sign_db.time_updated = datetime.datetime.now() + await self.sign_service.update(sign_db) diff --git a/poetry.lock b/poetry.lock index ddbecd91..070f1cb7 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. [[package]] name = "aiofiles" @@ -1621,6 +1621,97 @@ files = [ {file = "pathspec-0.11.0.tar.gz", hash = "sha256:64d338d4e0914e91c1792321e6907b5a593f1ab1851de7fc269557a21b30ebbc"}, ] +[[package]] +name = "pillow" +version = "9.4.0" +description = "Python Imaging Library (Fork)" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "Pillow-9.4.0-1-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:1b4b4e9dda4f4e4c4e6896f93e84a8f0bcca3b059de9ddf67dac3c334b1195e1"}, + {file = "Pillow-9.4.0-1-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:fb5c1ad6bad98c57482236a21bf985ab0ef42bd51f7ad4e4538e89a997624e12"}, + {file = "Pillow-9.4.0-1-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:f0caf4a5dcf610d96c3bd32932bfac8aee61c96e60481c2a0ea58da435e25acd"}, + {file = "Pillow-9.4.0-1-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:3f4cc516e0b264c8d4ccd6b6cbc69a07c6d582d8337df79be1e15a5056b258c9"}, + {file = "Pillow-9.4.0-1-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:b8c2f6eb0df979ee99433d8b3f6d193d9590f735cf12274c108bd954e30ca858"}, + {file = "Pillow-9.4.0-1-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b70756ec9417c34e097f987b4d8c510975216ad26ba6e57ccb53bc758f490dab"}, + {file = "Pillow-9.4.0-1-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:43521ce2c4b865d385e78579a082b6ad1166ebed2b1a2293c3be1d68dd7ca3b9"}, + {file = "Pillow-9.4.0-2-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:9d9a62576b68cd90f7075876f4e8444487db5eeea0e4df3ba298ee38a8d067b0"}, + {file = "Pillow-9.4.0-2-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:87708d78a14d56a990fbf4f9cb350b7d89ee8988705e58e39bdf4d82c149210f"}, + {file = "Pillow-9.4.0-2-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:8a2b5874d17e72dfb80d917213abd55d7e1ed2479f38f001f264f7ce7bae757c"}, + {file = "Pillow-9.4.0-2-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:83125753a60cfc8c412de5896d10a0a405e0bd88d0470ad82e0869ddf0cb3848"}, + {file = "Pillow-9.4.0-2-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:9e5f94742033898bfe84c93c831a6f552bb629448d4072dd312306bab3bd96f1"}, + {file = "Pillow-9.4.0-2-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:013016af6b3a12a2f40b704677f8b51f72cb007dac785a9933d5c86a72a7fe33"}, + {file = "Pillow-9.4.0-2-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:99d92d148dd03fd19d16175b6d355cc1b01faf80dae93c6c3eb4163709edc0a9"}, + {file = "Pillow-9.4.0-cp310-cp310-macosx_10_10_x86_64.whl", hash = "sha256:2968c58feca624bb6c8502f9564dd187d0e1389964898f5e9e1fbc8533169157"}, + {file = "Pillow-9.4.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c5c1362c14aee73f50143d74389b2c158707b4abce2cb055b7ad37ce60738d47"}, + {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bd752c5ff1b4a870b7661234694f24b1d2b9076b8bf337321a814c612665f343"}, + {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9a3049a10261d7f2b6514d35bbb7a4dfc3ece4c4de14ef5876c4b7a23a0e566d"}, + {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16a8df99701f9095bea8a6c4b3197da105df6f74e6176c5b410bc2df2fd29a57"}, + {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_28_aarch64.whl", hash = "sha256:94cdff45173b1919350601f82d61365e792895e3c3a3443cf99819e6fbf717a5"}, + {file = "Pillow-9.4.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ed3e4b4e1e6de75fdc16d3259098de7c6571b1a6cc863b1a49e7d3d53e036070"}, + {file = "Pillow-9.4.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:d5b2f8a31bd43e0f18172d8ac82347c8f37ef3e0b414431157718aa234991b28"}, + {file = "Pillow-9.4.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:09b89ddc95c248ee788328528e6a2996e09eaccddeeb82a5356e92645733be35"}, + {file = "Pillow-9.4.0-cp310-cp310-win32.whl", hash = "sha256:f09598b416ba39a8f489c124447b007fe865f786a89dbfa48bb5cf395693132a"}, + {file = "Pillow-9.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:f6e78171be3fb7941f9910ea15b4b14ec27725865a73c15277bc39f5ca4f8391"}, + {file = "Pillow-9.4.0-cp311-cp311-macosx_10_10_x86_64.whl", hash = "sha256:3fa1284762aacca6dc97474ee9c16f83990b8eeb6697f2ba17140d54b453e133"}, + {file = "Pillow-9.4.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:eaef5d2de3c7e9b21f1e762f289d17b726c2239a42b11e25446abf82b26ac132"}, + {file = "Pillow-9.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a4dfdae195335abb4e89cc9762b2edc524f3c6e80d647a9a81bf81e17e3fb6f0"}, + {file = "Pillow-9.4.0-cp311-cp311-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6abfb51a82e919e3933eb137e17c4ae9c0475a25508ea88993bb59faf82f3b35"}, + {file = "Pillow-9.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:451f10ef963918e65b8869e17d67db5e2f4ab40e716ee6ce7129b0cde2876eab"}, + {file = "Pillow-9.4.0-cp311-cp311-manylinux_2_28_aarch64.whl", hash = "sha256:6663977496d616b618b6cfa43ec86e479ee62b942e1da76a2c3daa1c75933ef4"}, + {file = "Pillow-9.4.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:60e7da3a3ad1812c128750fc1bc14a7ceeb8d29f77e0a2356a8fb2aa8925287d"}, + {file = "Pillow-9.4.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:19005a8e58b7c1796bc0167862b1f54a64d3b44ee5d48152b06bb861458bc0f8"}, + {file = "Pillow-9.4.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f715c32e774a60a337b2bb8ad9839b4abf75b267a0f18806f6f4f5f1688c4b5a"}, + {file = "Pillow-9.4.0-cp311-cp311-win32.whl", hash = "sha256:b222090c455d6d1a64e6b7bb5f4035c4dff479e22455c9eaa1bdd4c75b52c80c"}, + {file = "Pillow-9.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:ba6612b6548220ff5e9df85261bddc811a057b0b465a1226b39bfb8550616aee"}, + {file = "Pillow-9.4.0-cp37-cp37m-macosx_10_10_x86_64.whl", hash = "sha256:5f532a2ad4d174eb73494e7397988e22bf427f91acc8e6ebf5bb10597b49c493"}, + {file = "Pillow-9.4.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5dd5a9c3091a0f414a963d427f920368e2b6a4c2f7527fdd82cde8ef0bc7a327"}, + {file = "Pillow-9.4.0-cp37-cp37m-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ef21af928e807f10bf4141cad4746eee692a0dd3ff56cfb25fce076ec3cc8abe"}, + {file = "Pillow-9.4.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:847b114580c5cc9ebaf216dd8c8dbc6b00a3b7ab0131e173d7120e6deade1f57"}, + {file = "Pillow-9.4.0-cp37-cp37m-manylinux_2_28_aarch64.whl", hash = "sha256:653d7fb2df65efefbcbf81ef5fe5e5be931f1ee4332c2893ca638c9b11a409c4"}, + {file = "Pillow-9.4.0-cp37-cp37m-manylinux_2_28_x86_64.whl", hash = "sha256:46f39cab8bbf4a384ba7cb0bc8bae7b7062b6a11cfac1ca4bc144dea90d4a9f5"}, + {file = "Pillow-9.4.0-cp37-cp37m-win32.whl", hash = "sha256:7ac7594397698f77bce84382929747130765f66406dc2cd8b4ab4da68ade4c6e"}, + {file = "Pillow-9.4.0-cp37-cp37m-win_amd64.whl", hash = "sha256:46c259e87199041583658457372a183636ae8cd56dbf3f0755e0f376a7f9d0e6"}, + {file = "Pillow-9.4.0-cp38-cp38-macosx_10_10_x86_64.whl", hash = "sha256:0e51f608da093e5d9038c592b5b575cadc12fd748af1479b5e858045fff955a9"}, + {file = "Pillow-9.4.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:765cb54c0b8724a7c12c55146ae4647e0274a839fb6de7bcba841e04298e1011"}, + {file = "Pillow-9.4.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:519e14e2c49fcf7616d6d2cfc5c70adae95682ae20f0395e9280db85e8d6c4df"}, + {file = "Pillow-9.4.0-cp38-cp38-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d197df5489004db87d90b918033edbeee0bd6df3848a204bca3ff0a903bef837"}, + {file = "Pillow-9.4.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0845adc64fe9886db00f5ab68c4a8cd933ab749a87747555cec1c95acea64b0b"}, + {file = "Pillow-9.4.0-cp38-cp38-manylinux_2_28_aarch64.whl", hash = "sha256:e1339790c083c5a4de48f688b4841f18df839eb3c9584a770cbd818b33e26d5d"}, + {file = "Pillow-9.4.0-cp38-cp38-manylinux_2_28_x86_64.whl", hash = "sha256:a96e6e23f2b79433390273eaf8cc94fec9c6370842e577ab10dabdcc7ea0a66b"}, + {file = "Pillow-9.4.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:7cfc287da09f9d2a7ec146ee4d72d6ea1342e770d975e49a8621bf54eaa8f30f"}, + {file = "Pillow-9.4.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:d7081c084ceb58278dd3cf81f836bc818978c0ccc770cbbb202125ddabec6628"}, + {file = "Pillow-9.4.0-cp38-cp38-win32.whl", hash = "sha256:df41112ccce5d47770a0c13651479fbcd8793f34232a2dd9faeccb75eb5d0d0d"}, + {file = "Pillow-9.4.0-cp38-cp38-win_amd64.whl", hash = "sha256:7a21222644ab69ddd9967cfe6f2bb420b460dae4289c9d40ff9a4896e7c35c9a"}, + {file = "Pillow-9.4.0-cp39-cp39-macosx_10_10_x86_64.whl", hash = "sha256:0f3269304c1a7ce82f1759c12ce731ef9b6e95b6df829dccd9fe42912cc48569"}, + {file = "Pillow-9.4.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:cb362e3b0976dc994857391b776ddaa8c13c28a16f80ac6522c23d5257156bed"}, + {file = "Pillow-9.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2e0f87144fcbbe54297cae708c5e7f9da21a4646523456b00cc956bd4c65815"}, + {file = "Pillow-9.4.0-cp39-cp39-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:28676836c7796805914b76b1837a40f76827ee0d5398f72f7dcc634bae7c6264"}, + {file = "Pillow-9.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0884ba7b515163a1a05440a138adeb722b8a6ae2c2b33aea93ea3118dd3a899e"}, + {file = "Pillow-9.4.0-cp39-cp39-manylinux_2_28_aarch64.whl", hash = "sha256:53dcb50fbdc3fb2c55431a9b30caeb2f7027fcd2aeb501459464f0214200a503"}, + {file = "Pillow-9.4.0-cp39-cp39-manylinux_2_28_x86_64.whl", hash = "sha256:e8c5cf126889a4de385c02a2c3d3aba4b00f70234bfddae82a5eaa3ee6d5e3e6"}, + {file = "Pillow-9.4.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:6c6b1389ed66cdd174d040105123a5a1bc91d0aa7059c7261d20e583b6d8cbd2"}, + {file = "Pillow-9.4.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:0dd4c681b82214b36273c18ca7ee87065a50e013112eea7d78c7a1b89a739153"}, + {file = "Pillow-9.4.0-cp39-cp39-win32.whl", hash = "sha256:6d9dfb9959a3b0039ee06c1a1a90dc23bac3b430842dcb97908ddde05870601c"}, + {file = "Pillow-9.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:54614444887e0d3043557d9dbc697dbb16cfb5a35d672b7a0fcc1ed0cf1c600b"}, + {file = "Pillow-9.4.0-pp38-pypy38_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b9b752ab91e78234941e44abdecc07f1f0d8f51fb62941d32995b8161f68cfe5"}, + {file = "Pillow-9.4.0-pp38-pypy38_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3b56206244dc8711f7e8b7d6cad4663917cd5b2d950799425076681e8766286"}, + {file = "Pillow-9.4.0-pp38-pypy38_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aabdab8ec1e7ca7f1434d042bf8b1e92056245fb179790dc97ed040361f16bfd"}, + {file = "Pillow-9.4.0-pp38-pypy38_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:db74f5562c09953b2c5f8ec4b7dfd3f5421f31811e97d1dbc0a7c93d6e3a24df"}, + {file = "Pillow-9.4.0-pp38-pypy38_pp73-win_amd64.whl", hash = "sha256:e9d7747847c53a16a729b6ee5e737cf170f7a16611c143d95aa60a109a59c336"}, + {file = "Pillow-9.4.0-pp39-pypy39_pp73-macosx_10_10_x86_64.whl", hash = "sha256:b52ff4f4e002f828ea6483faf4c4e8deea8d743cf801b74910243c58acc6eda3"}, + {file = "Pillow-9.4.0-pp39-pypy39_pp73-manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:575d8912dca808edd9acd6f7795199332696d3469665ef26163cd090fa1f8bfa"}, + {file = "Pillow-9.4.0-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c3c4ed2ff6760e98d262e0cc9c9a7f7b8a9f61aa4d47c58835cdaf7b0b8811bb"}, + {file = "Pillow-9.4.0-pp39-pypy39_pp73-manylinux_2_28_x86_64.whl", hash = "sha256:e621b0246192d3b9cb1dc62c78cfa4c6f6d2ddc0ec207d43c0dedecb914f152a"}, + {file = "Pillow-9.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:8f127e7b028900421cad64f51f75c051b628db17fb00e099eb148761eed598c9"}, + {file = "Pillow-9.4.0.tar.gz", hash = "sha256:a1c2d7780448eb93fbcc3789bf3916aa5720d942e37945f4056680317f1cd23e"}, +] + +[package.extras] +docs = ["furo", "olefile", "sphinx (>=2.4)", "sphinx-copybutton", "sphinx-inline-tabs", "sphinx-issues (>=3.0.1)", "sphinx-removed-in", "sphinxext-opengraph"] +tests = ["check-manifest", "coverage", "defusedxml", "markdown2", "olefile", "packaging", "pyroma", "pytest", "pytest-cov", "pytest-timeout"] + [[package]] name = "platformdirs" version = "3.1.0" @@ -1911,7 +2002,7 @@ files = [ ] [package.dependencies] -aiolimiter = {version = ">=1.0.0,<1.1.0", optional = true, markers = "extra == \"ext\""} +aiolimiter = {version = ">=1.0.0,<1.1.0", optional = true, markers = "extra == \"ext\" or extra == \"rate-limiter\""} APScheduler = {version = ">=3.10.0,<3.11.0", optional = true, markers = "extra == \"ext\""} cachetools = {version = ">=5.3.0,<5.4.0", optional = true, markers = "extra == \"ext\""} httpx = {version = ">=0.23.3,<0.24.0", extras = ["http2"]} @@ -2257,7 +2348,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and (platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\")"} +greenlet = {version = "!=0.4.17", markers = "python_version >= \"3\" and platform_machine == \"aarch64\" or python_version >= \"3\" and platform_machine == \"ppc64le\" or python_version >= \"3\" and platform_machine == \"x86_64\" or python_version >= \"3\" and platform_machine == \"amd64\" or python_version >= \"3\" and platform_machine == \"AMD64\" or python_version >= \"3\" and platform_machine == \"win32\" or python_version >= \"3\" and platform_machine == \"WIN32\""} [package.extras] aiomysql = ["aiomysql", "greenlet (!=0.4.17)"] @@ -2889,11 +2980,11 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] -all = ["pytest", "pytest-asyncio", "flaky", "Pyrogram", "TgCrypto"] +all = ["Pyrogram", "TgCrypto", "flaky", "pytest", "pytest-asyncio"] pyro = ["Pyrogram", "TgCrypto"] -test = ["pytest", "pytest-asyncio", "flaky"] +test = ["flaky", "pytest", "pytest-asyncio"] [metadata] lock-version = "2.0" python-versions = "^3.8" -content-hash = "17a60c0935380268c882c607bbadbf7d3e931615cde1324da7c2515492039732" +content-hash = "d1846eb4c7be70ecb7c27e528352cdb6a39fd60a9012793671607a68a4871ad3" diff --git a/pyproject.toml b/pyproject.toml index 3e74a39e..84609dd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ async-lru = "^2.0.2" thefuzz = "^0.19.0" qrcode = "^7.4.2" cryptography = "^39.0.1" +pillow = "^9.4.0" [tool.poetry.extras] pyro = ["Pyrogram", "TgCrypto"] diff --git a/requirements.txt b/requirements.txt index 8fd65d32..68b77df0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,29 +1,109 @@ -httpx~=0.23.3 -ujson~=5.7.0 -git+https://github.com/thesadru/genshin.py -Jinja2~=3.1.2 -python-telegram-bot[ext, rate-limiter]~=20.1 -sqlmodel~=0.0.8 -colorlog~=6.7.0 -playwright ~= 1.27.1 -fakeredis ~= 2.9.0 -beautifulsoup4 ~= 4.11.2 -asyncmy ~= 0.2.7 -pyppeteer ~= 1.0.2 -aiofiles ~= 23.1.0 -python-dotenv ~= 1.0.0 -alembic ~= 1.10.2 -black ~= 23.1.0 -rich ~= 13.3.1 -git+https://github.com/mrwan200/EnkaNetwork.py -lxml ~= 4.9.2 -arko-wrapper ~= 0.2.8 -fastapi ~= 0.93.0 -uvicorn[standard] ~= 0.21.0 -sentry-sdk ~= 1.15.0 -GitPython ~= 3.1.30 -openpyxl ~= 3.1.1 -async-lru ~= 2.0.2 -thefuzz ~= 0.19.0 -qrcode ~= 7.4.2 -cryptography ~= 39.0.1 +aiofiles==23.1.0 ; python_version >= "3.8" and python_version < "4.0" +aiohttp==3.8.4 ; python_version >= "3.8" and python_version < "4.0" +aiolimiter==1.0.0 ; python_version >= "3.8" and python_version < "4.0" +aiosignal==1.3.1 ; python_version >= "3.8" and python_version < "4.0" +alembic==1.10.2 ; python_version >= "3.8" and python_version < "4.0" +anyio==3.6.2 ; python_version >= "3.8" and python_version < "4.0" +appdirs==1.4.4 ; python_version >= "3.8" and python_version < "4.0" +apscheduler==3.10.1 ; python_version >= "3.8" and python_version < "4.0" +arko-wrapper==0.2.8 ; python_version >= "3.8" and python_version < "4.0" +async-lru==2.0.2 ; python_version >= "3.8" and python_version < "4.0" +async-timeout==4.0.2 ; python_version >= "3.8" and python_version < "4.0" +asyncmy==0.2.7 ; python_version >= "3.8" and python_version < "4.0" +attrs==22.2.0 ; python_version >= "3.8" and python_version < "4.0" +backports-zoneinfo==0.2.1 ; python_version >= "3.8" and python_version < "3.9" +beautifulsoup4==4.11.2 ; python_version >= "3.8" and python_version < "4.0" +black==23.1.0 ; python_version >= "3.8" and python_version < "4.0" +cachetools==5.3.0 ; python_version >= "3.8" and python_version < "4.0" +certifi==2022.12.7 ; python_version >= "3.8" and python_version < "4.0" +cffi==1.15.1 ; python_version >= "3.8" and python_version < "4.0" +charset-normalizer==3.1.0 ; python_version >= "3.8" and python_version < "4.0" +click==8.1.3 ; python_version >= "3.8" and python_version < "4.0" +colorama==0.4.6 ; python_version >= "3.8" and python_version < "4.0" and sys_platform == "win32" or python_version >= "3.8" and python_version < "4.0" and platform_system == "Windows" +colorlog==6.7.0 ; python_version >= "3.8" and python_version < "4.0" +cryptography==39.0.2 ; python_version >= "3.8" and python_version < "4.0" +enkanetwork-py @ git+https://github.com/mrwan200/EnkaNetwork.py@master ; python_version >= "3.8" and python_version < "4.0" +et-xmlfile==1.1.0 ; python_version >= "3.8" and python_version < "4.0" +exceptiongroup==1.1.0 ; python_version >= "3.8" and python_version < "3.11" +fakeredis==2.10.0 ; python_version >= "3.8" and python_version < "4.0" +fastapi==0.93.0 ; python_version >= "3.8" and python_version < "4.0" +flaky==3.7.0 ; python_version >= "3.8" and python_version < "4.0" +frozenlist==1.3.3 ; python_version >= "3.8" and python_version < "4.0" +genshin @ git+https://github.com/thesadru/genshin.py@master ; python_version >= "3.8" and python_version < "4.0" +gitdb==4.0.10 ; python_version >= "3.8" and python_version < "4.0" +gitpython==3.1.31 ; python_version >= "3.8" and python_version < "4.0" +greenlet==1.1.3 ; python_version >= "3.8" and python_version < "4.0" +h11==0.14.0 ; python_version >= "3.8" and python_version < "4.0" +h2==4.1.0 ; python_version >= "3.8" and python_version < "4.0" +hpack==4.0.0 ; python_version >= "3.8" and python_version < "4.0" +httpcore==0.16.3 ; python_version >= "3.8" and python_version < "4.0" +httptools==0.5.0 ; python_version >= "3.8" and python_version < "4.0" +httpx==0.23.3 ; python_version >= "3.8" and python_version < "4.0" +httpx[http2]==0.23.3 ; python_version >= "3.8" and python_version < "4.0" +hyperframe==6.0.1 ; python_version >= "3.8" and python_version < "4.0" +idna==3.4 ; python_version >= "3.8" and python_version < "4.0" +importlib-metadata==6.0.0 ; python_version >= "3.8" and python_version < "4.0" +importlib-resources==5.12.0 ; python_version >= "3.8" and python_version < "3.9" +iniconfig==2.0.0 ; python_version >= "3.8" and python_version < "4.0" +jinja2==3.1.2 ; python_version >= "3.8" and python_version < "4.0" +lxml==4.9.2 ; python_version >= "3.8" and python_version < "4.0" +mako==1.2.4 ; python_version >= "3.8" and python_version < "4.0" +markdown-it-py==2.2.0 ; python_version >= "3.8" and python_version < "4.0" +markupsafe==2.1.2 ; python_version >= "3.8" and python_version < "4.0" +mdurl==0.1.2 ; python_version >= "3.8" and python_version < "4.0" +multidict==6.0.4 ; python_version >= "3.8" and python_version < "4.0" +mypy-extensions==1.0.0 ; python_version >= "3.8" and python_version < "4.0" +openpyxl==3.1.1 ; python_version >= "3.8" and python_version < "4.0" +packaging==23.0 ; python_version >= "3.8" and python_version < "4.0" +pathspec==0.11.0 ; python_version >= "3.8" and python_version < "4.0" +pillow==9.4.0 ; python_version >= "3.8" and python_version < "4.0" +platformdirs==3.1.0 ; python_version >= "3.8" and python_version < "4.0" +playwright==1.27.1 ; python_version >= "3.8" and python_version < "4.0" +pluggy==1.0.0 ; python_version >= "3.8" and python_version < "4.0" +pyaes==1.6.1 ; python_version >= "3.8" and python_version < "4.0" +pycparser==2.21 ; python_version >= "3.8" and python_version < "4.0" +pydantic==1.10.5 ; python_version >= "3.8" and python_version < "4.0" +pyee==8.1.0 ; python_version >= "3.8" and python_version < "4.0" +pygments==2.14.0 ; python_version >= "3.8" and python_version < "4.0" +pypng==0.20220715.0 ; python_version >= "3.8" and python_version < "4.0" +pyppeteer==1.0.2 ; python_version >= "3.8" and python_version < "4.0" +pyrogram==2.0.100 ; python_version >= "3.8" and python_version < "4.0" +pysocks==1.7.1 ; python_version >= "3.8" and python_version < "4.0" +pytest-asyncio==0.20.3 ; python_version >= "3.8" and python_version < "4.0" +pytest==7.2.2 ; python_version >= "3.8" and python_version < "4.0" +python-dotenv==1.0.0 ; python_version >= "3.8" and python_version < "4.0" +python-telegram-bot[ext,rate-limiter]==20.1 ; python_version >= "3.8" and python_version < "4.0" +pytz-deprecation-shim==0.1.0.post0 ; python_version >= "3.8" and python_version < "4.0" +pytz==2022.7.1 ; python_version >= "3.8" and python_version < "4.0" +pyyaml==6.0 ; python_version >= "3.8" and python_version < "4.0" +qrcode==7.4.2 ; python_version >= "3.8" and python_version < "4.0" +redis==4.5.1 ; python_version >= "3.8" and python_version < "4.0" +rfc3986[idna2008]==1.5.0 ; python_version >= "3.8" and python_version < "4.0" +rich==13.3.2 ; python_version >= "3.8" and python_version < "4.0" +sentry-sdk==1.16.0 ; python_version >= "3.8" and python_version < "4.0" +setuptools==67.5.1 ; python_version >= "3.8" and python_version < "4.0" +six==1.16.0 ; python_version >= "3.8" and python_version < "4.0" +smmap==5.0.0 ; python_version >= "3.8" and python_version < "4.0" +sniffio==1.3.0 ; python_version >= "3.8" and python_version < "4.0" +sortedcontainers==2.4.0 ; python_version >= "3.8" and python_version < "4.0" +soupsieve==2.4 ; python_version >= "3.8" and python_version < "4.0" +sqlalchemy2-stubs==0.0.2a32 ; python_version >= "3.8" and python_version < "4.0" +sqlalchemy==1.4.41 ; python_version >= "3.8" and python_version < "4.0" +sqlmodel==0.0.8 ; python_version >= "3.8" and python_version < "4.0" +starlette==0.25.0 ; python_version >= "3.8" and python_version < "4.0" +tgcrypto==1.2.5 ; python_version >= "3.8" and python_version < "4.0" +thefuzz==0.19.0 ; python_version >= "3.8" and python_version < "4.0" +tomli==2.0.1 ; python_version >= "3.8" and python_version < "3.11" +tornado==6.2 ; python_version >= "3.8" and python_version < "4.0" +tqdm==4.65.0 ; python_version >= "3.8" and python_version < "4.0" +typing-extensions==4.5.0 ; python_version >= "3.8" and python_version < "4.0" +tzdata==2022.7 ; python_version >= "3.8" and python_version < "4.0" +tzlocal==4.2 ; python_version >= "3.8" and python_version < "4.0" +ujson==5.7.0 ; python_version >= "3.8" and python_version < "4.0" +urllib3==1.26.14 ; python_version >= "3.8" and python_version < "4.0" +uvicorn[standard]==0.21.0 ; python_version >= "3.8" and python_version < "4.0" +uvloop==0.17.0 ; sys_platform != "win32" and sys_platform != "cygwin" and platform_python_implementation != "PyPy" and python_version >= "3.8" and python_version < "4.0" +watchfiles==0.18.1 ; python_version >= "3.8" and python_version < "4.0" +websockets==10.4 ; python_version >= "3.8" and python_version < "4.0" +yarl==1.8.2 ; python_version >= "3.8" and python_version < "4.0" +zipp==3.15.0 ; python_version >= "3.8" and python_version < "4.0" diff --git a/resources/genshin/daily_note/daily_note.html b/resources/genshin/daily_note/daily_note.html index d57fbadd..d67adc68 100644 --- a/resources/genshin/daily_note/daily_note.html +++ b/resources/genshin/daily_note/daily_note.html @@ -1,131 +1,131 @@ - - - - - - -
-
+ + + + + + +
+
- ID:{{ uid }} + ID:{{ uid }}
- {{ day }} + {{ day }}
-
-
+
+
-
- -
-
原粹树脂
-
- {% if resin_recovery_time %} - 将于{{ resin_recovery_time }} 全部恢复 - {% else %} - 树脂已完全恢复 - {% endif %} +
+ +
+
原粹树脂
+
+ {% if resin_recovery_time %} + 将于{{ resin_recovery_time }} 全部恢复 + {% else %} + 树脂已完全恢复 + {% endif %} +
-
{{ current_resin }}/{{ max_resin }}
-
-
+
+
-
- -
-
洞天宝钱
-
- {% if realm_recovery_time %} - 预计{{ realm_recovery_time }}后达到上限 - {% else %} - 存储已满 - {% endif %} +
+ +
+
洞天宝钱
+
+ {% if realm_recovery_time %} + 预计{{ realm_recovery_time }}后达到上限 + {% else %} + 存储已满 + {% endif %} +
-
{{ current_realm_currency }}/{{ max_realm_currency }}
-
-
+
+
-
- -
-
每日委托任务
-
今日委托奖励{% if claimed_commission_reward %}已{% else %}未{% endif %}领取
-
+
+ +
+
每日委托任务
+
今日委托奖励{% if claimed_commission_reward %}已{% else %}未{% endif %}领取
+
- {{ completed_commissions }}/{{ max_commissions }} + {{ completed_commissions }}/{{ max_commissions }}
-
-
+
+
-
- -
-
探索派遣
-
- {% if not expeditions %}尚未进行派遣 - {% elif remained_time %}将于{{ remained_time }}完成 - {% else %}派遣已完成{% endif %} +
+ +
+
探索派遣
+
+ {% if not expeditions %}尚未进行派遣 + {% elif remained_time %}将于{{ remained_time }}完成 + {% else %}派遣已完成{% endif %} +
-
- {{ current_expeditions }}/{{ max_expeditions }} + {{ current_expeditions }}/{{ max_expeditions }}
-
-
+
+
-
- -
-
值得铭记的强敌
-
- {% if remaining_resin_discounts<=0 %}周本已完成 - {% else %}周本树脂减半次数剩余{% endif %} +
+ +
+
值得铭记的强敌
+
+ {% if remaining_resin_discounts<=0 %}周本已完成 + {% else %}周本树脂减半次数剩余{% endif %} +
-
- {{ remaining_resin_discounts }}/{{ max_resin_discounts }} + {{ remaining_resin_discounts }}/{{ max_resin_discounts }}
-
-
+
+
-
- -
-
参量质变仪
-
- {% if transformer %} - {% if transformer_ready %}已准备完成 - {% else %}{{ transformer_recovery_time }}后可使用{% endif %} - {% else %} - 尚未获得 - {% endif %} +
+ +
+
参量质变仪
+
+ {% if transformer %} + {% if transformer_ready %}已准备完成 + {% else %}{{ transformer_recovery_time }}后可使用{% endif %} + {% else %} + 尚未获得 + {% endif %} +
-
{% if transformer %}{% if transformer_ready %}可使用{% else %}冷却中{% endif %}{% else %}尚未获得{% endif %}
-
- - +
+ + diff --git a/run.py b/run.py index c135eb1e..ce733e37 100644 --- a/run.py +++ b/run.py @@ -1,8 +1,39 @@ -from core.bot import bot +import asyncio + +from utils.const import PROJECT_ROOT + +try: + import uvloop + + asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) +except ImportError: + uvloop = None + + +def run(): + from core.application import Application + from dotenv import load_dotenv + + load_dotenv() + Application.build().launch() def main(): - bot.launch() + from core.builtins.reloader import Reloader + from core.config import config + + if config.auto_reload: # 是否启动重载器 + reload_config = config.reload + + Reloader( + run, + reload_delay=reload_config.delay, + reload_dirs=list(set(reload_config.dirs + [PROJECT_ROOT])), + reload_includes=reload_config.include, + reload_excludes=reload_config.exclude, + ).run() + else: + run() if __name__ == "__main__": diff --git a/tests/data/test_artifact.jpg b/tests/data/test_artifact.jpg deleted file mode 100644 index 30cc80d4..00000000 Binary files a/tests/data/test_artifact.jpg and /dev/null differ diff --git a/tests/integration/.env.example b/tests/integration/.env.example new file mode 100644 index 00000000..ded9802f --- /dev/null +++ b/tests/integration/.env.example @@ -0,0 +1,9 @@ +DB_PORT=3306 +DB_USERNAME=root +DB_PASSWORD=123456test +DB_DATABASE=integration_test + +REDIS_HOST=127.0.0.1 +REDIS_PORT=6379 +REDIS_DB=0 +REDIS_PASSWORD="" \ No newline at end of file diff --git a/tests/integration/__init__.py b/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py new file mode 100644 index 00000000..ee5f409b --- /dev/null +++ b/tests/integration/conftest.py @@ -0,0 +1,27 @@ +import asyncio + +import pytest +import pytest_asyncio + +from core.config import config +from core.dependence.mysql import MySQL +from core.dependence.redisdb import RedisDB + + +@pytest_asyncio.fixture(scope="session") +def event_loop(): + policy = asyncio.get_event_loop_policy() + res = policy.new_event_loop() + asyncio.set_event_loop(res) + yield res + res.close() + + +@pytest.fixture(scope="session") +def mysql(): + return MySQL.from_config(config=config) + + +@pytest.fixture(scope="session") +def redis(): + return RedisDB.from_config(config=config) diff --git a/tests/integration/test_mysql.py b/tests/integration/test_mysql.py new file mode 100644 index 00000000..aabadeb9 --- /dev/null +++ b/tests/integration/test_mysql.py @@ -0,0 +1,21 @@ +import logging + +import pytest +from sqlmodel import SQLModel + +from core.services.players.models import PlayersDataBase + +logger = logging.getLogger() +logger.info("%s", PlayersDataBase.__name__) + + +# noinspection PyShadowingNames +@pytest.mark.asyncio +async def test_mysql(mysql): + assert mysql + + +async def test_init_create_all(mysql): + async with mysql.engine.begin() as conn: + await conn.run_sync(SQLModel.metadata.drop_all) + await conn.run_sync(SQLModel.metadata.create_all) diff --git a/tests/integration/test_player_service.py b/tests/integration/test_player_service.py new file mode 100644 index 00000000..931c178e --- /dev/null +++ b/tests/integration/test_player_service.py @@ -0,0 +1,89 @@ +import logging + +import pytest_asyncio + +from core.basemodel import RegionEnum +from core.services.players import PlayersService +from core.services.players.models import PlayersDataBase +from core.services.players.repositories import PlayersRepository + +logger = logging.getLogger("TestPlayersService") + + +@pytest_asyncio.fixture(scope="class", name="players_service") +def service(mysql): + repository = PlayersRepository(mysql) + _players_service = PlayersService(repository) + return _players_service + + +class TestPlayersService: + @staticmethod + async def test_add_player(players_service: "PlayersService"): + data_base = PlayersDataBase( + user_id=1, + account_id=2, + player_id=3, + region=RegionEnum.HYPERION, + is_chosen=True, + ) + await players_service.add(data_base) + + @staticmethod + async def test_get_player_by_user_id(players_service: "PlayersService"): + result = await players_service.get(1) + assert isinstance(result, PlayersDataBase) + result = await players_service.get(1, region=RegionEnum.HYPERION) + assert isinstance(result, PlayersDataBase) + result = await players_service.get(1, region=RegionEnum.HOYOLAB) + assert not isinstance(result, PlayersDataBase) + assert result is None + + @staticmethod + async def test_remove_all_by_user_id(players_service): + await players_service.remove_all_by_user_id(1) + result = await players_service.get(1) + assert not isinstance(result, PlayersDataBase) + assert result is None + + @staticmethod + async def test_1(players_service: "PlayersService"): + """测试 绑定时 账号不存在 账号添加 多账号添加""" + results = await players_service.get_all_by_user_id(10) + assert len(results) == 0 # 账号不存在 + data_base = PlayersDataBase( + user_id=10, + account_id=2, + player_id=3, + region=RegionEnum.HYPERION, + is_chosen=1, + ) + await players_service.add(data_base) # 添加 + result = await players_service.get(10) + assert result.user_id == 10 + data_base = PlayersDataBase( + user_id=10, + account_id=3, + player_id=3, + region=RegionEnum.HYPERION, + is_chosen=True, + ) + results = await players_service.get_all_by_user_id(10) # 添加多账号,新的账号设置为主账号 + assert len(results) == 1 # 账号存在只有一个 + for result in results: + assert result.user_id == 10 + if result.is_chosen == 1: + result.is_chosen = 0 + await players_service.update(result) + await players_service.add(data_base) + results = await players_service.get_all_by_user_id(10) # check all + assert len(results) == 2 + for result in results: + assert result.user_id == 10 + if result.account_id == 3: + assert result.is_chosen == 1 + if result.account_id == 2: + assert result.is_chosen == 0 + await players_service.remove_all_by_user_id(10) + results = await players_service.get_all_by_user_id(10) + assert len(results) == 0 diff --git a/tests/integration/test_redis.py b/tests/integration/test_redis.py new file mode 100644 index 00000000..4d293fd1 --- /dev/null +++ b/tests/integration/test_redis.py @@ -0,0 +1,16 @@ +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from core.dependence.redisdb import RedisDB + + +@pytest.mark.asyncio +async def test_mysql(redis: "RedisDB"): + assert redis + assert redis.client + + +async def test_redis_ping(redis: "RedisDB"): + await redis.ping() diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_abyss_team_data.py b/tests/unit/test_abyss_team_data.py similarity index 100% rename from tests/test_abyss_team_data.py rename to tests/unit/test_abyss_team_data.py diff --git a/tests/test_hyperion.py b/tests/unit/test_hyperion.py similarity index 100% rename from tests/test_hyperion.py rename to tests/unit/test_hyperion.py diff --git a/tests/test_hyperion_bbs.py b/tests/unit/test_hyperion_bbs.py similarity index 100% rename from tests/test_hyperion_bbs.py rename to tests/unit/test_hyperion_bbs.py diff --git a/tests/test_wiki.py b/tests/unit/test_wiki.py similarity index 99% rename from tests/test_wiki.py rename to tests/unit/test_wiki.py index fba1fd94..f7496223 100644 --- a/tests/test_wiki.py +++ b/tests/unit/test_wiki.py @@ -1,6 +1,6 @@ import asyncio import logging -from random import sample, randint +from random import randint, sample from typing import Type import pytest diff --git a/utils/__init__.py b/utils/__init__.py index 22814ec7..e69de29b 100644 --- a/utils/__init__.py +++ b/utils/__init__.py @@ -1 +0,0 @@ -from utils.patch import * diff --git a/utils/baseobject.py b/utils/baseobject.py deleted file mode 100644 index 617b7e5d..00000000 --- a/utils/baseobject.py +++ /dev/null @@ -1,106 +0,0 @@ -import json -from copy import deepcopy -from typing import Dict, Union, Optional, List - -from utils.typedefs import JSONDict - - -class BaseObject: - """ - 大多数数据对象的基础类型 - """ - - def __new__(cls, *args: object, **kwargs: object) -> "BaseObject": - instance = super().__new__(cls) - return instance - - def __str__(self) -> str: - return str(self.to_dict()) - - def __getitem__(self, item: str) -> object: - try: - return getattr(self, item) - except AttributeError as exc: - raise KeyError( - f"Objects of type {self.__class__.__name__} don't have an attribute called " f"`{item}`." - ) from exc - - def __getstate__(self) -> Dict[str, Union[str, object]]: - return self._get_attrs(include_private=True, recursive=False) - - def __setstate__(self, state: dict) -> None: - for key, val in state.items(): - setattr(self, key, val) - - def __deepcopy__(self, memodict: dict = None): - if memodict is None: - memodict = {} - cls = self.__class__ - result = cls.__new__(cls) # 创建新实例 - attrs = self._get_attrs(include_private=True) # 获取其所有属性 - - for k in attrs: # 在DeepCopy对象中设置属性 - setattr(result, k, deepcopy(getattr(self, k), memodict)) - return result - - # 添加插槽可减少内存使用,并允许更快的属性访问 - __slots__ = () - - def _get_attrs( - self, - include_private: bool = False, - recursive: bool = False, - ) -> Dict[str, Union[str, object]]: - data = {} - if not recursive: - try: - # __dict__ 具有来自超类的属性,因此无需在下面的for循环中输入 - data.update(self.__dict__) - except AttributeError: - pass - # 我们希望使用self获取类的所有属性,但如果使用 self.__slots__ ,仅包括该类本身使用的属性,而不是它的超类 - # 因此,我们得到它的MRO,然后再得到它们的属性 - # 使用“[:-1]”切片排除了“object”类 - for cls in self.__class__.__mro__[:-1]: - for key in cls.__slots__: # 忽略 属性已定义 - if not include_private and key.startswith("_"): - continue - - value = getattr(self, key, None) - if value is not None: - if recursive and hasattr(value, "to_dict"): - data[key] = value.to_dict() - else: - data[key] = value - elif not recursive: - data[key] = value - - return data - - @staticmethod - def _parse_data(data: Optional[JSONDict]) -> Optional[JSONDict]: - return None if data is None else data.copy() - - @classmethod - def de_json(cls, data: Optional[JSONDict]): - data = cls._parse_data(data) - - if data is None: - return None - - if cls == BaseObject: - return cls() - return cls(**data) - - @classmethod - def de_list(cls, data: Optional[List[JSONDict]]) -> List: - if not data: - return [] - - return [cls.de_json(d) for d in data] - - def to_json(self) -> str: - return json.dumps(self.to_dict()) - - def to_dict(self) -> JSONDict: - return self._get_attrs(recursive=True) diff --git a/utils/bot.py b/utils/bot.py deleted file mode 100644 index c6f1f048..00000000 --- a/utils/bot.py +++ /dev/null @@ -1,46 +0,0 @@ -import json -from typing import List, cast, Union - -from telegram import Chat -from telegram.ext import CallbackContext - -from core.base.redisdb import RedisDB -from core.bot import bot - -redis_db = bot.services.get(RedisDB) -redis_db = cast(RedisDB, redis_db) - - -async def get_chat(chat_id: Union[str, int], ttl: int = 86400) -> Chat: - if not redis_db: - return await bot.app.bot.get_chat(chat_id) - qname = f"bot:chat:{chat_id}" - data = await redis_db.client.get(qname) - if data: - json_data = json.loads(data) - return Chat.de_json(json_data, bot.app.bot) - chat_info = await bot.app.bot.get_chat(chat_id) - await redis_db.client.set(qname, chat_info.to_json()) - await redis_db.client.expire(qname, ttl) - return chat_info - - -def get_args(context: CallbackContext) -> List[str]: - args = context.args - match = context.match - if args is None: - if match is not None: - groups = match.groups() - command = groups[0] - if command: - temp = [] - command_parts = command.split(" ") - for command_part in command_parts: - if command_part: - temp.append(command_part) - return temp - return [] - else: - if len(args) >= 1: - return args - return [] diff --git a/utils/const.py b/utils/const.py deleted file mode 100644 index 34e945b1..00000000 --- a/utils/const.py +++ /dev/null @@ -1,32 +0,0 @@ -"""常量""" -from pathlib import Path - -from httpx import URL - -__all__ = [ - "PROJECT_ROOT", - "CORE_DIR", - "PLUGIN_DIR", - "RESOURCE_DIR", - "NOT_SET", - "HONEY_HOST", - "ENKA_HOST", - "AMBR_HOST", - "CELESTIA_HOST", -] - -# 项目根目录 -PROJECT_ROOT = Path(__file__).joinpath("../..").resolve() -# Core 目录 -CORE_DIR = PROJECT_ROOT / "core" -# 插件目录 -PLUGIN_DIR = PROJECT_ROOT / "plugins" -# 资源目录 -RESOURCE_DIR = PROJECT_ROOT / "resources" - -NOT_SET = object() - -HONEY_HOST = URL("https://genshin.honeyhunterworld.com/") -ENKA_HOST = URL("https://enka.network/") -AMBR_HOST = URL("https://api.ambr.top/") -CELESTIA_HOST = URL("https://www.projectcelestia.com/") diff --git a/utils/const/__init__.py b/utils/const/__init__.py new file mode 100644 index 00000000..cedd98b1 --- /dev/null +++ b/utils/const/__init__.py @@ -0,0 +1,35 @@ +"""一些常量""" +from functools import WRAPPER_ASSIGNMENTS as _WRAPPER_ASSIGNMENTS +from typing import List + +from core.basemodel import RegionEnum +from utils.const._path import * +from utils.const._single import * +from utils.const._url import * + +NOT_SET = object() +# noinspection PyTypeChecker +WRAPPER_ASSIGNMENTS: List[str] = list(_WRAPPER_ASSIGNMENTS) + [ + "block", + "_catch_targets", + "_handler_datas", + "_conversation_handler_data", + "_error_handler_data", + "_job_data", +] + +USER_AGENT: str = ( + "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/90.0.4430.72 Safari/537.36" +) +REQUEST_HEADERS: dict = {"User-Agent": USER_AGENT} + +REGION_MAP = { + "1": RegionEnum.HYPERION, + "2": RegionEnum.HYPERION, + "5": RegionEnum.HYPERION, + "6": RegionEnum.HOYOLAB, + "7": RegionEnum.HOYOLAB, + "8": RegionEnum.HOYOLAB, + "9": RegionEnum.HOYOLAB, +} diff --git a/utils/const/_path.py b/utils/const/_path.py new file mode 100644 index 00000000..f7fc95a2 --- /dev/null +++ b/utils/const/_path.py @@ -0,0 +1,18 @@ +"""目录常量""" +from pathlib import Path + +__all__ = ["PROJECT_ROOT", "CORE_DIR", "PLUGIN_DIR", "RESOURCE_DIR", "CACHE_DIR"] + +# 项目根目录 +PROJECT_ROOT = Path(__file__).joinpath("../../..").resolve() +# Core 目录 +CORE_DIR = PROJECT_ROOT / "core" +# 插件目录 +PLUGIN_DIR = PROJECT_ROOT / "plugins" +# 资源目录 +RESOURCE_DIR = PROJECT_ROOT / "resources" +# cache 目录 +CACHE_DIR = PROJECT_ROOT / "cache" + +if not CACHE_DIR.exists(): + CACHE_DIR.mkdir(exist_ok=True, parents=True) diff --git a/utils/const/_single.py b/utils/const/_single.py new file mode 100644 index 00000000..be293100 --- /dev/null +++ b/utils/const/_single.py @@ -0,0 +1,5 @@ +from signal import SIGABRT, SIGINT, SIGTERM + +__all__ = ("HANDLED_SIGNALS",) + +HANDLED_SIGNALS = SIGINT, SIGTERM, SIGABRT diff --git a/utils/const/_url.py b/utils/const/_url.py new file mode 100644 index 00000000..ee5c279e --- /dev/null +++ b/utils/const/_url.py @@ -0,0 +1,8 @@ +from httpx import URL + +__all__ = ("HONEY_HOST", "ENKA_HOST", "AMBR_HOST", "CELESTIA_HOST") + +HONEY_HOST = URL("https://genshin.honeyhunterworld.com/") +ENKA_HOST = URL("https://enka.network/") +AMBR_HOST = URL("https://api.ambr.top/") +CELESTIA_HOST = URL("https://www.projectcelestia.com/") diff --git a/utils/decorator.py b/utils/decorator.py new file mode 100644 index 00000000..6b3b9a26 --- /dev/null +++ b/utils/decorator.py @@ -0,0 +1,11 @@ +from contextlib import contextmanager + +__all__ = ["do_nothing"] + + +@contextmanager +def do_nothing(): + try: + yield + finally: + ... diff --git a/utils/decorators/admins.py b/utils/decorators/admins.py deleted file mode 100644 index b3dfe5fa..00000000 --- a/utils/decorators/admins.py +++ /dev/null @@ -1,38 +0,0 @@ -from functools import wraps -from typing import Callable, cast - -from telegram import Update - -from core.admin import BotAdminService -from core.bot import bot -from core.error import ServiceNotFoundError - -bot_admin_service = bot.services.get(BotAdminService) - - -def bot_admins_rights_check(func: Callable) -> Callable: - """BOT ADMIN 权限检查""" - - @wraps(func) - async def decorator(*args, **kwargs): - if len(args) == 3: - # self update context - _, update, _ = args - elif len(args) == 2: - # update context - update, _ = args - else: - return await func(*args, **kwargs) - if bot_admin_service is None: - raise ServiceNotFoundError("BotAdminService") - admin_list = await bot_admin_service.get_admin_list() - update = cast(Update, update) - message = update.effective_message - user = update.effective_user - if user.id in admin_list: - return await func(*args, **kwargs) - else: - await message.reply_text("权限不足") - return None - - return decorator diff --git a/utils/decorators/error.py b/utils/decorators/error.py deleted file mode 100644 index 11019e63..00000000 --- a/utils/decorators/error.py +++ /dev/null @@ -1,176 +0,0 @@ -import json -from functools import wraps -from typing import Callable, cast, Optional - -from aiohttp import ClientConnectorError -from genshin import InvalidCookies, GenshinException, TooManyRequests, DataNotPublic -from httpx import ConnectTimeout -from telegram import Update, ReplyKeyboardRemove, InlineKeyboardButton, InlineKeyboardMarkup, Message -from telegram.error import BadRequest, TimedOut, Forbidden -from telegram.ext import CallbackContext, ConversationHandler, filters -from telegram.helpers import create_deep_linked_url - -from core.baseplugin import add_delete_message_job -from modules.apihelper.error import APIHelperException, ReturnCodeError, APIHelperTimedOut, ResponseException -from utils.error import UrlResourcesNotFoundError -from utils.log import logger - - -async def send_user_notification(update: Update, context: CallbackContext, text: str) -> Optional[Message]: - if not isinstance(update, Update): - logger.warning("错误的消息类型 %s", repr(update)) - return None - if update.inline_query is not None: # 忽略 inline_query - return None - if "重新绑定" in text: - buttons = InlineKeyboardMarkup( - [[InlineKeyboardButton("点我重新绑定", url=create_deep_linked_url(context.bot.username, "set_cookie"))]] - ) - elif "通过验证" in text: - buttons = InlineKeyboardMarkup( - [[InlineKeyboardButton("点我通过验证", url=create_deep_linked_url(context.bot.username, "verify_verification"))]] - ) - else: - buttons = ReplyKeyboardRemove() - user = update.effective_user - message = update.effective_message - chat = update.effective_chat - if message is None: - update_str = update.to_dict() if isinstance(update, Update) else str(update) - logger.warning("错误的消息类型\n %s", json.dumps(update_str, indent=2, ensure_ascii=False)) - return None - if chat.id == user.id: - logger.info("尝试通知用户 %s[%s] 错误信息[%s]", user.full_name, user.id, text) - else: - logger.info("尝试通知用户 %s[%s] 在 %s[%s] 的错误信息[%s]", user.full_name, user.id, chat.title, chat.id, text) - try: - if update.callback_query: - await update.callback_query.answer(text, show_alert=True) - return None - return await message.reply_text(text, reply_markup=buttons, allow_sending_without_reply=True) - except TimedOut: - logger.error("发送 update_id[%s] 错误信息失败 连接超时", update.update_id) - except BadRequest as exc: - logger.error("发送 update_id[%s] 错误信息失败 错误信息为 [%s]", update.update_id, exc.message) - except Forbidden as exc: - logger.error("发送 update_id[%s] 错误信息失败 错误信息为 [%s]", update.update_id, exc.message) - except Exception as exc: - logger.error("发送 update_id[%s] 错误信息失败 错误信息为 [%s]", update.update_id, repr(exc)) - logger.exception(exc) - return None - - -def telegram_warning(update: Update, text: str): - user = update.effective_user - message = update.effective_message - chat = update.effective_chat - msg = f"{text}\n user_id[{user.id}] chat_id[{chat.id}] message_id[{message.id}]" - logger.warning(msg) - - -def error_callable(func: Callable) -> Callable: - """Plugins 错误处理修饰器 - - 非常感谢 @Bibo-Joshi 提出的建议 - """ - - @wraps(func) - async def decorator(*args, **kwargs): - if len(args) == 3: - # self update context - _, update, context = args - elif len(args) == 2: - # update context - update, context = args - else: - return await func(*args, **kwargs) - update = cast(Update, update) - context = cast(CallbackContext, context) - text: str = "" - try: - return await func(*args, **kwargs) - except ClientConnectorError: - logger.error("aiohttp 模块连接服务器 ClientConnectorError") - text = "出错了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ 请稍后再试" - except ConnectTimeout: - logger.error("httpx 模块连接服务器 ConnectTimeout") - text = "出错了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ 请稍后再试" - except TimedOut: - logger.error("python-telegram-bot 模块连接服务器 TimedOut") - text = "出错了呜呜呜 ~ 服务器连接超时 服务器熟啦 ~ 请稍后再试" - except UrlResourcesNotFoundError as exc: - logger.error("URL数据资源未找到") - logger.exception(exc) - text = "出错了呜呜呜 ~ 资源未找到 ~ " - except InvalidCookies as exc: - if exc.retcode in (10001, -100): - text = "出错了呜呜呜 ~ Cookie 无效,请尝试重新绑定" - elif exc.retcode == 10103: - text = "出错了呜呜呜 ~ Cookie 有效,但没有绑定到游戏帐户,请尝试登录通行证,在账号管理里面选择账号游戏信息,将原神设置为默认角色。" - else: - logger.warning("Cookie错误") - logger.exception(exc) - text = f"出错了呜呜呜 ~ Cookie 无效 错误信息为 {exc.original} 请尝试重新绑定" - except TooManyRequests as exc: - logger.warning("查询次数太多(操作频繁) %s", exc) - text = "出错了呜呜呜 ~ 当天查询次数已经超过30次,请次日再进行查询" - except DataNotPublic: - text = "出错了呜呜呜 ~ 查询的用户数据未公开" - except GenshinException as exc: - if exc.retcode == -130: - text = "出错了呜呜呜 ~ 未设置默认角色,请尝试重新绑定" - elif exc.retcode == 1034: - text = "出错了呜呜呜 ~ 服务器检测到该账号可能存在异常,请求被拒绝,请尝试通过验证" - elif exc.retcode == -500001: - text = "出错了呜呜呜 ~ 网络出小差了,请稍后重试~" - elif exc.retcode == -1: - text = "出错了呜呜呜 ~ 系统发生错误,请稍后重试~" - elif exc.retcode == -10001: # 参数异常 应该抛出错误 - raise exc - else: - logger.error("GenshinException") - logger.exception(exc) - text = f"出错了呜呜呜 ~ 获取账号信息发生错误 错误信息为 {exc.original if exc.original else exc.retcode} ~ 请稍后再试" - except ReturnCodeError as exc: - text = f"出错了呜呜呜 ~ API请求错误 错误信息为 {exc.message} ~ 请稍后再试" - except APIHelperTimedOut: - logger.warning("APIHelperException") - text = "出错了呜呜呜 ~ API请求超时 ~ 请稍后再试" - except ResponseException as exc: - logger.error("APIHelperException [%s]%s", exc.code, exc.message) - text = f"出错了呜呜呜 ~ API请求错误 错误信息为 {exc.message if exc.message else exc.code} ~ 请稍后再试" - except APIHelperException as exc: - logger.error("APIHelperException") - logger.exception(exc) - text = "出错了呜呜呜 ~ API请求错误 ~ 请稍后再试" - except BadRequest as exc: - if "Replied message not found" in exc.message: - telegram_warning(update, exc.message) - text = "气死我了!怎么有人喜欢发一个命令就秒删了!" - elif "Message is not modified" in exc.message: - telegram_warning(update, exc.message) - elif "Not enough rights" in exc.message: - telegram_warning(update, exc.message) - text = "出错了呜呜呜 ~ 权限不足,请检查对应权限是否开启" - else: - logger.error("python-telegram-bot 请求错误") - logger.exception(exc) - text = "出错了呜呜呜 ~ telegram-bot-api请求错误 ~ 请稍后再试" - except Forbidden as exc: - logger.error("python-telegram-bot 返回 Forbidden") - logger.exception(exc) - text = "出错了呜呜呜 ~ telegram-bot-api请求错误 ~ 请稍后再试" - if text: - notice_message = await send_user_notification(update, context, text) - message = update.effective_message - if message and not update.callback_query and filters.ChatType.GROUPS.filter(message): - if notice_message: - add_delete_message_job(context, notice_message.chat_id, notice_message.message_id, 60) - add_delete_message_job(context, message.chat_id, message.message_id, 60) - else: - user = update.effective_user - chat = update.effective_chat - logger.error("发送 %s[%s] 在 %s[%s] 的通知出现问题 通知文本不存在", user.full_name, user.id, chat.full_name, chat.id) - return ConversationHandler.END - - return decorator diff --git a/utils/decorators/restricts.py b/utils/decorators/restricts.py deleted file mode 100644 index 8fba4b35..00000000 --- a/utils/decorators/restricts.py +++ /dev/null @@ -1,113 +0,0 @@ -import asyncio -import time -from functools import wraps -from typing import Callable, cast, Optional, Any - -from telegram import Update -from telegram.ext import filters, CallbackContext - -from utils.log import logger - -_lock = asyncio.Lock() - - -def restricts( - restricts_time: int = 9, - restricts_time_of_groups: Optional[int] = None, - return_data: Any = None, - without_overlapping: bool = False, -): - """用于装饰在指定函数预防洪水攻击的装饰器 - - 被修饰的函数生声明必须为 - - async def command_func(update, context) - 或 - async def command_func(self, update, context - - 如果修饰的函数属于 - ConversationHandler - 参数 - return_data - 必须传入 - ConversationHandler.END - - 我真™是服了某些闲着没事干的群友了 - - :param restricts_time: 基础限制时间 - :param restricts_time_of_groups: 对群限制的时间 - :param return_data: 返回的数据对于 ConversationHandler 需要传入 ConversationHandler.END - :param without_overlapping: 两次命令时间不覆盖,在上一条一样的命令返回之前,忽略重复调用 - """ - - def decorator(func: Callable): - @wraps(func) - async def restricts_func(*args, **kwargs): - if len(args) == 3: - # self update context - _, update, context = args - elif len(args) == 2: - # update context - update, context = args - else: - return await func(*args, **kwargs) - update = cast(Update, update) - context = cast(CallbackContext, context) - message = update.effective_message - user = update.effective_user - - _restricts_time = restricts_time - if restricts_time_of_groups is not None and filters.ChatType.GROUPS.filter(message): - _restricts_time = restricts_time_of_groups - - async with _lock: - user_lock = context.user_data.get("lock") - if user_lock is None: - user_lock = context.user_data["lock"] = asyncio.Lock() - - # 如果上一个命令还未完成,忽略后续重复调用 - if without_overlapping and user_lock.locked(): - logger.warning("用户 %s[%s] 触发 overlapping 该次命令已忽略", user.full_name, user.id) - return return_data - - async with user_lock: - command_time = context.user_data.get("command_time", 0) - count = context.user_data.get("usage_count", 0) - restrict_since = context.user_data.get("restrict_since", 0) - - # 洪水防御 - if restrict_since: - if (time.time() - restrict_since) >= 60: - del context.user_data["restrict_since"] - del context.user_data["usage_count"] - else: - return return_data - else: - if count >= 6: - context.user_data["restrict_since"] = time.time() - if update.callback_query: - await update.callback_query.answer("你已经触发洪水防御,请等待60秒", show_alert=True) - else: - await message.reply_text("你已经触发洪水防御,请等待60秒") - logger.warning("用户 %s[%s] 触发洪水限制 已被限制60秒", user.full_name, user.id) - return return_data - # 单次使用限制 - if command_time: - if (time.time() - command_time) <= _restricts_time: - context.user_data["usage_count"] = count + 1 - else: - if count >= 1: - context.user_data["usage_count"] = count - 1 - context.user_data["command_time"] = time.time() - - # 只需要给 without_overlapping 的代码加锁运行 - if without_overlapping: - return await func(*args, **kwargs) - - if count > 1: - await asyncio.sleep(count) - return await func(*args, **kwargs) - - return restricts_func - - return decorator diff --git a/utils/enums.py b/utils/enums.py new file mode 100644 index 00000000..d6fbf8eb --- /dev/null +++ b/utils/enums.py @@ -0,0 +1,13 @@ +from enum import IntEnum + +__all__ = ("Priority",) + + +class Priority(IntEnum): + """优先级""" + + Lowest = 0 + Low = 4 + Normal = 8 + High = 12 + Highest = 16 diff --git a/utils/error.py b/utils/error.py index f29729ab..3c177d8f 100644 --- a/utils/error.py +++ b/utils/error.py @@ -1,4 +1,4 @@ -"""此模块包含BOT的错误的基类""" +"""此模块包含BOT Utils的错误的基类""" class NotFoundError(Exception): diff --git a/utils/helpers.py b/utils/helpers.py index 7c8070f0..948a4bec 100644 --- a/utils/helpers.py +++ b/utils/helpers.py @@ -1,163 +1,45 @@ -from __future__ import annotations - import hashlib import os import re +from abc import ABC from asyncio import create_subprocess_shell -from asyncio.subprocess import PIPE -from inspect import iscoroutinefunction +from functools import lru_cache +from inspect import isabstract as inspect_isabstract, iscoroutinefunction from pathlib import Path -from typing import Awaitable, Callable, Match, Optional, Pattern, Tuple, TypeVar, Union, cast +from typing import Awaitable, Callable, Iterator, Match, Pattern, Type, TypeVar, Union import aiofiles -import genshin import httpx -from genshin import Client, types from httpx import UnsupportedProtocol from typing_extensions import ParamSpec -from core.base.redisdb import RedisDB -from core.bot import bot -from core.config import config -from core.cookies.services import CookiesService, PublicCookiesService -from core.error import ServiceNotFoundError -from core.user.services import UserService -from utils.error import UrlResourcesNotFoundError -from utils.log import logger -from utils.models.base import RegionEnum +from utils.const import REQUEST_HEADERS + +__all__ = ("sha1", "gen_pkg", "async_re_sub", "execute", "isabstract", "download_resource") + T = TypeVar("T") P = ParamSpec("P") -USER_AGENT: str = ( - "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/90.0.4430.72 Safari/537.36" -) -REQUEST_HEADERS: dict = {"User-Agent": USER_AGENT} -current_dir = os.getcwd() -cache_dir = os.path.join(current_dir, "cache") +cache_dir = os.path.join(os.getcwd(), "cache") if not os.path.exists(cache_dir): os.mkdir(cache_dir) -cookies_service = bot.services.get(CookiesService) -cookies_service = cast(CookiesService, cookies_service) -user_service = bot.services.get(UserService) -user_service = cast(UserService, user_service) -public_cookies_service = bot.services.get(PublicCookiesService) -public_cookies_service = cast(PublicCookiesService, public_cookies_service) -redis_db = bot.services.get(RedisDB) -redis_db = cast(RedisDB, redis_db) -genshin_cache: Optional[genshin.RedisCache] = None -if redis_db and config.genshin_ttl: - genshin_cache = genshin.RedisCache(redis_db.client, ttl=config.genshin_ttl) - -REGION_MAP = { - "1": RegionEnum.HYPERION, - "2": RegionEnum.HYPERION, - "5": RegionEnum.HYPERION, - "6": RegionEnum.HOYOLAB, - "7": RegionEnum.HOYOLAB, - "8": RegionEnum.HOYOLAB, - "9": RegionEnum.HOYOLAB, -} - +@lru_cache(64) def sha1(text: str) -> str: - _sha1 = hashlib.sha1() + _sha1 = hashlib.sha1() # nosec B303 _sha1.update(text.encode()) return _sha1.hexdigest() -async def url_to_file(url: str, return_path: bool = False) -> str: - url_sha1 = sha1(url) - url_file_name = os.path.basename(url) - _, extension = os.path.splitext(url_file_name) - temp_file_name = url_sha1 + extension - file_dir = os.path.join(cache_dir, temp_file_name) - if not os.path.exists(file_dir): - async with httpx.AsyncClient(headers=REQUEST_HEADERS) as client: - try: - data = await client.get(url) - except UnsupportedProtocol: - logger.error("连接不支持 url[%s]", url) - return "" - if data.is_error: - logger.error("请求出现错误 url[%s] status_code[%s]", url, data.status_code) - raise UrlResourcesNotFoundError(url) - if data.status_code != 200: - logger.error("url_to_file 获取url[%s] 错误 status_code[%s]", url, data.status_code) - raise UrlResourcesNotFoundError(url) - async with aiofiles.open(file_dir, mode="wb") as f: - await f.write(data.content) - logger.debug("url_to_file 获取url[%s] 并下载到 file_dir[%s]", url, file_dir) - - return file_dir if return_path else Path(file_dir).as_uri() - - -async def get_genshin_client(user_id: int, region: Optional[RegionEnum] = None, need_cookie: bool = True) -> Client: - if user_service is None: - raise ServiceNotFoundError(UserService) - if cookies_service is None: - raise ServiceNotFoundError(CookiesService) - user = await user_service.get_user_by_id(user_id) - if region is None: - region = user.region - cookies = None - if need_cookie: - cookies = await cookies_service.get_cookies(user_id, region) - cookies = cookies.cookies - if region == RegionEnum.HYPERION: - uid = user.yuanshen_uid - client = genshin.Client(cookies=cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE, uid=uid) - elif region == RegionEnum.HOYOLAB: - uid = user.genshin_uid - client = genshin.Client( - cookies=cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn", uid=uid - ) - else: - raise TypeError("region is not RegionEnum.NULL") - if genshin_cache: - client.cache = genshin_cache - return client - - -async def get_public_genshin_client(user_id: int) -> Tuple[Client, Optional[int]]: - if user_service is None: - raise ServiceNotFoundError(UserService) - if public_cookies_service is None: - raise ServiceNotFoundError(PublicCookiesService) - user = await user_service.get_user_by_id(user_id) - region = user.region - cookies = await public_cookies_service.get_cookies(user_id, region) - if region == RegionEnum.HYPERION: - uid = user.yuanshen_uid - client = genshin.Client(cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.CHINESE) - elif region == RegionEnum.HOYOLAB: - uid = user.genshin_uid - client = genshin.Client( - cookies=cookies.cookies, game=types.Game.GENSHIN, region=types.Region.OVERSEAS, lang="zh-cn" - ) - else: - raise TypeError("region is not RegionEnum.NULL") - if genshin_cache: - client.cache = genshin_cache - return client, uid - - -def region_server(uid: Union[int, str]) -> RegionEnum: - if isinstance(uid, (int, str)): - region = REGION_MAP.get(str(uid)[0]) - else: - raise TypeError("UID variable type error") - if region: - return region - else: - raise TypeError(f"UID {uid} isn't associated with any region") - - -async def execute(command, pass_error=True): +async def execute(command: Union[str, bytes], pass_error: bool = True) -> str: """Executes command and returns output, with the option of enabling stderr.""" - executor = await create_subprocess_shell(command, stdout=PIPE, stderr=PIPE, stdin=PIPE) + from asyncio import subprocess + + executor = await create_subprocess_shell( + command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, stdin=subprocess.PIPE + ) stdout, stderr = await executor.communicate() if pass_error: @@ -174,8 +56,8 @@ async def execute(command, pass_error=True): async def async_re_sub( - pattern: str | Pattern, - repl: str | Callable[[Match], str] | Callable[[Match], Awaitable[str]], + pattern: Union[str, Pattern], + repl: Union[str, Callable[[Match], Union[Awaitable[str], str]]], string: str, count: int = 0, flags: int = 0, @@ -218,3 +100,44 @@ async def async_re_sub( result += temp[: match.span(1)[0]] + (replaced or repl) temp = temp[match.span(1)[1] :] return result + temp + + +def gen_pkg(path: Path) -> Iterator[str]: + """遍历 path 生成可以用于 import_module 导入的字符串 + + 注意: 此方法会遍历当前目录下所有的、文件名为以非 '_' 开头的 '.py' 文件,并将他们导入 + """ + from utils.const import PROJECT_ROOT + + for p in path.iterdir(): + if not p.name.startswith("_"): + if p.is_dir(): + yield from gen_pkg(p) + elif p.suffix == ".py": + yield str(p.relative_to(PROJECT_ROOT).with_suffix("")).replace(os.sep, ".") + + +def isabstract(target: Type) -> bool: + return any([inspect_isabstract(target), isinstance(target, type) and ABC in target.__bases__]) + + +async def download_resource(url: str, return_path: bool = False, timeout: float = 20) -> str: + url_sha1 = sha1(url) + url_file_name = os.path.basename(url) + _, extension = os.path.splitext(url_file_name) + temp_file_name = url_sha1 + extension + file_dir = os.path.join(cache_dir, temp_file_name) + if not os.path.exists(file_dir): + async with httpx.AsyncClient(headers=REQUEST_HEADERS, timeout=timeout) as client: + try: + data = await client.get(url) + except UnsupportedProtocol as exc: + raise RuntimeError("Unsupported Protocol") from exc + if data.is_error and data.status_code == 200: + raise RuntimeError("Request Error") + if data.status_code != 200: + raise RuntimeError("Request Error, Status Code", data.status_code) + async with aiofiles.open(file_dir, mode="wb") as f: + await f.write(data.content) + + return file_dir if return_path else Path(file_dir).as_uri() diff --git a/utils/log/__init__.py b/utils/log/__init__.py index 489232c1..f7c119ae 100644 --- a/utils/log/__init__.py +++ b/utils/log/__init__.py @@ -4,15 +4,12 @@ from core.config import config from utils.log._config import LoggerConfig -from utils.log._logger import ( - LogFilter, - Logger, -) +from utils.log._logger import LogFilter, Logger if TYPE_CHECKING: from logging import LogRecord -__all__ = ["logger"] +__all__ = ("logger",) logger = Logger( LoggerConfig( @@ -30,13 +27,17 @@ @lru_cache -def _name_filter(record_name: str) -> bool: +def _whitelist_name_filter(record_name: str) -> bool: + """白名单过滤器""" return any(re.match(rf"^{name}.*?$", record_name) for name in config.logger.filtered_names + [config.logger.name]) def name_filter(record: "LogRecord") -> bool: - """默认的过滤器""" - return _name_filter(record.name) + """默认的过滤器. 白名单 + + 根据当前的 record 的 name 判断是否需要打印。如果应该打印,则返回 True;否则返回 False。 + """ + return _whitelist_name_filter(record.name) log_filter = LogFilter() diff --git a/utils/log/_config.py b/utils/log/_config.py index a887a334..dda61779 100644 --- a/utils/log/_config.py +++ b/utils/log/_config.py @@ -1,16 +1,12 @@ from multiprocessing import RLock as Lock from pathlib import Path -from typing import ( - List, - Optional, - Union, -) +from typing import List, Optional, Union from pydantic import BaseSettings from utils.const import PROJECT_ROOT -__all__ = ["LoggerConfig"] +__all__ = ("LoggerConfig",) class LoggerConfig(BaseSettings): @@ -21,7 +17,7 @@ def __new__(cls, *args, **kwargs) -> "LoggerConfig": with cls._lock: if cls._instance is None: cls.update_forward_refs() - result = super(LoggerConfig, cls).__new__(cls) + result = super(LoggerConfig, cls).__new__(cls) # pylint: disable=E1120 result.__init__(*args, **kwargs) cls._instance = result return cls._instance @@ -30,7 +26,7 @@ def __new__(cls, *args, **kwargs) -> "LoggerConfig": level: Optional[Union[str, int]] = None debug: bool = False - width: int = 180 + width: Optional[int] = None keywords: List[str] = [] time_format: str = "[%Y-%m-%d %X]" capture_warnings: bool = True diff --git a/utils/log/_file.py b/utils/log/_file.py index ee5efb1d..7b614c07 100644 --- a/utils/log/_file.py +++ b/utils/log/_file.py @@ -2,7 +2,7 @@ from datetime import date from pathlib import Path from types import TracebackType -from typing import AnyStr, IO, Iterable, Iterator, List, Optional, Type +from typing import IO, AnyStr, Iterable, Iterator, List, Optional, Type __all__ = ["FileIO"] diff --git a/utils/log/_handler.py b/utils/log/_handler.py index 9656859d..caa8a42f 100644 --- a/utils/log/_handler.py +++ b/utils/log/_handler.py @@ -3,27 +3,12 @@ import sys from datetime import datetime from pathlib import Path -from typing import ( - Any, - Callable, - Iterable, - List, - Literal, - Optional, - TYPE_CHECKING, - Union, -) +from typing import Any, Callable, Iterable, List, Literal, Optional, TYPE_CHECKING, Union from rich.console import Console -from rich.logging import ( - LogRender as DefaultLogRender, - RichHandler as DefaultRichHandler, -) +from rich.logging import LogRender as DefaultLogRender, RichHandler as DefaultRichHandler from rich.table import Table -from rich.text import ( - Text, - TextType, -) +from rich.text import Text, TextType from rich.theme import Theme from utils.log._file import FileIO @@ -31,13 +16,11 @@ from utils.log._traceback import Traceback if TYPE_CHECKING: - from rich.console import ( # pylint: disable=unused-import - ConsoleRenderable, - RenderableType, - ) from logging import LogRecord # pylint: disable=unused-import -__all__ = ["LogRender", "Handler", "FileHandler"] + from rich.console import ConsoleRenderable, RenderableType # pylint: disable=unused-import + +__all__ = ("LogRender", "Handler", "FileHandler") FormatTimeCallable = Callable[[datetime], Text] @@ -169,13 +152,16 @@ def render( time_format = None if self.formatter is None else self.formatter.datefmt log_time = datetime.fromtimestamp(record.created) + if not traceback: + traceback_content = [message_renderable] + elif message_renderable is not None: + traceback_content = [message_renderable, traceback] + else: + traceback_content = [traceback] + log_renderable = self._log_render( self.console, - ( - [message_renderable] - if not traceback - else ([message_renderable, traceback] if message_renderable is not None else [traceback]) - ), + traceback_content, log_time=log_time, time_format=time_format, level=_level, @@ -191,8 +177,10 @@ def render_message( message: Any, ) -> "ConsoleRenderable": use_markup = getattr(record, "markup", self.markup) + tag = getattr(record, "tag", None) if isinstance(message, str): message_text = Text.from_markup(message) if use_markup else Text(message) + message_text = (Text.from_markup(f"[purple][{tag}][/]") + message_text) if tag is not None else message_text highlighter = getattr(record, "highlighter", self.highlighter) else: from rich.highlighter import JSONHighlighter @@ -200,6 +188,7 @@ def render_message( highlighter = JSONHighlighter() message_text = JSON.from_data(message, indent=4).text + message_text = (Text.from_markup(f"[purple][{tag}][/]") + message_text) if tag is not None else message_text if highlighter is not None: # noinspection PyCallingNonCallable diff --git a/utils/log/_logger.py b/utils/log/_logger.py index de74ae99..66a391d8 100644 --- a/utils/log/_logger.py +++ b/utils/log/_logger.py @@ -6,31 +6,19 @@ from multiprocessing import RLock as Lock from pathlib import Path from types import TracebackType -from typing import ( - Any, - Callable, - List, - Mapping, - Optional, - TYPE_CHECKING, - Tuple, - Type, - Union, -) +from typing import TYPE_CHECKING, Any, Callable, List, Mapping, Optional, Tuple, Type, Union from typing_extensions import Self -from utils.log._handler import ( - FileHandler, - Handler, -) +from utils.log._handler import FileHandler, Handler from utils.typedefs import LogFilterType if TYPE_CHECKING: - from utils.log._config import LoggerConfig # pylint: disable=unused-import from logging import LogRecord # pylint: disable=unused-import -__all__ = ["Logger", "LogFilter"] + from utils.log._config import LoggerConfig # pylint: disable=unused-import + +__all__ = ("Logger", "LogFilter") SysExcInfoType = Union[ Tuple[Type[BaseException], BaseException, Optional[TracebackType]], @@ -42,7 +30,7 @@ NONE = object() -class Logger(logging.Logger): +class Logger(logging.Logger): # skipcq: PY-A6006 _instance: Optional["Logger"] = None def __new__(cls, *args, **kwargs) -> "Logger": @@ -68,6 +56,7 @@ def __init__(self, config: "LoggerConfig" = None) -> None: # 控制台 log 配置 Handler( width=self.config.width, + keywords=self.config.keywords, locals_max_length=self.config.traceback_locals_max_length, locals_max_string=self.config.traceback_locals_max_string, locals_max_depth=self.config.traceback_locals_max_depth, @@ -77,6 +66,7 @@ def __init__(self, config: "LoggerConfig" = None) -> None: # debug.log 配置 FileHandler( width=self.config.width, + keywords=self.config.keywords, level=10, path=log_path.joinpath("debug/debug.log"), locals_max_depth=1, @@ -88,6 +78,7 @@ def __init__(self, config: "LoggerConfig" = None) -> None: # error.log 配置 FileHandler( width=self.config.width, + keywords=self.config.keywords, level=40, path=log_path.joinpath("error/error.log"), locals_max_length=self.config.traceback_locals_max_length, @@ -103,7 +94,7 @@ def __init__(self, config: "LoggerConfig" = None) -> None: datefmt=self.config.time_format, handlers=[handler, debug_handler, error_handler], ) - if config.capture_warnings: + if self.config.capture_warnings: logging.captureWarnings(True) warnings_logger = logging.getLogger("py.warnings") warnings_logger.addHandler(handler) @@ -132,7 +123,7 @@ def success( extra=extra, ) - def exception( + def exception( # pylint: disable=W1113 self, msg: Any = NONE, *args: Any, @@ -141,7 +132,7 @@ def exception( stacklevel: int = 1, extra: Optional[Mapping[str, Any]] = None, **kwargs, - ) -> None: # pylint: disable=W1113 + ) -> None: super(Logger, self).exception( "" if msg is NONE else msg, *args, @@ -189,7 +180,7 @@ def addFilter(self, log_filter: LogFilterType) -> None: # pylint: disable=argum handler.addFilter(log_filter) -class LogFilter(logging.Filter): +class LogFilter(logging.Filter): # skipcq: PY-A6006 _filter_list: List[Callable[["LogRecord"], bool]] = [] def __init__(self, name: str = ""): diff --git a/utils/log/_style.py b/utils/log/_style.py index 45475fe3..376e4ef9 100644 --- a/utils/log/_style.py +++ b/utils/log/_style.py @@ -1,22 +1,10 @@ from typing import Dict from pygments.style import Style as PyStyle -from pygments.token import ( - Comment, - Error, - Generic, - Keyword, - Literal, - Name, - Number, - Operator, - Punctuation, - String, - Text, -) +from pygments.token import Comment, Error, Generic, Keyword, Literal, Name, Number, Operator, Punctuation, String, Text from rich.style import Style -__all__ = [ +__all__ = ( "MonokaiProStyle", "DEFAULT_STYLE", "BACKGROUND", @@ -34,7 +22,7 @@ "BLUE", "CYAN", "WHITE", -] +) BACKGROUND = "#272822" FOREGROUND = "#f8f8f2" diff --git a/utils/log/_traceback.py b/utils/log/_traceback.py index 0d2ae2b0..96b33079 100644 --- a/utils/log/_traceback.py +++ b/utils/log/_traceback.py @@ -1,54 +1,26 @@ import os import traceback as traceback_ -from types import ( - ModuleType, - TracebackType, -) -from typing import ( - Any, - Dict, - Iterable, - List, - Mapping, - Optional, - TYPE_CHECKING, - Tuple, - Type, - Union, -) +from types import ModuleType, TracebackType +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Mapping, Optional, Tuple, Type, Union from rich import pretty from rich.columns import Columns -from rich.console import ( - RenderResult, - group, -) +from rich.console import RenderResult, group from rich.highlighter import ReprHighlighter from rich.panel import Panel from rich.pretty import Pretty -from rich.syntax import ( - PygmentsSyntaxTheme, - Syntax, -) +from rich.syntax import PygmentsSyntaxTheme, Syntax from rich.table import Table -from rich.text import ( - Text, - TextType, -) -from rich.traceback import ( - Frame, - PathHighlighter, - Stack, - Trace, - Traceback as BaseTraceback, -) +from rich.text import Text, TextType +from rich.traceback import Frame, PathHighlighter, Stack, Trace, LOCALS_MAX_LENGTH, LOCALS_MAX_STRING +from rich.traceback import Traceback as BaseTraceback from utils.log._style import MonokaiProStyle if TYPE_CHECKING: from rich.console import ConsoleRenderable # pylint: disable=W0611 -__all__ = ["render_scope", "Traceback"] +__all__ = ("render_scope", "Traceback") def render_scope( @@ -124,15 +96,18 @@ def from_exception( exc_type: Type[BaseException], exc_value: BaseException, traceback: Optional[TracebackType], + *, width: Optional[int] = 100, extra_lines: int = 3, theme: Optional[str] = None, word_wrap: bool = False, - show_locals: bool = True, - indent_guides: bool = True, - locals_max_length: int = 10, - locals_max_string: int = 80, + show_locals: bool = False, + locals_max_length: int = LOCALS_MAX_LENGTH, + locals_max_string: int = LOCALS_MAX_STRING, locals_max_depth: Optional[int] = None, + locals_hide_dunder: bool = True, + locals_hide_sunder: bool = False, + indent_guides: bool = True, suppress: Iterable[Union[str, ModuleType]] = (), max_frames: int = 100, ) -> "Traceback": @@ -166,10 +141,13 @@ def extract( exc_type: Type[BaseException], exc_value: BaseException, traceback: Optional[TracebackType], + *, show_locals: bool = False, - locals_max_length: int = 10, - locals_max_string: int = 80, + locals_max_length: int = LOCALS_MAX_LENGTH, + locals_max_string: int = LOCALS_MAX_STRING, locals_max_depth: Optional[int] = None, + locals_hide_dunder: bool = True, + locals_hide_sunder: bool = False, ) -> Trace: # noinspection PyProtectedMember from rich import _IMPORT_CWD diff --git a/utils/models/base.py b/utils/models/base.py deleted file mode 100644 index e1f92733..00000000 --- a/utils/models/base.py +++ /dev/null @@ -1,115 +0,0 @@ -import imghdr -import os -from enum import Enum -from typing import ( - Optional, - Union, -) - -import ujson as json -from pydantic import BaseSettings - -from utils.baseobject import BaseObject - - -class Stat: - def __init__( - self, view_num: int = 0, reply_num: int = 0, like_num: int = 0, bookmark_num: int = 0, forward_num: int = 0 - ): - self.forward_num = forward_num # 关注数 - self.bookmark_num = bookmark_num # 收藏数 - self.like_num = like_num # 喜欢数 - self.reply_num = reply_num # 回复数 - self.view_num = view_num # 观看数 - - -class ArtworkInfo: - def __init__(self): - self.user_id: int = 0 - self.artwork_id: int = 0 # 作品ID - self.site = "" - self.title: str = "" # 标题 - self.origin_url: str = "" - self.site_name: str = "" - self.tags: list = [] - self.stat: Stat = Stat() - self.create_timestamp: int = 0 - self.info = None - - -class ArtworkImage: - def __init__(self, art_id: int, page: int = 0, is_error: bool = False, data: bytes = b""): - self.art_id = art_id - self.data = data - self.is_error = is_error - if not is_error: - self.format: str = imghdr.what(None, self.data) - self.page = page - - -class RegionEnum(Enum): - """注册服务器的列举型别 - - HYPERION名称来源于米忽悠BBS的安卓端包名结尾 - - 查了一下确实有点意思 考虑到大部分重要的功能确实是在移动端实现了 - - 干脆用这个还好听 )""" - - NULL = None - HYPERION = 1 # 米忽悠国服 hyperion - HOYOLAB = 2 # 米忽悠国际服 hoyolab - - -class GameItem(BaseObject): - def __init__( - self, - item_id: int = 0, - name: str = "", - item_type: Union[Enum, str, int] = "", - value: Union[Enum, str, int, bool, float] = 0, - ): - self.item_id = item_id - self.name = name # 名称 - self.type = item_type # 类型 - self.value = value # 数值 - - __slots__ = ("name", "type", "value", "item_id") - - -class ModuleInfo: - def __init__( - self, file_name: Optional[str] = None, plugin_name: Optional[str] = None, relative_path: Optional[str] = None - ): - self.relative_path = relative_path - self.module_name = plugin_name - self.file_name = file_name - if file_name is None: - if relative_path is None: - raise ValueError("file_name 和 relative_path 都不能为空") - self.file_name = os.path.basename(relative_path) - if plugin_name is None: - self.module_name, _ = os.path.splitext(self.file_name) - - @property - def package_path(self) -> str: - if self.relative_path is None: - return "" - if os.path.isdir(self.relative_path): - return self.relative_path.replace("/", ".") - root, _ = os.path.splitext(self.relative_path) - return root.replace("/", ".") - - def __str__(self): - return self.module_name - - -class Settings(BaseSettings): - def __new__(cls, *args, **kwargs): - cls.update_forward_refs() - return super(Settings, cls).__new__(cls) - - class Config(BaseSettings.Config): - case_sensitive = False - json_loads = json.loads - json_dumps = json.dumps diff --git a/utils/models/lock.py b/utils/models/lock.py new file mode 100644 index 00000000..e8e39fb8 --- /dev/null +++ b/utils/models/lock.py @@ -0,0 +1,50 @@ +import asyncio +from asyncio import Task +from multiprocessing import RLock as Lock +from typing import Any, Dict, TYPE_CHECKING + +if TYPE_CHECKING: + from multiprocessing.synchronize import RLock as LockType + +__all__ = ("HashLock",) + +_lock: "LockType" = Lock() +_locks: Dict[int, "LockType"] = {} +_clean_lock_task_map: Dict[int, Task] + + +async def delete_lock(target: int) -> None: + await asyncio.sleep(3) + with _lock: + del _locks[target] + del _clean_lock_task_map[target] # pylint: disable=E0602 + + +class HashLock: + """可以根据 hash 来获取锁的类""" + + target: int + + @property + def lock(self) -> "LockType": + # noinspection PyTypeChecker + with _lock: + if self.target not in _locks: + # noinspection PyTypeChecker + _locks[self.target] = Lock() + else: + _clean_lock_task_map[self.target].cancel() + _clean_lock_task_map.update({self.target: asyncio.create_task(delete_lock(self.target))}) + return _locks[self.target] + + def __init__(self, target: Any) -> None: + if not isinstance(target, int): + target = hash(target) + self.target = target + + def __enter__(self) -> None: + # noinspection PyTypeChecker + return self.lock.__enter__() + + def __exit__(self, exc_type=None, exc_val=None, exc_tb=None): + return self.lock.__exit__(exc_type, exc_val, exc_tb) diff --git a/utils/models/signal.py b/utils/models/signal.py new file mode 100644 index 00000000..d0a86af0 --- /dev/null +++ b/utils/models/signal.py @@ -0,0 +1,52 @@ +from multiprocessing import RLock as Lock +from typing import ClassVar, Generic, Optional, TYPE_CHECKING, Type, TypeVar + +from typing_extensions import Self + +if TYPE_CHECKING: + from multiprocessing.synchronize import RLock as LockType + +__all__ = ["singleton", "Singleton"] + +T = TypeVar("T") + + +class _Singleton(Generic[T]): + lock: ClassVar["LockType"] = Lock() + + __slots__ = "cls", "instance" + + cls: Type[T] + instance: Optional[T] + + def __init__(self, cls: Type[T]): + self.cls = cls + self.instance = None + + def __call__(self, *args, **kwargs) -> T: + with self.lock: + if self.instance is None or args or kwargs: + self.instance = self.cls(*args, **kwargs) + return self.instance + + +def singleton(cls: Optional[Type[T]] = None) -> Type[T]: + """单例装饰器。用于装饰 class , 使之成为单例""" + + def wrap(_cls: Type[T]) -> _Singleton[T]: + return _Singleton(_cls) + + return wrap if cls is None else wrap(cls) + + +class Singleton: + """单例""" + + _lock: ClassVar["LockType"] = Lock() + _instance: ClassVar[Optional[Self]] = None + + def __new__(cls: Type[T], *args, **kwargs) -> T: + with cls._lock: + if cls._instance is None: + cls._instance = object.__new__(cls) + return cls._instance diff --git a/utils/patch/aiohttp.py b/utils/patch/aiohttp.py index 0e5d5663..d64ce0a0 100644 --- a/utils/patch/aiohttp.py +++ b/utils/patch/aiohttp.py @@ -2,11 +2,12 @@ from typing import Optional import aiohttp # pylint: disable=W0406 +from aiohttp import ClientError from utils.patch.methods import patch, patchable -class AioHttpTimeoutException(asyncio.TimeoutError): +class AioHttpTimeoutException(ClientError): pass diff --git a/utils/patch/genshin.py b/utils/patch/genshin.py index 2d317518..301b0b1a 100644 --- a/utils/patch/genshin.py +++ b/utils/patch/genshin.py @@ -37,7 +37,7 @@ async def request_calculator( headers: typing.Optional[aiohttp.typedefs.LooseHeaders] = None, **kwargs: typing.Any, ) -> typing.Mapping[str, typing.Any]: - global UPDATE_CHARACTERS + global UPDATE_CHARACTERS # pylint: disable=W0603 params = dict(params or {}) headers = dict(headers or {}) @@ -65,7 +65,7 @@ async def request_calculator( try: await update_task UPDATE_CHARACTERS = True - except Exception as e: + except Exception as e: # pylint: disable=W0703 warnings.warn(f"Failed to update characters: {e!r}") return data diff --git a/utils/queues.py b/utils/queues.py new file mode 100644 index 00000000..d524da06 --- /dev/null +++ b/utils/queues.py @@ -0,0 +1,447 @@ +"""线程安全的队列""" + +import asyncio +import sys +from asyncio import QueueEmpty as AsyncQueueEmpty +from asyncio import QueueFull as AsyncQueueFull +from collections import deque +from heapq import heappop, heappush +from queue import Empty as SyncQueueEmpty +from queue import Full as SyncQueueFull +from threading import Condition, Lock +from typing import TYPE_CHECKING, Any, Callable, Deque, Generic, List, NoReturn, Optional, Set, TypeVar + +from utils.typedefs import AsyncQueue, SyncQueue + +if TYPE_CHECKING: + from asyncio import AbstractEventLoop as EventLoop + +__all__ = ( + "Queue", + "PriorityQueue", + "LifoQueue", +) + +T = TypeVar("T") +OptFloat = Optional[float] + + +class Queue(Generic[T]): + """线程安全的同步、异步队列""" + + _loop: "EventLoop" + + @property + def loop(self) -> "EventLoop": + """返回该队列的事件循环""" + try: + self._loop = asyncio.get_running_loop() + except RuntimeError as e: + raise RuntimeError("没有正在运行的事件循环, 请在异步函数中使用.") from e + return self._loop + + def __init__(self, maxsize: int = 0): + """初始化队列 + + Args: + maxsize (int): 队列的大小 + Returns: + 无 + """ + self._maxsize = maxsize + + self._init() + + self.unfinished_tasks = 0 + + self.sync_mutex = Lock() + self.sync_not_empty = Condition(self.sync_mutex) + self.sync_not_full = Condition(self.sync_mutex) + self.all_tasks_done = Condition(self.sync_mutex) + + self.async_mutex = asyncio.Lock() + if sys.version_info[:3] == (3, 10, 0): + # 针对 3.10 的 bug + getattr(self.async_mutex, "_get_loop", lambda: None)() + self.async_not_empty = asyncio.Condition(self.async_mutex) + self.async_not_full = asyncio.Condition(self.async_mutex) + self.finished = asyncio.Event() + self.finished.set() + + self.closing = False + self.pending = set() # type: Set[asyncio.Future[Any]] + + def checked_call_soon_threadsafe(callback: Callable[..., None], *args: Any) -> NoReturn: + try: + self.loop.call_soon_threadsafe(callback, *args) + except RuntimeError: + pass + + self._call_soon_threadsafe = checked_call_soon_threadsafe + + def checked_call_soon(callback: Callable[..., None], *args: Any) -> NoReturn: + if not self.loop.is_closed(): + self.loop.call_soon(callback, *args) + + self._call_soon = checked_call_soon + + self._sync_queue = _SyncQueueProxy(self) + self._async_queue = _AsyncQueueProxy(self) + + def close(self) -> NoReturn: + """关闭队列""" + with self.sync_mutex: + self.closing = True + for fut in self.pending: + fut.cancel() + self.finished.set() # 取消堵塞全部的 async_q.join() + self.all_tasks_done.notify_all() # 取消堵塞全部的 sync_q.join() + + async def wait_closed(self) -> NoReturn: + if not self.closing: + raise RuntimeError("队列已被关闭") + await asyncio.sleep(0) + if not self.pending: + return + await asyncio.wait(self.pending) + + @property + def closed(self) -> bool: + return self.closing and not self.pending + + @property + def maxsize(self) -> int: + return self._maxsize + + @property + def sync_q(self) -> "_SyncQueueProxy[T]": + return self._sync_queue + + @property + def async_q(self) -> "_AsyncQueueProxy[T]": + return self._async_queue + + def _init(self) -> NoReturn: + self._queue = deque() # type: Deque[T] + + def qsize(self) -> int: + return len(self._queue) + + def put(self, item: T) -> NoReturn: + self._queue.append(item) + + def get(self) -> T: + return self._queue.popleft() + + def put_internal(self, item: T) -> NoReturn: + self.put(item) + self.unfinished_tasks += 1 + self.finished.clear() + + def notify_sync_not_empty(self) -> NoReturn: + def f() -> NoReturn: + with self.sync_mutex: + self.sync_not_empty.notify() + + self.loop.run_in_executor(None, f) + + def notify_sync_not_full(self) -> NoReturn: + def f() -> NoReturn: + with self.sync_mutex: + self.sync_not_full.notify() + + fut = asyncio.ensure_future(self.loop.run_in_executor(None, f)) + fut.add_done_callback(self.pending.discard) + self.pending.add(fut) + + def notify_async_not_empty(self, *, threadsafe: bool) -> NoReturn: + async def f() -> NoReturn: + async with self.async_mutex: + self.async_not_empty.notify() + + def task_maker() -> NoReturn: + task = self.loop.create_task(f()) + task.add_done_callback(self.pending.discard) + self.pending.add(task) + + if threadsafe: + self._call_soon_threadsafe(task_maker) + else: + self._call_soon(task_maker) + + def notify_async_not_full(self, *, threadsafe: bool) -> NoReturn: + async def f() -> NoReturn: + async with self.async_mutex: + self.async_not_full.notify() + + def task_maker() -> NoReturn: + task = self.loop.create_task(f()) + task.add_done_callback(self.pending.discard) + self.pending.add(task) + + if threadsafe: + self._call_soon_threadsafe(task_maker) + else: + self._call_soon(task_maker) + + def check_closing(self) -> NoReturn: + if self.closing: + raise RuntimeError("禁止对已关闭的队列进行操作") + + +# noinspection PyProtectedMember +class _SyncQueueProxy(SyncQueue[T]): # pylint: disable=W0212 + """同步""" + + def __init__(self, parent: Queue[T]): + self._parent = parent + + @property + def maxsize(self) -> int: + return self._parent.maxsize + + @property + def closed(self) -> bool: + return self._parent.closed + + def task_done(self) -> NoReturn: + self._parent.check_closing() + with self._parent.all_tasks_done: + unfinished = self._parent.unfinished_tasks - 1 + if unfinished <= 0: + if unfinished < 0: + raise ValueError("task_done() 执行次数过多") + self._parent.all_tasks_done.notify_all() + self._parent.loop.call_soon_threadsafe(self._parent.finished.set) + self._parent.unfinished_tasks = unfinished + + def join(self) -> NoReturn: + self._parent.check_closing() + with self._parent.all_tasks_done: + while self._parent.unfinished_tasks: + self._parent.all_tasks_done.wait() + self._parent.check_closing() + + def qsize(self) -> int: + """返回队列的大致大小(不可靠)""" + return self._parent.qsize() + + @property + def unfinished_tasks(self) -> int: + """返回未完成 task 的数量""" + return self._parent.unfinished_tasks + + def empty(self) -> bool: + return not self._parent.qsize() + + def full(self) -> bool: + return 0 < self._parent.maxsize <= self._parent.qsize() + + def put(self, item: T, block: bool = True, timeout: OptFloat = None) -> NoReturn: + self._parent.check_closing() + with self._parent.sync_not_full: + if self._parent.maxsize > 0: + if not block: + if self._parent.qsize() >= self._parent.maxsize: + raise SyncQueueFull + elif timeout is None: + while self._parent.qsize() >= self._parent.maxsize: + self._parent.sync_not_full.wait() + elif timeout < 0: + raise ValueError("'timeout' 必须为一个非负数") + else: + time = self._parent.loop.time + end_time = time() + timeout + while self._parent.qsize() >= self._parent.maxsize: + remaining = end_time - time() + if remaining <= 0.0: + raise SyncQueueFull + self._parent.sync_not_full.wait(remaining) + self._parent.put_internal(item) + self._parent.sync_not_empty.notify() + self._parent.notify_async_not_empty(threadsafe=True) + + def get(self, block: bool = True, timeout: OptFloat = None) -> T: + self._parent.check_closing() + with self._parent.sync_not_empty: + if not block: + if not self._parent.qsize(): + raise SyncQueueEmpty + elif timeout is None: + while not self._parent.qsize(): + self._parent.sync_not_empty.wait() + elif timeout < 0: + raise ValueError("'timeout' 必须为一个非负数") + else: + time = self._parent.loop.time + end_time = time() + timeout + while not self._parent.qsize(): + remaining = end_time - time() + if remaining <= 0.0: + raise SyncQueueEmpty + self._parent.sync_not_empty.wait(remaining) + item = self._parent.get() + self._parent.sync_not_full.notify() + self._parent.notify_async_not_full(threadsafe=True) + return item + + def put_nowait(self, item: T) -> NoReturn: + return self.put(item, block=False) + + def get_nowait(self) -> T: + return self.get(block=False) + + +# noinspection PyProtectedMember +class _AsyncQueueProxy(AsyncQueue[T]): # pylint: disable=W0212 + """异步""" + + def __init__(self, parent: Queue[T]): + self._parent = parent + + @property + def closed(self) -> bool: + return self._parent.closed + + def qsize(self) -> int: + return self._parent.qsize() + + @property + def unfinished_tasks(self) -> int: + return self._parent.unfinished_tasks + + @property + def maxsize(self) -> int: + return self._parent.maxsize + + def empty(self) -> bool: + return self.qsize() == 0 + + def full(self) -> bool: + if self._parent.maxsize <= 0: + return False + return self.qsize() >= self._parent.maxsize + + async def put(self, item: T) -> None: + self._parent.check_closing() + async with self._parent.async_not_full: + self._parent.sync_mutex.acquire() + locked = True + try: + if self._parent.maxsize > 0: + do_wait = True + while do_wait: + do_wait = self._parent.qsize() >= self._parent.maxsize + if do_wait: + locked = False + self._parent.sync_mutex.release() + await self._parent.async_not_full.wait() + self._parent.sync_mutex.acquire() + locked = True + + self._parent.put_internal(item) + self._parent.async_not_empty.notify() + self._parent.notify_sync_not_empty() + finally: + if locked: + self._parent.sync_mutex.release() + + def put_nowait(self, item: T) -> NoReturn: + self._parent.check_closing() + with self._parent.sync_mutex and 0 < self._parent.maxsize <= self._parent.qsize(): + raise AsyncQueueFull + + self._parent.put_internal(item) + self._parent.notify_async_not_empty(threadsafe=False) + self._parent.notify_sync_not_empty() + + async def get(self) -> T: + self._parent.check_closing() + async with self._parent.async_not_empty: + self._parent.sync_mutex.acquire() + locked = True + try: + do_wait = True + while do_wait: + do_wait = self._parent.qsize() == 0 + + if do_wait: + locked = False + self._parent.sync_mutex.release() + await self._parent.async_not_empty.wait() + self._parent.sync_mutex.acquire() + locked = True + + item = self._parent.get() + self._parent.async_not_full.notify() + self._parent.notify_sync_not_full() + return item + finally: + if locked: + self._parent.sync_mutex.release() + + def get_nowait(self) -> T: + self._parent.check_closing() + with self._parent.sync_mutex: + if self._parent.qsize() == 0: + raise AsyncQueueEmpty + + item = self._parent.get() + self._parent.notify_async_not_full(threadsafe=False) + self._parent.notify_sync_not_full() + return item + + def task_done(self) -> NoReturn: + self._parent.check_closing() + with self._parent.all_tasks_done: + if self._parent.unfinished_tasks <= 0: + raise ValueError("task_done() called too many times") + self._parent.unfinished_tasks -= 1 + if self._parent.unfinished_tasks == 0: + self._parent.finished.set() + self._parent.all_tasks_done.notify_all() + + async def join(self) -> None: + while True: + with self._parent.sync_mutex: + self._parent.check_closing() + if self._parent.unfinished_tasks == 0: + break + await self._parent.finished.wait() + + +class PriorityQueue(Queue[T]): + """优先级队列""" + + def _init(self) -> NoReturn: + self._heap_queue: List[T] = [] + + def qsize(self) -> int: + return len(self._heap_queue) + + def put(self, item: T) -> NoReturn: + if not isinstance(item, tuple): + if hasattr(item, "priority"): + item = (int(item.priority), item) + else: + try: + item = (int(item), item) + except (TypeError, ValueError): + pass + heappush(self._heap_queue, item) + + def get(self) -> T: + return heappop(self._heap_queue) + + +class LifoQueue(Queue[T]): + """后进先出队列""" + + def qsize(self) -> int: + return len(self._queue) + + def put(self, item: T) -> NoReturn: + self._queue.append(item) + + def get(self) -> T: + return self._queue.pop() diff --git a/utils/typedefs.py b/utils/typedefs.py deleted file mode 100644 index f8ff02ee..00000000 --- a/utils/typedefs.py +++ /dev/null @@ -1,34 +0,0 @@ -from logging import Filter, LogRecord -from pathlib import Path -from types import TracebackType -from typing import Any, Callable, Dict, Optional, Tuple, Type, Union - -from httpx import URL -from pydantic import ConstrainedInt - -__all__ = [ - "StrOrPath", - "StrOrURL", - "StrOrInt", - "SysExcInfoType", - "ExceptionInfoType", - "JSONDict", - "JSONType", - "LogFilterType", - "NaturalNumber", -] - -StrOrPath = Union[str, Path] -StrOrURL = Union[str, URL] -StrOrInt = Union[str, int] - -SysExcInfoType = Union[Tuple[Type[BaseException], BaseException, Optional[TracebackType]], Tuple[None, None, None]] -ExceptionInfoType = Union[bool, SysExcInfoType, BaseException] -JSONDict = Dict[str, Any] -JSONType = Union[JSONDict, list] - -LogFilterType = Union[Filter, Callable[[LogRecord], int]] - - -class NaturalNumber(ConstrainedInt): - ge = 0 diff --git a/utils/typedefs/__init__.py b/utils/typedefs/__init__.py new file mode 100644 index 00000000..3c595b25 --- /dev/null +++ b/utils/typedefs/__init__.py @@ -0,0 +1,65 @@ +import sys +from typing import Any, Callable, Dict, Optional, TYPE_CHECKING, Tuple, Type, Union + +from utils.typedefs._generics import * +from utils.typedefs._queue import AsyncQueue, BaseQueue, SyncQueue + +if sys.version_info >= (3, 9): + from types import GenericAlias +else: + # noinspection PyUnresolvedReferences,PyProtectedMember + from typing import _GenericAlias as GenericAlias + +__all__ = [ + "GenericAlias", + "StrOrPath", + "StrOrURL", + "StrOrInt", + "SysExcInfoType", + "ExceptionInfoType", + "JSONDict", + "JSONType", + "LogFilterType", + "NaturalNumber", + # queue + "BaseQueue", + "SyncQueue", + "AsyncQueue", + # generics + "P", + "T", + "R", +] + +if TYPE_CHECKING: + from pathlib import Path + from httpx import URL + from logging import Filter, LogRecord + from pydantic import ConstrainedInt + from types import TracebackType + + StrOrPath = Union[str, Path] + StrOrURL = Union[str, URL] + LogFilterType = Union[Filter, Callable[[LogRecord], int]] + + SysExcInfoType = Union[Tuple[Type[BaseException], BaseException, Optional[TracebackType]], Tuple[None, None, None]] + + class NaturalNumber(ConstrainedInt): + """自然数""" + + ge = 0 + +else: + StrOrPath = Union[str, "Path"] + StrOrURL = Union[str, "URL"] + LogFilterType = Union["Filter", Callable[["LogRecord"], int]] + SysExcInfoType = Union[ + Tuple[Type[BaseException], BaseException, Optional["TracebackType"]], Tuple[None, None, None] + ] + NaturalNumber = int + +StrOrInt = Union[str, int] + +ExceptionInfoType = Union[bool, SysExcInfoType, BaseException] +JSONDict = Dict[str, Any] +JSONType = Union[JSONDict, list] diff --git a/utils/typedefs/_generics.py b/utils/typedefs/_generics.py new file mode 100644 index 00000000..537a20ba --- /dev/null +++ b/utils/typedefs/_generics.py @@ -0,0 +1,13 @@ +import sys +from typing import TypeVar + +__all__ = ("T", "R", "P") + +if sys.version_info >= (3, 10): + from typing import ParamSpec +else: + from typing_extensions import ParamSpec + +T = TypeVar("T") # normal type +R = TypeVar("R") # return type +P = ParamSpec("P") # param type diff --git a/utils/typedefs/_queue.py b/utils/typedefs/_queue.py new file mode 100644 index 00000000..a73826a2 --- /dev/null +++ b/utils/typedefs/_queue.py @@ -0,0 +1,92 @@ +# pylint: disable=W0049 +from typing import NoReturn, Optional, Protocol, TypeVar + +__all__ = ["BaseQueue", "SyncQueue", "AsyncQueue"] + +T = TypeVar("T") + + +# noinspection PyPropertyDefinition +class BaseQueue(Protocol[T]): # pylint: disable=W0049 + @property + def maxsize(self) -> int: + raise NotImplementedError + + @property + def closed(self) -> bool: + raise NotImplementedError + + def task_done(self) -> NoReturn: + raise NotImplementedError() + + def qsize(self) -> int: + raise NotImplementedError() + + @property + def unfinished_tasks(self) -> int: + raise NotImplementedError + + def empty(self) -> bool: + raise NotImplementedError() + + def full(self) -> bool: + raise NotImplementedError() + + def put_nowait(self, item: T) -> None: + raise NotImplementedError() + + def get_nowait(self) -> T: + raise NotImplementedError() + + +# noinspection PyPropertyDefinition +class SyncQueue(BaseQueue[T], Protocol[T]): # pylint: disable=W0049 + @property + def maxsize(self) -> int: + raise NotImplementedError + + @property + def closed(self) -> bool: + raise NotImplementedError + + def task_done(self) -> NoReturn: + raise NotImplementedError() + + def qsize(self) -> int: + raise NotImplementedError() + + @property + def unfinished_tasks(self) -> int: + raise NotImplementedError + + def empty(self) -> bool: + raise NotImplementedError() + + def full(self) -> bool: + raise NotImplementedError() + + def put_nowait(self, item: T) -> None: + raise NotImplementedError() + + def get_nowait(self) -> T: + raise NotImplementedError() + + def put(self, item: T, block: bool = True, timeout: Optional[float] = None) -> None: + raise NotImplementedError() + + def get(self, block: bool = True, timeout: Optional[float] = None) -> T: + raise NotImplementedError() + + def join(self) -> None: + raise NotImplementedError() + + +class AsyncQueue(BaseQueue[T], Protocol[T]): # pylint: disable=W0049 + async def put(self, item: T) -> None: + pass + + async def get(self) -> T: + pass + + async def join(self) -> None: + pass