Skip to content

Commit de8c91d

Browse files
authored
feat: make context have generic client types (#1699)
1 parent 545f9e2 commit de8c91d

File tree

7 files changed

+59
-46
lines changed

7 files changed

+59
-46
lines changed

interactions/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@
3939
smart_cache,
4040
T,
4141
T_co,
42+
ClientT,
4243
utils,
4344
)
4445
from .client import const
@@ -430,6 +431,7 @@
430431
"ChannelType",
431432
"check",
432433
"Client",
434+
"ClientT",
433435
"ClientUser",
434436
"Color",
435437
"COLOR_TYPES",

interactions/client/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
Absent,
3535
T,
3636
T_co,
37+
ClientT,
3738
)
3839
from .client import Client
3940
from .auto_shard_client import AutoShardedClient
@@ -77,6 +78,7 @@
7778
"Absent",
7879
"T",
7980
"T_co",
81+
"ClientT",
8082
"Client",
8183
"AutoShardedClient",
8284
"smart_cache",

interactions/client/client.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import traceback
1313
from collections.abc import Iterable
1414
from datetime import datetime
15+
from typing_extensions import Self
1516
from typing import (
1617
TYPE_CHECKING,
1718
Any,
@@ -367,17 +368,17 @@ def __init__(
367368
"""The HTTP client to use when interacting with discord endpoints"""
368369

369370
# context factories
370-
self.interaction_context: Type[BaseContext] = interaction_context
371+
self.interaction_context: Type[BaseContext[Self]] = interaction_context
371372
"""The object to instantiate for Interaction Context"""
372-
self.component_context: Type[BaseContext] = component_context
373+
self.component_context: Type[BaseContext[Self]] = component_context
373374
"""The object to instantiate for Component Context"""
374-
self.autocomplete_context: Type[BaseContext] = autocomplete_context
375+
self.autocomplete_context: Type[BaseContext[Self]] = autocomplete_context
375376
"""The object to instantiate for Autocomplete Context"""
376-
self.modal_context: Type[BaseContext] = modal_context
377+
self.modal_context: Type[BaseContext[Self]] = modal_context
377378
"""The object to instantiate for Modal Context"""
378-
self.slash_context: Type[BaseContext] = slash_context
379+
self.slash_context: Type[BaseContext[Self]] = slash_context
379380
"""The object to instantiate for Slash Context"""
380-
self.context_menu_context: Type[BaseContext] = context_menu_context
381+
self.context_menu_context: Type[BaseContext[Self]] = context_menu_context
381382
"""The object to instantiate for Context Menu Context"""
382383

383384
self.token: str | None = token
@@ -1826,7 +1827,7 @@ def update_command_cache(self, scope: "Snowflake_Type", command_name: str, comma
18261827
command.cmd_id[scope] = command_id
18271828
self._interaction_lookup[command.resolved_name] = command
18281829

1829-
async def get_context(self, data: dict) -> InteractionContext:
1830+
async def get_context(self, data: dict) -> InteractionContext[Self]:
18301831
match data["type"]:
18311832
case InteractionType.MESSAGE_COMPONENT:
18321833
cls = self.component_context.from_dict(self, data)

interactions/client/const.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@
4343
import sys
4444
from collections import defaultdict
4545
from importlib.metadata import version as _v, PackageNotFoundError
46-
from typing import TypeVar, Union, Callable, Coroutine, ClassVar
46+
import typing_extensions
47+
from typing import TypeVar, Union, Callable, Coroutine, ClassVar, TYPE_CHECKING
4748

4849
__all__ = (
4950
"__version__",
@@ -79,6 +80,7 @@
7980
"Absent",
8081
"T",
8182
"T_co",
83+
"ClientT",
8284
"LIB_PATH",
8385
"RECOVERABLE_WEBSOCKET_CLOSE_CODES",
8486
"NON_RESUMABLE_WEBSOCKET_CLOSE_CODES",
@@ -239,6 +241,13 @@ def has_client_feature(feature: str) -> bool:
239241
Absent = Union[T, Missing]
240242
AsyncCallable = Callable[..., Coroutine]
241243

244+
if TYPE_CHECKING:
245+
from interactions import Client
246+
247+
ClientT = typing_extensions.TypeVar("ClientT", bound=Client, default=Client)
248+
else:
249+
ClientT = TypeVar("ClientT")
250+
242251
LIB_PATH = os.sep.join(__file__.split(os.sep)[:-2])
243252
"""The path to the library folder."""
244253

interactions/ext/hybrid_commands/context.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
Permissions,
1010
Message,
1111
SlashContext,
12-
Client,
1312
Typing,
1413
Embed,
1514
BaseComponent,
@@ -25,6 +24,7 @@
2524
TYPE_MESSAGEABLE_CHANNEL,
2625
Poll,
2726
)
27+
from interactions.client.const import ClientT
2828
from interactions.models.discord.enums import ContextType
2929
from interactions.client.mixins.send import SendMixin
3030
from interactions.client.errors import HTTPException
@@ -38,7 +38,7 @@
3838

3939

4040
class DeferTyping:
41-
def __init__(self, ctx: "HybridContext", ephermal: bool) -> None:
41+
def __init__(self, ctx: "SlashContext[ClientT]", ephermal: bool) -> None:
4242
self.ctx = ctx
4343
self.ephermal = ephermal
4444

@@ -49,7 +49,7 @@ async def __aexit__(self, *_) -> None:
4949
pass
5050

5151

52-
class HybridContext(BaseContext, SendMixin):
52+
class HybridContext(BaseContext[ClientT], SendMixin):
5353
prefix: str
5454
"The prefix used to invoke this command."
5555

@@ -77,10 +77,10 @@ class HybridContext(BaseContext, SendMixin):
7777

7878
__attachment_index__: int
7979

80-
_slash_ctx: SlashContext | None
81-
_prefixed_ctx: prefixed.PrefixedContext | None
80+
_slash_ctx: SlashContext[ClientT] | None
81+
_prefixed_ctx: prefixed.PrefixedContext[ClientT] | None
8282

83-
def __init__(self, client: Client):
83+
def __init__(self, client: ClientT):
8484
super().__init__(client)
8585
self.prefix = ""
8686
self.app_permissions = Permissions(0)
@@ -97,12 +97,12 @@ def __init__(self, client: Client):
9797
self._prefixed_ctx = None
9898

9999
@classmethod
100-
def from_dict(cls, client: Client, payload: dict) -> None:
100+
def from_dict(cls, client: ClientT, payload: dict) -> None:
101101
# this doesn't mean anything, so just implement it to make abc happy
102102
raise NotImplementedError
103103

104104
@classmethod
105-
def from_slash_context(cls, ctx: SlashContext) -> Self:
105+
def from_slash_context(cls, ctx: SlashContext[ClientT]) -> Self:
106106
self = cls(ctx.client)
107107
self.guild_id = ctx.guild_id
108108
self.channel_id = ctx.channel_id
@@ -121,7 +121,7 @@ def from_slash_context(cls, ctx: SlashContext) -> Self:
121121
return self
122122

123123
@classmethod
124-
def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self:
124+
def from_prefixed_context(cls, ctx: prefixed.PrefixedContext[ClientT]) -> Self:
125125
# this is a "best guess" on what the permissions are
126126
# this may or may not be totally accurate
127127
if hasattr(ctx.channel, "permissions_for"):
@@ -163,7 +163,7 @@ def from_prefixed_context(cls, ctx: prefixed.PrefixedContext) -> Self:
163163
return self
164164

165165
@property
166-
def inner_context(self) -> SlashContext | prefixed.PrefixedContext:
166+
def inner_context(self) -> SlashContext[ClientT] | prefixed.PrefixedContext[ClientT]:
167167
"""The inner context that this hybrid context is wrapping."""
168168
return self._slash_ctx or self._prefixed_ctx # type: ignore
169169

interactions/ext/prefixed_commands/context.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from typing_extensions import Self
44

5-
from interactions.client.client import Client
5+
from interactions.client.const import ClientT
66
from interactions.client.mixins.send import SendMixin
77
from interactions.models.discord.channel import TYPE_MESSAGEABLE_CHANNEL
88
from interactions.models.discord.embed import Embed
@@ -17,7 +17,7 @@
1717
__all__ = ("PrefixedContext",)
1818

1919

20-
class PrefixedContext(BaseContext, SendMixin):
20+
class PrefixedContext(BaseContext[ClientT], SendMixin):
2121
_message: Message
2222

2323
prefix: str
@@ -33,12 +33,12 @@ class PrefixedContext(BaseContext, SendMixin):
3333
"This is always empty, and is only here for compatibility with other types of commands."
3434

3535
@classmethod
36-
def from_dict(cls, client: "Client", payload: dict) -> Self:
36+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
3737
# this doesn't mean anything, so just implement it to make abc happy
3838
raise NotImplementedError
3939

4040
@classmethod
41-
def from_message(cls, client: "Client", message: Message) -> Self:
41+
def from_message(cls, client: "ClientT", message: Message) -> Self:
4242
instance = cls(client=client)
4343

4444
# hack to work around BaseContext property

interactions/models/internal/context.py

Lines changed: 23 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from aiohttp import FormData
1010

1111
from interactions.client import const
12-
from interactions.client.const import get_logger, MISSING
12+
from interactions.client.const import get_logger, MISSING, ClientT
1313
from interactions.models.discord.components import BaseComponent
1414
from interactions.models.discord.file import UPLOADABLE_TYPE
1515
from interactions.models.discord.poll import Poll
@@ -149,17 +149,14 @@ def from_dict(cls, client: "interactions.Client", data: dict, guild_id: None | S
149149
return instance
150150

151151

152-
class BaseContext(metaclass=abc.ABCMeta):
152+
class BaseContext(typing.Generic[ClientT], metaclass=abc.ABCMeta):
153153
"""
154154
Base context class for all contexts.
155155
156156
Define your own context class by inheriting from this class. For compatibility with the library, you must define a `from_dict` classmethod that takes a dict and returns an instance of your context class.
157157
158158
"""
159159

160-
client: "interactions.Client"
161-
"""The client that created this context."""
162-
163160
command: BaseCommand
164161
"""The command this context invokes."""
165162

@@ -173,8 +170,10 @@ class BaseContext(metaclass=abc.ABCMeta):
173170
guild_id: typing.Optional[Snowflake]
174171
"""The id of the guild this context was invoked in, if any."""
175172

176-
def __init__(self, client: "interactions.Client") -> None:
177-
self.client = client
173+
def __init__(self, client: ClientT) -> None:
174+
self.client: ClientT = client
175+
"""The client that created this context."""
176+
178177
self.author_id = MISSING
179178
self.channel_id = MISSING
180179
self.message_id = MISSING
@@ -218,12 +217,12 @@ def voice_state(self) -> typing.Optional["interactions.VoiceState"]:
218217
return self.client.cache.get_bot_voice_state(self.guild_id)
219218

220219
@property
221-
def bot(self) -> "interactions.Client":
220+
def bot(self) -> "ClientT":
222221
return self.client
223222

224223
@classmethod
225224
@abc.abstractmethod
226-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
225+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
227226
"""
228227
Create a context instance from a dict.
229228
@@ -238,7 +237,7 @@ def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
238237
raise NotImplementedError
239238

240239

241-
class BaseInteractionContext(BaseContext):
240+
class BaseInteractionContext(BaseContext[ClientT]):
242241
token: str
243242
"""The interaction token."""
244243
id: Snowflake
@@ -281,14 +280,14 @@ class BaseInteractionContext(BaseContext):
281280
kwargs: dict[str, typing.Any]
282281
"""The keyword arguments passed to the interaction."""
283282

284-
def __init__(self, client: "interactions.Client") -> None:
283+
def __init__(self, client: "ClientT") -> None:
285284
super().__init__(client)
286285
self.deferred = False
287286
self.responded = False
288287
self.ephemeral = False
289288

290289
@classmethod
291-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
290+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
292291
instance = cls(client=client)
293292
instance.token = payload["token"]
294293
instance.id = Snowflake(payload["id"])
@@ -418,7 +417,7 @@ def gather_options(_options: list[dict[str, typing.Any]]) -> dict[str, typing.An
418417
self.args = list(self.kwargs.values())
419418

420419

421-
class InteractionContext(BaseInteractionContext, SendMixin):
420+
class InteractionContext(BaseInteractionContext[ClientT], SendMixin):
422421
async def defer(self, *, ephemeral: bool = False, suppress_error: bool = False) -> None:
423422
"""
424423
Defer the interaction.
@@ -657,26 +656,26 @@ async def edit(
657656
return self.client.cache.place_message_data(message_data)
658657

659658

660-
class SlashContext(InteractionContext, ModalMixin):
659+
class SlashContext(InteractionContext[ClientT], ModalMixin):
661660
@classmethod
662-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
661+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
663662
return super().from_dict(client, payload)
664663

665664

666-
class ContextMenuContext(InteractionContext, ModalMixin):
665+
class ContextMenuContext(InteractionContext[ClientT], ModalMixin):
667666
target_id: Snowflake
668667
"""The id of the target of the context menu."""
669668
editing_origin: bool
670669
"""Whether you have deferred the interaction and are editing the original response."""
671670
target_type: None | CommandType
672671
"""The type of the target of the context menu."""
673672

674-
def __init__(self, client: "interactions.Client") -> None:
673+
def __init__(self, client: "ClientT") -> None:
675674
super().__init__(client)
676675
self.editing_origin = False
677676

678677
@classmethod
679-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
678+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
680679
instance = super().from_dict(client, payload)
681680
instance.target_id = Snowflake(payload["data"]["target_id"])
682681
instance.target_type = CommandType(payload["data"]["type"])
@@ -739,7 +738,7 @@ def target(self) -> None | Message | User | Member:
739738
return self.resolved.get(self.target_id)
740739

741740

742-
class ComponentContext(InteractionContext, ModalMixin):
741+
class ComponentContext(InteractionContext[ClientT], ModalMixin):
743742
values: list[str]
744743
"""The values of the SelectMenu component, if any."""
745744
custom_id: str
@@ -750,7 +749,7 @@ class ComponentContext(InteractionContext, ModalMixin):
750749
"""Whether you have deferred the interaction and are editing the original response."""
751750

752751
@classmethod
753-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
752+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
754753
instance = super().from_dict(client, payload)
755754
instance.values = payload["data"].get("values", [])
756755
instance.custom_id = payload["data"]["custom_id"]
@@ -920,7 +919,7 @@ def component(self) -> typing.Optional[BaseComponent]:
920919
return component
921920

922921

923-
class ModalContext(InteractionContext):
922+
class ModalContext(InteractionContext[ClientT]):
924923
responses: dict[str, str]
925924
"""The responses of the modal. The key is the `custom_id` of the component."""
926925
custom_id: str
@@ -929,7 +928,7 @@ class ModalContext(InteractionContext):
929928
"""Whether to edit the original message instead of sending a new one."""
930929

931930
@classmethod
932-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
931+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
933932
instance = super().from_dict(client, payload)
934933
instance.responses = {
935934
comp["components"][0]["custom_id"]: comp["components"][0]["value"] for comp in payload["data"]["components"]
@@ -990,12 +989,12 @@ async def _defer(self, *, ephemeral: bool = False, edit_origin: bool = False) ->
990989
self.ephemeral = ephemeral
991990

992991

993-
class AutocompleteContext(BaseInteractionContext):
992+
class AutocompleteContext(BaseInteractionContext[ClientT]):
994993
focussed_option: SlashCommandOption # todo: option parsing
995994
"""The option the user is currently filling in."""
996995

997996
@classmethod
998-
def from_dict(cls, client: "interactions.Client", payload: dict) -> Self:
997+
def from_dict(cls, client: "ClientT", payload: dict) -> Self:
999998
return super().from_dict(client, payload)
1000999

10011000
@property

0 commit comments

Comments
 (0)