diff --git a/twitchio/__init__.py b/twitchio/__init__.py index ecd9d00e..420dbfe7 100644 --- a/twitchio/__init__.py +++ b/twitchio/__init__.py @@ -35,4 +35,5 @@ from .exceptions import * from .models import * from .payloads import * +from .user import * from .utils import Color as Color, Colour as Colour diff --git a/twitchio/ext/commands/bot.py b/twitchio/ext/commands/bot.py index 846a6982..1a449651 100644 --- a/twitchio/ext/commands/bot.py +++ b/twitchio/ext/commands/bot.py @@ -30,6 +30,7 @@ from twitchio.client import Client from .context import Context +from .converters import _BaseConverter from .core import Command, CommandErrorPayload, Group, Mixin from .exceptions import * @@ -65,6 +66,7 @@ def __init__( self._get_prefix: Prefix_T = prefix self._components: dict[str, Component] = {} + self._base_converter: _BaseConverter = _BaseConverter(self) @property def bot_id(self) -> str: @@ -135,3 +137,5 @@ async def event_command_error(self, payload: CommandErrorPayload) -> None: async def before_invoke(self, ctx: Context) -> None: ... async def after_invoke(self, ctx: Context) -> None: ... + + async def check(self, ctx: Context) -> None: ... diff --git a/twitchio/ext/commands/context.py b/twitchio/ext/commands/context.py index dae1f0dc..6b6b2187 100644 --- a/twitchio/ext/commands/context.py +++ b/twitchio/ext/commands/context.py @@ -59,6 +59,9 @@ def __init__(self, message: ChatMessage, bot: Bot) -> None: self._view: StringView = StringView(self._raw_content) + self._args: list[Any] = [] + self._kwargs: dict[str, Any] = {} + @property def message(self) -> ChatMessage: return self._message @@ -111,6 +114,14 @@ def error_dispatched(self) -> bool: def error_dispatched(self, value: bool, /) -> None: self._error_dispatched = value + @property + def args(self) -> list[Any]: + return self._args + + @property + def kwargs(self) -> dict[str, Any]: + return self._kwargs + def is_valid(self) -> bool: return self._prefix is not None diff --git a/twitchio/ext/commands/converters.py b/twitchio/ext/commands/converters.py new file mode 100644 index 00000000..f0a94316 --- /dev/null +++ b/twitchio/ext/commands/converters.py @@ -0,0 +1,78 @@ +""" +MIT License + +Copyright (c) 2017 - Present PythonistaGuild + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from twitchio.user import User + +from .exceptions import * + + +if TYPE_CHECKING: + from .bot import Bot + from .context import Context + +__all__ = ("_BaseConverter",) + + +class _BaseConverter: + def __init__(self, client: Bot) -> None: + self.__client: Bot = client + + self._MAPPING: dict[Any, Any] = {User: self._user} + self._DEFAULTS: dict[type, type] = {str: str, int: int, float: float} + + async def _user(self, context: Context, arg: str) -> User: + arg = arg.lower() + users: list[User] + msg: str = 'Failed to convert "{}" to User. A User with the ID or login could not be found.' + + if arg.startswith("@"): + arg = arg.removeprefix("@") + users = await self.__client.fetch_users(logins=[arg]) + + if not users: + raise BadArgument(msg.format(arg), value=arg) + + if arg.isdigit(): + users = await self.__client.fetch_users(logins=[arg], ids=[arg]) + else: + users = await self.__client.fetch_users(logins=[arg]) + + potential: list[User] = [] + + for user in users: + # ID's should be taken into consideration first... + if user.id == arg: + return user + + elif user.name == arg: + potential.append(user) + + if potential: + return potential[0] + + raise BadArgument(msg.format(arg), value=arg) diff --git a/twitchio/ext/commands/core.py b/twitchio/ext/commands/core.py index b451438a..ca0fd9e5 100644 --- a/twitchio/ext/commands/core.py +++ b/twitchio/ext/commands/core.py @@ -25,8 +25,12 @@ from __future__ import annotations import asyncio +import inspect from collections.abc import Callable, Coroutine -from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Unpack +from types import UnionType +from typing import TYPE_CHECKING, Any, Concatenate, Generic, ParamSpec, TypeAlias, TypeVar, Union, Unpack + +from twitchio.utils import MISSING from .exceptions import * from .types_ import CommandOptions, Component_T @@ -112,10 +116,134 @@ def extras(self) -> dict[Any, Any]: def has_error(self) -> bool: return self._error is not None + async def _do_conversion(self, context: Context, param: inspect.Parameter, *, annotation: Any, raw: str | None) -> Any: + name: str = param.name + + if isinstance(annotation, UnionType) or getattr(annotation, "__origin__", None) is Union: + converters = list(annotation.__args__) + converters.remove(type(None)) + + result: Any = MISSING + + for c in converters: + try: + result = await self._do_conversion(context, param=param, annotation=c, raw=raw) + except Exception: + continue + + if result is MISSING: + raise BadArgument( + f'Failed to convert argument "{name}" with any converter from Union: {converters}. No default value was provided.', + name=name, + value=raw, + ) + + return result + + base = context.bot._base_converter._DEFAULTS.get(annotation, None if annotation != param.empty else str) + if base: + try: + result = base(raw) + except Exception as e: + raise BadArgument(f'Failed to convert "{name}" to {base}', name=name, value=raw) from e + + return result + + converter = context.bot._base_converter._MAPPING.get(annotation, annotation) + + try: + result = converter(context, raw) + except Exception as e: + raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw) from e + + if not asyncio.iscoroutine(result): + return result + + try: + result = await result + except Exception as e: + raise BadArgument(f'Failed to convert "{name}" to {type(converter)}', name=name, value=raw) from e + + return result + + async def _parse_arguments(self, context: Context) -> ...: + context._view.skip_ws() + signature: inspect.Signature = inspect.signature(self._callback) + + # We expect context always and self with commands in components... + skip: int = 1 if not self._injected else 2 + params: list[inspect.Parameter] = list(signature.parameters.values())[skip:] + + args: list[Any] = [] + kwargs = {} + + for param in params: + if param.kind == param.KEYWORD_ONLY: + raw = context._view.read_rest() + + if not raw: + if param.default == param.empty: + raise MissingRequiredArgument(param=param) + + kwargs[param.name] = param.default + continue + + result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation) + kwargs[param.name] = result + break + + elif param.kind == param.VAR_POSITIONAL: + packed: list[Any] = [] + + while True: + context._view.skip_ws() + raw = context._view.get_quoted_word() + if not raw: + break + + result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation) + packed.append(result) + + args.extend(packed) + break + + elif param.kind == param.POSITIONAL_OR_KEYWORD: + raw = context._view.get_quoted_word() + context._view.skip_ws() + + if not raw: + if param.default == param.empty: + raise MissingRequiredArgument(param=param) + + args.append(param.default) + continue + + result = await self._do_conversion(context, param=param, raw=raw, annotation=param.annotation) + args.append(result) + + return args, kwargs + + async def _do_checks(self, context: Context) -> ...: + # Bot + # Component + # Command + ... + async def _invoke(self, context: Context) -> None: - # TODO: Argument parsing... - # TODO: Checks... Including cooldowns... - callback = self._callback(self._injected, context) if self._injected else self._callback(context) # type: ignore + try: + args, kwargs = await self._parse_arguments(context) + except (ConversionError, MissingRequiredArgument): + raise + except Exception as e: + raise ConversionError("An unknown error occurred converting arguments.") from e + + context._args = args + context._kwargs = kwargs + + args: list[Any] = [context, *args] + args.insert(0, self._injected) if self._injected else None + + callback = self._callback(*args, **kwargs) # type: ignore try: await callback @@ -127,6 +255,9 @@ async def invoke(self, context: Context) -> None: await self._invoke(context) except CommandError as e: await self._dispatch_error(context, e) + except Exception as e: + error = CommandInvokeError(str(e), original=e) + await self._dispatch_error(context, error) async def _dispatch_error(self, context: Context, exception: CommandError) -> None: payload = CommandErrorPayload(context=context, exception=exception) diff --git a/twitchio/ext/commands/exceptions.py b/twitchio/ext/commands/exceptions.py index 35a67f89..f770adaf 100644 --- a/twitchio/ext/commands/exceptions.py +++ b/twitchio/ext/commands/exceptions.py @@ -22,6 +22,8 @@ SOFTWARE. """ +import inspect + from twitchio.exceptions import TwitchioException @@ -35,6 +37,10 @@ "PrefixError", "InputError", "ArgumentError", + "CheckFailure", + "ConversionError", + "BadArgument", + "MissingRequiredArgument", ) @@ -84,3 +90,22 @@ class ExpectedClosingQuoteError(ArgumentError): def __init__(self, close_quote: str) -> None: self.close_quote: str = close_quote super().__init__(f"Expected closing {close_quote}.") + + +class CheckFailure(CommandError): ... + + +class ConversionError(ArgumentError): ... + + +class BadArgument(ConversionError): + def __init__(self, msg: str, *, name: str | None = None, value: str | None) -> None: + self.name: str | None = name + self.value: str | None = value + super().__init__(msg) + + +class MissingRequiredArgument(ArgumentError): + def __init__(self, param: inspect.Parameter) -> None: + self.param: inspect.Parameter = param + super().__init__(f'"{param.name}" is a required argument which is missing.')