Skip to content

Commit 545f9e2

Browse files
feat/fix: improve typehinting of wait_fors (#1694)
* feat/fix: improve typehinting of wait_fors * ci: correct from checks. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 59471c4 commit 545f9e2

File tree

1 file changed

+86
-6
lines changed

1 file changed

+86
-6
lines changed

interactions/client/client.py

Lines changed: 86 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,8 @@
2626
Union,
2727
Awaitable,
2828
Tuple,
29+
TypeVar,
30+
overload,
2931
)
3032

3133
from aiohttp import BasicAuth
@@ -40,6 +42,7 @@
4042
from interactions.client import errors
4143
from interactions.client.const import (
4244
GLOBAL_SCOPE,
45+
Missing,
4346
MISSING,
4447
Absent,
4548
EMBED_MAX_DESC_LENGTH,
@@ -122,6 +125,8 @@
122125
if TYPE_CHECKING:
123126
from interactions.models import Snowflake_Type, TYPE_ALL_CHANNEL
124127

128+
EventT = TypeVar("EventT", bound=BaseEvent)
129+
125130
__all__ = ("Client",)
126131

127132
# see https://discord.com/developers/docs/topics/gateway#list-of-intents
@@ -1061,12 +1066,36 @@ async def wait_until_ready(self) -> None:
10611066
"""Waits for the client to become ready."""
10621067
await self._ready.wait()
10631068

1069+
@overload
1070+
def wait_for(
1071+
self,
1072+
event: type[EventT],
1073+
checks: Absent[Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]]] = MISSING,
1074+
timeout: Optional[float] = None,
1075+
) -> "Awaitable[EventT]": ...
1076+
1077+
@overload
10641078
def wait_for(
10651079
self,
1066-
event: Union[str, "BaseEvent"],
1067-
checks: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] = MISSING,
1080+
event: str,
1081+
checks: Callable[[EventT], bool] | Callable[[EventT], Awaitable[bool]],
10681082
timeout: Optional[float] = None,
1069-
) -> Any:
1083+
) -> "Awaitable[EventT]": ...
1084+
1085+
@overload
1086+
def wait_for(
1087+
self,
1088+
event: str,
1089+
checks: Missing = MISSING,
1090+
timeout: Optional[float] = None,
1091+
) -> Awaitable[Any]: ...
1092+
1093+
def wait_for(
1094+
self,
1095+
event: Union[str, "type[BaseEvent]"],
1096+
checks: Absent[Callable[[BaseEvent], bool] | Callable[[BaseEvent], Awaitable[bool]]] = MISSING,
1097+
timeout: Optional[float] = None,
1098+
) -> Awaitable[Any]:
10701099
"""
10711100
Waits for a WebSocket event to be dispatched.
10721101
@@ -1112,17 +1141,68 @@ async def wait_for_modal(
11121141
"""
11131142
author = to_snowflake(author) if author else None
11141143

1115-
def predicate(event) -> bool:
1144+
def predicate(event: events.ModalCompletion) -> bool:
11161145
if modal.custom_id != event.ctx.custom_id:
11171146
return False
11181147
return author == to_snowflake(event.ctx.author) if author else True
11191148

11201149
resp = await self.wait_for("modal_completion", predicate, timeout)
11211150
return resp.ctx
11221151

1152+
@overload
1153+
async def wait_for_component(
1154+
self,
1155+
messages: Union[Message, int, list],
1156+
components: Union[
1157+
List[List[Union["BaseComponent", dict]]],
1158+
List[Union["BaseComponent", dict]],
1159+
"BaseComponent",
1160+
dict,
1161+
],
1162+
check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None,
1163+
timeout: Optional[float] = None,
1164+
) -> "events.Component": ...
1165+
1166+
@overload
1167+
async def wait_for_component(
1168+
self,
1169+
*,
1170+
components: Union[
1171+
List[List[Union["BaseComponent", dict]]],
1172+
List[Union["BaseComponent", dict]],
1173+
"BaseComponent",
1174+
dict,
1175+
],
1176+
check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None,
1177+
timeout: Optional[float] = None,
1178+
) -> "events.Component": ...
1179+
1180+
@overload
1181+
async def wait_for_component(
1182+
self,
1183+
messages: None,
1184+
components: Union[
1185+
List[List[Union["BaseComponent", dict]]],
1186+
List[Union["BaseComponent", dict]],
1187+
"BaseComponent",
1188+
dict,
1189+
],
1190+
check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None,
1191+
timeout: Optional[float] = None,
1192+
) -> "events.Component": ...
1193+
1194+
@overload
1195+
async def wait_for_component(
1196+
self,
1197+
messages: Union[Message, int, list],
1198+
components: None = None,
1199+
check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None,
1200+
timeout: Optional[float] = None,
1201+
) -> "events.Component": ...
1202+
11231203
async def wait_for_component(
11241204
self,
1125-
messages: Union[Message, int, list] = None,
1205+
messages: Optional[Union[Message, int, list]] = None,
11261206
components: Optional[
11271207
Union[
11281208
List[List[Union["BaseComponent", dict]]],
@@ -1131,7 +1211,7 @@ async def wait_for_component(
11311211
dict,
11321212
]
11331213
] = None,
1134-
check: Absent[Optional[Union[Callable[..., bool], Callable[..., Awaitable[bool]]]]] | None = None,
1214+
check: Optional[Callable[[events.Component], bool] | Callable[[events.Component], Awaitable[bool]]] = None,
11351215
timeout: Optional[float] = None,
11361216
) -> "events.Component":
11371217
"""

0 commit comments

Comments
 (0)