|
30 | 30 | import functools |
31 | 31 | import inspect |
32 | 32 | from collections import OrderedDict |
33 | | -from typing import Any, Callable, Dict, List, Optional, Union |
| 33 | +from typing import Any, Callable, Dict, List, Optional, Union, TYPE_CHECKING |
34 | 34 |
|
35 | 35 | from ..enums import SlashCommandOptionType, ChannelType |
36 | 36 | from ..member import Member |
|
60 | 60 | "MessageCommand", |
61 | 61 | ) |
62 | 62 |
|
| 63 | +if TYPE_CHECKING: |
| 64 | + from ..interactions import Interaction |
| 65 | + |
63 | 66 | def wrap_callback(coro): |
64 | 67 | @functools.wraps(coro) |
65 | 68 | async def wrapped(*args, **kwargs): |
@@ -351,7 +354,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: |
351 | 354 | self.cog = None |
352 | 355 |
|
353 | 356 | params = self._get_signature_parameters() |
354 | | - self.options = self._parse_options(params) |
| 357 | + self.options: List[Option] = self._parse_options(params) |
355 | 358 |
|
356 | 359 | try: |
357 | 360 | checks = func.__commands_checks__ |
@@ -487,6 +490,17 @@ async def _invoke(self, ctx: ApplicationContext) -> None: |
487 | 490 | else: |
488 | 491 | await self.callback(ctx, **kwargs) |
489 | 492 |
|
| 493 | + async def invoke_autocomplete_callback(self, interaction: Interaction): |
| 494 | + for op in interaction.data.get("options", []): |
| 495 | + if op.get("focused", False): |
| 496 | + option = find(lambda o: o.name == op["name"], self.options) |
| 497 | + result = await option.autocomplete(interaction, op.get("value", None)) |
| 498 | + choices = [ |
| 499 | + o if isinstance(o, OptionChoice) else OptionChoice(o) |
| 500 | + for o in result |
| 501 | + ] |
| 502 | + await interaction.response.send_autocomplete_result(choices=choices) |
| 503 | + |
490 | 504 | def qualified_name(self): |
491 | 505 | return self.name |
492 | 506 |
|
@@ -581,13 +595,21 @@ def __init__( |
581 | 595 | if not (isinstance(self.max_value, minmax_types) or self.min_value is None): |
582 | 596 | raise TypeError(f"Expected {minmax_typehint} for max_value, got \"{type(self.max_value).__name__}\"") |
583 | 597 |
|
| 598 | + self.autocomplete = kwargs.pop("autocomplete", None) |
| 599 | + if ( |
| 600 | + self.autocomplete and |
| 601 | + not asyncio.iscoroutinefunction(self.autocomplete) |
| 602 | + ): |
| 603 | + raise TypeError("Autocomplete callback must be a coroutine.") |
| 604 | + |
584 | 605 | def to_dict(self) -> Dict: |
585 | 606 | as_dict = { |
586 | 607 | "name": self.name, |
587 | 608 | "description": self.description, |
588 | 609 | "type": self.input_type.value, |
589 | 610 | "required": self.required, |
590 | 611 | "choices": [c.to_dict() for c in self.choices], |
| 612 | + "autocomplete": bool(self.autocomplete) |
591 | 613 | } |
592 | 614 | if self.channel_types: |
593 | 615 | as_dict["channel_types"] = [t.value for t in self.channel_types] |
@@ -722,6 +744,13 @@ async def _invoke(self, ctx: ApplicationContext) -> None: |
722 | 744 | ctx.interaction.data = option |
723 | 745 | await command.invoke(ctx) |
724 | 746 |
|
| 747 | + async def invoke_autocomplete_callback(self, interaction: Interaction) -> None: |
| 748 | + option = interaction.data["options"][0] |
| 749 | + command = find(lambda x: x.name == option["name"], self.subcommands) |
| 750 | + interaction.data = option |
| 751 | + await command.invoke_autocomplete_callback(interaction) |
| 752 | + |
| 753 | + |
725 | 754 | class ContextMenuCommand(ApplicationCommand): |
726 | 755 | r"""A class that implements the protocol for context menu commands. |
727 | 756 |
|
|
0 commit comments