Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add a thin veneer of compatibility with mypy #166

Merged
merged 12 commits into from
Jan 2, 2022
16 changes: 16 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -122,3 +122,19 @@ reportUnknownVariableType = "warning" # Lotta false-positives, might f
[tool.pytest.ini_options]
testpaths = ["tests"]
required_plugins = ["pytest-asyncio"]

[tool.mypy]
# some good strict settings
strict = true
warn_unreachable = true

# more narrow type ignores
show_error_codes = true

# these are used by pyright
warn_unused_ignores = false
warn_redundant_casts = false

# compatibility with pyright
allow_redefinition = true
disable_error_code = ["return"]
8 changes: 6 additions & 2 deletions tanjun/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -2468,7 +2468,9 @@ def with_slash_command(self, command: BaseSlashCommandT, /) -> BaseSlashCommandT

@typing.overload
@abc.abstractmethod
def with_slash_command(self, *, copy: bool = False) -> collections.Callable[[BaseSlashCommandT], BaseSlashCommandT]:
def with_slash_command(
self, /, *, copy: bool = False
) -> collections.Callable[[BaseSlashCommandT], BaseSlashCommandT]:
...

@abc.abstractmethod
Expand Down Expand Up @@ -2535,7 +2537,9 @@ def with_message_command(self, command: MessageCommandT, /) -> MessageCommandT:

@typing.overload
@abc.abstractmethod
def with_message_command(self, *, copy: bool = False) -> collections.Callable[[MessageCommandT], MessageCommandT]:
def with_message_command(
self, /, *, copy: bool = False
) -> collections.Callable[[MessageCommandT], MessageCommandT]:
...

@abc.abstractmethod
Expand Down
17 changes: 8 additions & 9 deletions tanjun/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -812,7 +812,7 @@ def message_accepts(self) -> MessageAcceptsEnum:
@property
def is_human_only(self) -> bool:
"""Whether this client is only executing for non-bot/webhook users messages."""
return _check_human in self._checks
return typing.cast("checks.InjectableCheck", _check_human) in self._checks

@property
def cache(self) -> typing.Optional[hikari.api.Cache]:
Expand Down Expand Up @@ -1651,9 +1651,9 @@ async def close(self, *, deregister_listeners: bool = True) -> None:

self._try_unsubscribe(self._events, hikari.InteractionCreateEvent, self.on_interaction_create_event)

for event_type, listeners in self._listeners.items():
for event_type_, listeners in self._listeners.items():
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
for listener in listeners:
self._try_unsubscribe(self._events, event_type, listener.__call__)
self._try_unsubscribe(self._events, event_type_, listener.__call__)

if deregister_listeners and self._server:
self._server.set_listener(hikari.CommandInteraction, None)
Expand Down Expand Up @@ -1708,9 +1708,9 @@ async def open(self, *, register_listeners: bool = True) -> None:

self._events.subscribe(hikari.InteractionCreateEvent, self.on_interaction_create_event)

for event_type, listeners in self._listeners.items():
for event_type_, listeners in self._listeners.items():
for listener in listeners:
self._events.subscribe(event_type, listener.__call__)
self._events.subscribe(event_type_, listener.__call__)

if register_listeners and self._server:
self._server.set_listener(hikari.CommandInteraction, self.on_interaction_create_request)
Expand All @@ -1736,13 +1736,12 @@ async def fetch_rest_application_id(self) -> hikari.Snowflake:
return application.id

if self._rest.token_type == hikari.TokenType.BOT:
application = await self._rest.fetch_application()
self._cached_application_id = hikari.Snowflake(await self._rest.fetch_application())

else:
application = (await self._rest.fetch_authorization()).application
self._cached_application_id = hikari.Snowflake((await self._rest.fetch_authorization()).application)

self._cached_application_id = application.id
return application.id
return self._cached_application_id

def set_hooks(self: _ClientT, hooks: typing.Optional[tanjun_abc.AnyHooks], /) -> _ClientT:
"""Set the general command execution hooks for this client.
Expand Down
32 changes: 14 additions & 18 deletions tanjun/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1113,7 +1113,7 @@ def __init__(
)

self._builder = _CommandBuilder(name, description, sort_options).set_default_permission(default_permission)
self._callback = injecting.CallbackDescriptor(callback)
self._callback = injecting.CallbackDescriptor[None](callback)
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
self._client: typing.Optional[abc.Client] = None
self._tracked_options: dict[str, _TrackedOption] = {}
self._wrapped_command = _wrapped_command
Expand Down Expand Up @@ -1315,8 +1315,8 @@ def add_str_option(
actual_choices = {}
warned = False
for choice in choices:
if isinstance(choice, tuple):
if not warned:
if isinstance(choice, tuple): # type: ignore[unreachable] # the point of this is for deprecation
A5rocks marked this conversation as resolved.
Show resolved Hide resolved
if not warned: # type: ignore[unreachable] # mypy sees `warned = True` and messes up.
warnings.warn(
"Passing a sequence of tuples for 'choices' is deprecated since 2.1.2a1, "
"please pass a mapping instead.",
Expand Down Expand Up @@ -1875,22 +1875,18 @@ async def _process_args(self, ctx: abc.SlashContext, /) -> collections.Mapping[s
keyword_args[option.name] = option.resolve_to_role()

elif option.type is hikari.OptionType.MENTIONABLE:
if option.type is hikari.OptionType.ROLE:
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved
keyword_args[option.name] = option.resolve_to_role()

else:
member: typing.Optional[hikari.InteractionMember] = None
if tracked_option.is_only_member and not (member := option.resolve_to_member()):
raise errors.ConversionError(
f"Couldn't find member for provided user: {option.value}", tracked_option.name
)
member = None
if tracked_option.is_only_member and not (member := option.resolve_to_member()):
raise errors.ConversionError(
f"Couldn't find member for provided user: {option.value}", tracked_option.name
)

keyword_args[option.name] = member or option.resolve_to_mentionable()
keyword_args[option.name] = member or option.resolve_to_mentionable()

else:
value = option.value
# To be type safe we obfuscate the fact that discord's double type will provide am int or float
# depending on the value Disocrd input by always casting to float.
# To be type safe we obfuscate the fact that discord's double type will provide an int or float
# depending on the value Discord inputs by always casting to float.
if tracked_option.type is hikari.OptionType.FLOAT and tracked_option.is_always_float:
value = float(value)

Expand Down Expand Up @@ -2044,7 +2040,7 @@ def __init__(
_wrapped_command: typing.Optional[abc.ExecutableCommand[typing.Any]] = None,
) -> None:
super().__init__(checks=checks, hooks=hooks, metadata=metadata)
self._callback = injecting.CallbackDescriptor(callback)
self._callback = injecting.CallbackDescriptor[None](callback)
self._names = list(dict.fromkeys((name, *names)))
self._parent: typing.Optional[abc.MessageCommandGroup[typing.Any]] = None
self._parser = parser
Expand Down Expand Up @@ -2310,8 +2306,8 @@ def find_command(self, content: str, /) -> collections.Iterable[tuple[str, abc.M
return

for command in self._commands:
if (name := utilities.match_prefix_names(content, command.names)) is not None:
yield name, command
if (name_ := utilities.match_prefix_names(content, command.names)) is not None:
yield name_, command

async def execute(
self,
Expand Down
19 changes: 10 additions & 9 deletions tanjun/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,7 +540,7 @@ def with_command(self, command: CommandT, /) -> CommandT:
...

@typing.overload
def with_command(self, *, copy: bool = False) -> collections.Callable[[CommandT], CommandT]:
def with_command(self, /, *, copy: bool = False) -> collections.Callable[[CommandT], CommandT]:
...

def with_command(
Expand Down Expand Up @@ -593,7 +593,7 @@ def with_slash_command(self, command: tanjun_abc.BaseSlashCommandT, /) -> tanjun

@typing.overload
def with_slash_command(
self, *, copy: bool = False
self, /, *, copy: bool = False
) -> collections.Callable[[tanjun_abc.BaseSlashCommandT], tanjun_abc.BaseSlashCommandT]:
...

Expand Down Expand Up @@ -662,7 +662,7 @@ def with_message_command(self, command: tanjun_abc.MessageCommandT, /) -> tanjun

@typing.overload
def with_message_command(
self, *, copy: bool = False
self, /, *, copy: bool = False
) -> collections.Callable[[tanjun_abc.MessageCommandT], tanjun_abc.MessageCommandT]:
...

Expand Down Expand Up @@ -832,7 +832,7 @@ def bind_client(self: _ComponentT, client: tanjun_abc.Client, /) -> _ComponentT:

return self

def unbind_client(self, client: tanjun_abc.Client, /) -> None:
def unbind_client(self: _ComponentT, client: tanjun_abc.Client, /) -> _ComponentT:
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved
# <<inherited docstring from tanjun.abc.Component>>.
if not self._client or self._client != client:
raise RuntimeError("Component isn't bound to this client")
Expand All @@ -853,6 +853,8 @@ def unbind_client(self, client: tanjun_abc.Client, /) -> None:

self._client = None

return self

async def _check_context(self, ctx: tanjun_abc.Context, /) -> bool:
return await utilities.gather_checks(ctx, self._checks)

Expand Down Expand Up @@ -896,10 +898,8 @@ def check_message_name(
return

for command in self._message_commands:
if (name := utilities.match_prefix_names(content, command.names)) is not None:
yield name, command
# Don't want to match a command multiple times
continue
if (name_ := utilities.match_prefix_names(content, command.names)) is not None:
yield name_, command
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved

def check_slash_name(self, name: str, /) -> collections.Iterator[tanjun_abc.BaseSlashCommand]:
# <<inherited docstring from tanjun.abc.Component>>.
Expand All @@ -923,7 +923,8 @@ async def _execute_interaction(

except errors.CommandError as exc:
await ctx.respond(exc.message)
return asyncio.get_running_loop().create_future().set_result(None)
asyncio.get_running_loop().create_future().set_result(None)
return None
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved

if self._slash_hooks:
if hooks is None:
Expand Down
4 changes: 2 additions & 2 deletions tanjun/conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,8 @@ async def __call__(
dm_cache: _DmCacheT = injecting.inject(type=_DmCacheT),
) -> hikari.PartialChannel:
channel_id = parse_channel_id(argument, message="No valid channel mention or ID found")
if ctx.cache and (channel := ctx.cache.get_guild_channel(channel_id)):
return channel
if ctx.cache and (channel_ := ctx.cache.get_guild_channel(channel_id)):
return channel_

no_guild_channel = False
if cache:
Expand Down
2 changes: 1 addition & 1 deletion tanjun/dependencies/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,5 +94,5 @@ def set_standard_dependencies(client: injecting.InjectorClient, /) -> None:
The injector client to set the standard dependencies on.
"""
client.set_type_dependency(AbstractOwners, Owners()).set_type_dependency(
LazyConstant[hikari.OwnUser], LazyConstant(fetch_my_user)
LazyConstant[hikari.OwnUser], LazyConstant[hikari.OwnUser](fetch_my_user)
)
1 change: 1 addition & 0 deletions tanjun/dependencies/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,7 @@ def make_lc_resolver(type_: type[_T], /) -> collections.Callable[..., collection
"""

async def resolve(
# LazyConstant gets type arguments at runtime
constant: LazyConstant[_T] = injecting.inject(type=LazyConstant[type_]),
ctx: injecting.AbstractInjectionContext = injecting.inject(type=injecting.AbstractInjectionContext),
) -> _T:
Expand Down
18 changes: 9 additions & 9 deletions tanjun/dependencies/limiters.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,14 +215,14 @@ async def _get_ctx_target(ctx: tanjun_abc.Context, type_: BucketResource, /) ->
if ctx.guild_id is None:
return ctx.channel_id

if channel := ctx.get_channel():
return channel.parent_id or ctx.guild_id
if cached_channel := ctx.get_channel():
return cached_channel.parent_id or ctx.guild_id

# TODO: upgrade this to the standard interface
assert isinstance(ctx, injecting.AbstractInjectionContext)
channel_cache = ctx.get_type_dependency(async_cache.SfCache[hikari.GuildChannel])
if channel_cache and (channel := await channel_cache.get(ctx.channel_id, default=None)):
return channel.parent_id or ctx.guild_id
if channel_cache and (channel_ := await channel_cache.get(ctx.channel_id, default=None)):
return channel_.parent_id or ctx.guild_id

channel = await ctx.fetch_channel()
assert isinstance(channel, hikari.TextableGuildChannel)
Expand Down Expand Up @@ -309,7 +309,7 @@ def increment(self: _CooldownT) -> _CooldownT:
def must_wait_for(self) -> typing.Optional[float]:
# A limit of -1 is special cased to mean no limit, so we don't need to wait.
if self.limit == -1:
return
return None

if self.counter >= self.limit and (time_left := self.resets_at - time.monotonic()) > 0:
return time_left
Expand Down Expand Up @@ -639,18 +639,18 @@ def set_bucket(
if limit is less 0 or negative.
"""
if isinstance(reset_after, datetime.timedelta):
reset_after = reset_after.total_seconds()
reset_after_seconds = reset_after.total_seconds()
else:
reset_after = float(reset_after)
reset_after_seconds = float(reset_after)

if reset_after <= 0:
if reset_after_seconds <= 0:
raise ValueError("reset_after must be greater than 0 seconds")

if limit <= 0:
raise ValueError("limit must be greater than 0")

bucket = self._buckets[bucket_id] = _to_bucket(
BucketResource(resource), lambda: _Cooldown(limit=limit, reset_after=reset_after)
BucketResource(resource), lambda: _Cooldown(limit=limit, reset_after=reset_after_seconds)
)
if bucket_id == "default":
self._default_bucket_template = bucket.copy()
Expand Down
4 changes: 2 additions & 2 deletions tanjun/injecting.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,11 @@ def __init__(self, callback: CallbackSig[_T], /) -> None:
self._is_async: typing.Optional[bool] = None
self._descriptors, self._needs_injector = self._parse_descriptors(callback)

# This is delegated to the callback in-order to delegate set/list behaviour for this class to the callback.
# This is delegated to the callback to delegate set/list behaviour for this class to the callback.
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved
def __eq__(self, other: typing.Any) -> bool:
return bool(self._callback == other)

# This is delegated to the callback in-order to delegate set/list behaviour for this class to the callback.
# This is delegated to the callback to delegate set/list behaviour for this class to the callback.
def __hash__(self) -> int:
return hash(self._callback)

Expand Down
8 changes: 6 additions & 2 deletions tanjun/parsing.py
Original file line number Diff line number Diff line change
Expand Up @@ -995,18 +995,22 @@ def set_parameters(

return self

def bind_client(self, client: tanjun_abc.Client, /) -> None:
def bind_client(self: _ShlexParserT, client: tanjun_abc.Client, /) -> _ShlexParserT:
FasterSpeeding marked this conversation as resolved.
Show resolved Hide resolved
# <<inherited docstring from AbstractParser>>.
self._client = client
for parameter in itertools.chain(self._options, self._arguments):
parameter.bind_client(client)

def bind_component(self, component: tanjun_abc.Component, /) -> None:
return self

def bind_component(self: _ShlexParserT, component: tanjun_abc.Component, /) -> _ShlexParserT:
# <<inherited docstring from AbstractParser>>.
self._component = component
for parameter in itertools.chain(self._options, self._arguments):
parameter.bind_component(component)

return self

def parse(
self, ctx: tanjun_abc.MessageContext, /
) -> collections.Coroutine[typing.Any, typing.Any, dict[str, typing.Any]]:
Expand Down
4 changes: 2 additions & 2 deletions tanjun/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,8 +227,8 @@ async def _fetch_channel(
return channel

channel_id = hikari.Snowflake(channel)
if client.cache and (found_channel := client.cache.get_guild_channel(channel_id)):
return found_channel
if client.cache and (found_channel_ := client.cache.get_guild_channel(channel_id)):
return found_channel_

if channel_cache := client.get_type_dependency(_ChannelCacheT):
try:
Expand Down