Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 93 additions & 81 deletions discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import asyncio
import collections
import collections.abc
import copy
import inspect
import logging
Expand All @@ -41,6 +42,7 @@
Generator,
List,
Literal,
Mapping,
Optional,
Type,
TypeVar,
Expand Down Expand Up @@ -109,7 +111,7 @@ def pending_application_commands(self):
def commands(self) -> List[Union[ApplicationCommand, Any]]:
commands = self.application_commands
if self._bot._supports_prefixed_commands and hasattr(self._bot, "prefixed_commands"):
commands += self._bot.prefixed_commands
commands += getattr(self._bot, "prefixed_commands")
return commands

@property
Expand Down Expand Up @@ -217,7 +219,7 @@ def get_application_command(
async def get_desynced_commands(
self,
guild_id: Optional[int] = None,
prefetched: Optional[List[ApplicationCommand]] = None
prefetched: Optional[List[interactions.ApplicationCommand]] = None
) -> List[Dict[str, Any]]:
"""|coro|

Expand Down Expand Up @@ -248,7 +250,7 @@ async def get_desynced_commands(

# We can suggest the user to upsert, edit, delete, or bulk upsert the commands

def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool:
if isinstance(cmd, SlashCommandGroup):
if len(cmd.subcommands) != len(match.get("options", [])):
return True
Expand Down Expand Up @@ -300,24 +302,25 @@ def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
# TODO: Remove for perms v2
continue
return True
return False
return False

return_value = []
cmds = self.pending_application_commands.copy()

if guild_id is None:
if prefetched is not None:
registered_commands = prefetched
else:
registered_commands = await self._bot.http.get_global_commands(self.user.id)
pending = [cmd for cmd in cmds if cmd.guild_ids is None]
else:
if prefetched is not None:
registered_commands = prefetched
else:
registered_commands = await self._bot.http.get_guild_commands(self.user.id, guild_id)
pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids]

registered_commands: List[interactions.ApplicationCommand] = []
if prefetched is not None:
registered_commands = prefetched
elif self._bot.user:
if guild_id is None:
registered_commands = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)

registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands}
# First let's check if the commands we have locally are the same as the ones on discord
for cmd in pending:
Expand Down Expand Up @@ -358,7 +361,7 @@ async def register_command(
self,
command: ApplicationCommand,
force: bool = True,
guild_ids: List[int] = None,
guild_ids: Optional[List[int]] = None,
) -> None:
"""|coro|

Expand All @@ -382,7 +385,7 @@ async def register_command(
The command that was registered
"""
# TODO: Write this
raise RuntimeError("This function has not been implemented yet")
raise NotImplementedError

async def register_commands(
self,
Expand Down Expand Up @@ -439,7 +442,7 @@ async def register_commands(
}

def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
return registration_methods[method](self._bot.user.id, *args, **kwargs)
return registration_methods[method](self._bot.user and self._bot.user.id, *args, **kwargs)

else:
pending = list(
Expand All @@ -456,27 +459,30 @@ def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwar
}

def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
return registration_methods[method](self._bot.user.id, guild_id, *args, **kwargs)
return registration_methods[method](self._bot.user and self._bot.user.id, guild_id, *args, **kwargs)

def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
if kwargs.pop("_log", True):
if method == "bulk":
_log.debug(f"Bulk updating commands {[c['name'] for c in args[0]]} for guild {guild_id}")
# TODO: Find where "cmd" is defined
elif method == "upsert":
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}")
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}") # type: ignore
elif method == "edit":
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}")
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}") # type: ignore
elif method == "delete":
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}")
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}") # type: ignore
return _register(method, *args, **kwargs)

pending_actions = []

if not force:
if guild_id is None:
prefetched_commands = await self.http.get_global_commands(self.user.id)
else:
prefetched_commands = await self.http.get_guild_commands(self.user.id, guild_id)
prefetched_commands: List[interactions.ApplicationCommand] = []
if self._bot.user:
if guild_id is None:
prefetched_commands = await self._bot.http.get_global_commands(self._bot.user.id)
else:
prefetched_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
desynced = await self.get_desynced_commands(guild_id=guild_id, prefetched=prefetched_commands)

for cmd in desynced:
Expand Down Expand Up @@ -549,10 +555,11 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg

# TODO: Our lists dont work sometimes, see if that can be fixed so we can avoid this second API call
if method != "bulk":
if guild_id is None:
registered = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
if self._bot.user:
if guild_id is None:
registered = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
else:
data = [cmd.to_dict() for cmd in pending]
registered = await register("bulk", data)
Expand All @@ -561,10 +568,10 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg
cmd = get(
self.pending_application_commands,
name=i["name"],
type=i["type"],
type=i.get("type"),
)
if not cmd:
raise ValueError(f"Registered command {i['name']}, type {i['type']} not found in pending commands")
raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands")
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

Expand Down Expand Up @@ -622,7 +629,7 @@ async def sync_commands(
Whether to delete existing commands that are not in the list of commands to register. Defaults to True.
"""

check_guilds = list(set((check_guilds or []) + (self.debug_guilds or [])))
check_guilds = list(set((check_guilds or []) + (self._bot.debug_guilds or [])))

if commands is None:
commands = self.pending_application_commands
Expand All @@ -636,48 +643,51 @@ async def sync_commands(
global_commands, method=method, force=force, delete_existing=delete_existing
)

registered_guild_commands = {}
registered_guild_commands: Dict[int, List[interactions.ApplicationCommand]] = {}

if register_guild_commands:
cmd_guild_ids = []
cmd_guild_ids: List[int] = []
for cmd in commands:
if cmd.guild_ids is not None:
cmd_guild_ids.extend(cmd.guild_ids)
if check_guilds is not None:
cmd_guild_ids.extend(check_guilds)
for guild_id in set(cmd_guild_ids):
guild_commands = [cmd for cmd in commands if cmd.guild_ids is not None and guild_id in cmd.guild_ids]
registered_guild_commands[guild_id] = await self.register_commands(
app_cmds = await self.register_commands(
guild_commands, guild_id=guild_id, method=method, force=force, delete_existing=delete_existing
)
registered_guild_commands[guild_id] = app_cmds

for i in registered_commands:
cmd = get(
self.pending_application_commands,
name=i["name"],
guild_ids=None,
type=i["type"],
type=i.get("type"),
)
if cmd:
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

for guild_id, commands in registered_guild_commands.items():
for i in commands:
cmd = find(
lambda cmd: cmd.name == i["name"]
and cmd.type == i["type"]
and cmd.guild_ids is not None
and int(i["guild_id"]) in cmd.guild_ids,
self.pending_application_commands,
)
if not cmd:
# command has not been added yet
continue
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd
if register_guild_commands and registered_guild_commands:
for guild_id, guild_cmds in registered_guild_commands.items():
for i in guild_cmds:
cmd = find(
lambda cmd: cmd.name == i["name"]
and cmd.type == i.get("type")
and cmd.guild_ids is not None
# TODO: fix this type error (guild_id is not defined in ApplicationCommand Typed Dict)
and int(i["guild_id"]) in cmd.guild_ids, # type: ignore
self.pending_application_commands,
)
if not cmd:
# command has not been added yet
continue
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

async def process_application_commands(self, interaction: Interaction, auto_sync: bool = None) -> None:
async def process_application_commands(self, interaction: Interaction, auto_sync: Optional[bool] = None) -> None:
"""|coro|

This function processes the commands that have been registered
Expand All @@ -698,33 +708,37 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
-----------
interaction: :class:`discord.Interaction`
The interaction to process
auto_sync: :class:`bool`
auto_sync: Optional[:class:`bool`]
Whether to automatically sync and unregister the command if it is not found in the internal cache. This will
invoke the :meth:`~.Bot.sync_commands` method on the context of the command, either globally or per-guild,
based on the type of the command, respectively. Defaults to :attr:`.Bot.auto_sync_commands`.
"""
if auto_sync is None:
auto_sync = self._bot.auto_sync_commands
# TODO: find out why the isinstance check below doesn't stop the type errors below
if interaction.type not in (
InteractionType.application_command,
InteractionType.auto_complete,
):
) and isinstance(interaction.data, interactions.ComponentInteractionData):
return

command: Optional[ApplicationCommand] = None
try:
command = self._application_commands[interaction.data["id"]]
if interaction.data:
command = self._application_commands[interaction.data["id"]] # type: ignore
except KeyError:
for cmd in self.application_commands + self.pending_application_commands:
guild_id = interaction.data.get("guild_id")
if guild_id:
guild_id = int(guild_id)
if cmd.name == interaction.data["name"] and (
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
):
command = cmd
break
if interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id:
guild_id = int(guild_id)
if cmd.name == interaction.data["name"] and ( # type: ignore
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
):
command = cmd
break
else:
if auto_sync:
if auto_sync and interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id is None:
await self.sync_commands()
Expand All @@ -734,26 +748,28 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
return self._bot.dispatch("unknown_application_command", interaction)

if interaction.type is InteractionType.auto_complete:
return self.dispatch("application_command_auto_complete", interaction, command)
return self._bot.dispatch("application_command_auto_complete", interaction, command)

ctx = await self.get_application_context(interaction)
ctx.command = command
if command:
ctx.command = command
await self.invoke_application_command(ctx)

async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None:
async def callback() -> None:
ctx = await self.get_autocomplete_context(interaction)
ctx.command = command
return await command.invoke_autocomplete_callback(ctx)
if isinstance(command, SlashCommand):
async def callback() -> None:
ctx = await self.get_autocomplete_context(interaction)
ctx.command = command
return await command.invoke_autocomplete_callback(ctx)

autocomplete_task = self.loop.create_task(callback())
try:
await self.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
except asyncio.TimeoutError:
return
else:
if not autocomplete_task.done():
autocomplete_task.cancel()
autocomplete_task = self._bot.loop.create_task(callback())
try:
await self._bot.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
except asyncio.TimeoutError:
return
else:
if not autocomplete_task.done():
autocomplete_task.cancel()

def slash_command(self, **kwargs):
"""A shortcut decorator that invokes :func:`command` and adds it to
Expand Down Expand Up @@ -924,7 +940,7 @@ def walk_application_commands(self) -> Generator[ApplicationCommand, None, None]
yield from command.walk_commands()
yield command

async def get_application_context(self, interaction: Interaction, cls=None) -> ApplicationContext:
async def get_application_context(self, interaction: Interaction, cls: Any = ApplicationContext) -> ApplicationContext:
r"""|coro|

Returns the invocation context from the interaction.
Expand All @@ -948,11 +964,9 @@ class be provided, it must be similar enough to
The invocation context. The type of this can change via the
``cls`` parameter.
"""
if cls is None:
cls = ApplicationContext
return cls(self, interaction)

async def get_autocomplete_context(self, interaction: Interaction, cls=None) -> AutocompleteContext:
async def get_autocomplete_context(self, interaction: Interaction, cls: Any = AutocompleteContext) -> AutocompleteContext:
r"""|coro|

Returns the autocomplete context from the interaction.
Expand All @@ -976,8 +990,6 @@ class be provided, it must be similar enough to
The autocomplete context. The type of this can change via the
``cls`` parameter.
"""
if cls is None:
cls = AutocompleteContext
return cls(self, interaction)

async def invoke_application_command(self, ctx: ApplicationContext) -> None:
Expand Down
3 changes: 3 additions & 0 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def qualified_name(self) -> str:
else:
return self.name

def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError

def __str__(self) -> str:
return self.qualified_name

Expand Down