Skip to content

Commit

Permalink
Merge branch 'master' into master
Browse files Browse the repository at this point in the history
  • Loading branch information
martinbndr authored Sep 4, 2022
2 parents f10c3ca + e3e29ac commit 19e4e90
Show file tree
Hide file tree
Showing 3 changed files with 97 additions and 82 deletions.
174 changes: 93 additions & 81 deletions discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import asyncio
import collections
import collections.abc
import copy
import inspect
import logging
Expand All @@ -41,6 +42,7 @@
Generator,
List,
Literal,
Mapping,
Optional,
Type,
TypeVar,
Expand Down Expand Up @@ -109,7 +111,7 @@ def pending_application_commands(self):
def commands(self) -> List[Union[ApplicationCommand, Any]]:
commands = self.application_commands
if self._bot._supports_prefixed_commands and hasattr(self._bot, "prefixed_commands"):
commands += self._bot.prefixed_commands
commands += getattr(self._bot, "prefixed_commands")
return commands

@property
Expand Down Expand Up @@ -217,7 +219,7 @@ def get_application_command(
async def get_desynced_commands(
self,
guild_id: Optional[int] = None,
prefetched: Optional[List[ApplicationCommand]] = None
prefetched: Optional[List[interactions.ApplicationCommand]] = None
) -> List[Dict[str, Any]]:
"""|coro|
Expand Down Expand Up @@ -248,7 +250,7 @@ async def get_desynced_commands(

# We can suggest the user to upsert, edit, delete, or bulk upsert the commands

def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
def _check_command(cmd: ApplicationCommand, match: Mapping[str, Any]) -> bool:
if isinstance(cmd, SlashCommandGroup):
if len(cmd.subcommands) != len(match.get("options", [])):
return True
Expand Down Expand Up @@ -300,24 +302,25 @@ def _check_command(cmd: ApplicationCommand, match: Dict) -> bool:
# TODO: Remove for perms v2
continue
return True
return False
return False

return_value = []
cmds = self.pending_application_commands.copy()

if guild_id is None:
if prefetched is not None:
registered_commands = prefetched
else:
registered_commands = await self._bot.http.get_global_commands(self.user.id)
pending = [cmd for cmd in cmds if cmd.guild_ids is None]
else:
if prefetched is not None:
registered_commands = prefetched
else:
registered_commands = await self._bot.http.get_guild_commands(self.user.id, guild_id)
pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids]

registered_commands: List[interactions.ApplicationCommand] = []
if prefetched is not None:
registered_commands = prefetched
elif self._bot.user:
if guild_id is None:
registered_commands = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)

registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands}
# First let's check if the commands we have locally are the same as the ones on discord
for cmd in pending:
Expand Down Expand Up @@ -358,7 +361,7 @@ async def register_command(
self,
command: ApplicationCommand,
force: bool = True,
guild_ids: List[int] = None,
guild_ids: Optional[List[int]] = None,
) -> None:
"""|coro|
Expand All @@ -382,7 +385,7 @@ async def register_command(
The command that was registered
"""
# TODO: Write this
raise RuntimeError("This function has not been implemented yet")
raise NotImplementedError

async def register_commands(
self,
Expand Down Expand Up @@ -439,7 +442,7 @@ async def register_commands(
}

def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
return registration_methods[method](self._bot.user.id, *args, **kwargs)
return registration_methods[method](self._bot.user and self._bot.user.id, *args, **kwargs)

else:
pending = list(
Expand All @@ -456,27 +459,30 @@ def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwar
}

def _register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
return registration_methods[method](self._bot.user.id, guild_id, *args, **kwargs)
return registration_methods[method](self._bot.user and self._bot.user.id, guild_id, *args, **kwargs)

def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwargs):
if kwargs.pop("_log", True):
if method == "bulk":
_log.debug(f"Bulk updating commands {[c['name'] for c in args[0]]} for guild {guild_id}")
# TODO: Find where "cmd" is defined
elif method == "upsert":
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}")
_log.debug(f"Creating command {cmd['name']} for guild {guild_id}") # type: ignore
elif method == "edit":
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}")
_log.debug(f"Editing command {cmd['name']} for guild {guild_id}") # type: ignore
elif method == "delete":
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}")
_log.debug(f"Deleting command {cmd['name']} for guild {guild_id}") # type: ignore
return _register(method, *args, **kwargs)

pending_actions = []

if not force:
if guild_id is None:
prefetched_commands = await self.http.get_global_commands(self.user.id)
else:
prefetched_commands = await self.http.get_guild_commands(self.user.id, guild_id)
prefetched_commands: List[interactions.ApplicationCommand] = []
if self._bot.user:
if guild_id is None:
prefetched_commands = await self._bot.http.get_global_commands(self._bot.user.id)
else:
prefetched_commands = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
desynced = await self.get_desynced_commands(guild_id=guild_id, prefetched=prefetched_commands)

for cmd in desynced:
Expand Down Expand Up @@ -549,10 +555,11 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg

# TODO: Our lists dont work sometimes, see if that can be fixed so we can avoid this second API call
if method != "bulk":
if guild_id is None:
registered = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
if self._bot.user:
if guild_id is None:
registered = await self._bot.http.get_global_commands(self._bot.user.id)
else:
registered = await self._bot.http.get_guild_commands(self._bot.user.id, guild_id)
else:
data = [cmd.to_dict() for cmd in pending]
registered = await register("bulk", data)
Expand All @@ -561,10 +568,10 @@ def register(method: Literal["bulk", "upsert", "delete", "edit"], *args, **kwarg
cmd = get(
self.pending_application_commands,
name=i["name"],
type=i["type"],
type=i.get("type"),
)
if not cmd:
raise ValueError(f"Registered command {i['name']}, type {i['type']} not found in pending commands")
raise ValueError(f"Registered command {i['name']}, type {i.get('type')} not found in pending commands")
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

Expand Down Expand Up @@ -622,7 +629,7 @@ async def sync_commands(
Whether to delete existing commands that are not in the list of commands to register. Defaults to True.
"""

check_guilds = list(set((check_guilds or []) + (self.debug_guilds or [])))
check_guilds = list(set((check_guilds or []) + (self._bot.debug_guilds or [])))

if commands is None:
commands = self.pending_application_commands
Expand All @@ -636,48 +643,51 @@ async def sync_commands(
global_commands, method=method, force=force, delete_existing=delete_existing
)

registered_guild_commands = {}
registered_guild_commands: Dict[int, List[interactions.ApplicationCommand]] = {}

if register_guild_commands:
cmd_guild_ids = []
cmd_guild_ids: List[int] = []
for cmd in commands:
if cmd.guild_ids is not None:
cmd_guild_ids.extend(cmd.guild_ids)
if check_guilds is not None:
cmd_guild_ids.extend(check_guilds)
for guild_id in set(cmd_guild_ids):
guild_commands = [cmd for cmd in commands if cmd.guild_ids is not None and guild_id in cmd.guild_ids]
registered_guild_commands[guild_id] = await self.register_commands(
app_cmds = await self.register_commands(
guild_commands, guild_id=guild_id, method=method, force=force, delete_existing=delete_existing
)
registered_guild_commands[guild_id] = app_cmds

for i in registered_commands:
cmd = get(
self.pending_application_commands,
name=i["name"],
guild_ids=None,
type=i["type"],
type=i.get("type"),
)
if cmd:
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

for guild_id, commands in registered_guild_commands.items():
for i in commands:
cmd = find(
lambda cmd: cmd.name == i["name"]
and cmd.type == i["type"]
and cmd.guild_ids is not None
and int(i["guild_id"]) in cmd.guild_ids,
self.pending_application_commands,
)
if not cmd:
# command has not been added yet
continue
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd
if register_guild_commands and registered_guild_commands:
for guild_id, guild_cmds in registered_guild_commands.items():
for i in guild_cmds:
cmd = find(
lambda cmd: cmd.name == i["name"]
and cmd.type == i.get("type")
and cmd.guild_ids is not None
# TODO: fix this type error (guild_id is not defined in ApplicationCommand Typed Dict)
and int(i["guild_id"]) in cmd.guild_ids, # type: ignore
self.pending_application_commands,
)
if not cmd:
# command has not been added yet
continue
cmd.id = i["id"]
self._application_commands[cmd.id] = cmd

async def process_application_commands(self, interaction: Interaction, auto_sync: bool = None) -> None:
async def process_application_commands(self, interaction: Interaction, auto_sync: Optional[bool] = None) -> None:
"""|coro|
This function processes the commands that have been registered
Expand All @@ -698,33 +708,37 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
-----------
interaction: :class:`discord.Interaction`
The interaction to process
auto_sync: :class:`bool`
auto_sync: Optional[:class:`bool`]
Whether to automatically sync and unregister the command if it is not found in the internal cache. This will
invoke the :meth:`~.Bot.sync_commands` method on the context of the command, either globally or per-guild,
based on the type of the command, respectively. Defaults to :attr:`.Bot.auto_sync_commands`.
"""
if auto_sync is None:
auto_sync = self._bot.auto_sync_commands
# TODO: find out why the isinstance check below doesn't stop the type errors below
if interaction.type not in (
InteractionType.application_command,
InteractionType.auto_complete,
):
) and isinstance(interaction.data, interactions.ComponentInteractionData):
return

command: Optional[ApplicationCommand] = None
try:
command = self._application_commands[interaction.data["id"]]
if interaction.data:
command = self._application_commands[interaction.data["id"]] # type: ignore
except KeyError:
for cmd in self.application_commands + self.pending_application_commands:
guild_id = interaction.data.get("guild_id")
if guild_id:
guild_id = int(guild_id)
if cmd.name == interaction.data["name"] and (
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
):
command = cmd
break
if interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id:
guild_id = int(guild_id)
if cmd.name == interaction.data["name"] and ( # type: ignore
guild_id == cmd.guild_ids or (isinstance(cmd.guild_ids, list) and guild_id in cmd.guild_ids)
):
command = cmd
break
else:
if auto_sync:
if auto_sync and interaction.data:
guild_id = interaction.data.get("guild_id")
if guild_id is None:
await self.sync_commands()
Expand All @@ -734,26 +748,28 @@ async def process_application_commands(self, interaction: Interaction, auto_sync
return self._bot.dispatch("unknown_application_command", interaction)

if interaction.type is InteractionType.auto_complete:
return self.dispatch("application_command_auto_complete", interaction, command)
return self._bot.dispatch("application_command_auto_complete", interaction, command)

ctx = await self.get_application_context(interaction)
ctx.command = command
if command:
ctx.command = command
await self.invoke_application_command(ctx)

async def on_application_command_auto_complete(self, interaction: Interaction, command: ApplicationCommand) -> None:
async def callback() -> None:
ctx = await self.get_autocomplete_context(interaction)
ctx.command = command
return await command.invoke_autocomplete_callback(ctx)
if isinstance(command, SlashCommand):
async def callback() -> None:
ctx = await self.get_autocomplete_context(interaction)
ctx.command = command
return await command.invoke_autocomplete_callback(ctx)

autocomplete_task = self.loop.create_task(callback())
try:
await self.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
except asyncio.TimeoutError:
return
else:
if not autocomplete_task.done():
autocomplete_task.cancel()
autocomplete_task = self._bot.loop.create_task(callback())
try:
await self._bot.wait_for("application_command_auto_complete", check=lambda i, c: c == command, timeout=3)
except asyncio.TimeoutError:
return
else:
if not autocomplete_task.done():
autocomplete_task.cancel()

def slash_command(self, **kwargs):
"""A shortcut decorator that invokes :func:`command` and adds it to
Expand Down Expand Up @@ -924,7 +940,7 @@ def walk_application_commands(self) -> Generator[ApplicationCommand, None, None]
yield from command.walk_commands()
yield command

async def get_application_context(self, interaction: Interaction, cls=None) -> ApplicationContext:
async def get_application_context(self, interaction: Interaction, cls: Any = ApplicationContext) -> ApplicationContext:
r"""|coro|
Returns the invocation context from the interaction.
Expand All @@ -948,11 +964,9 @@ class be provided, it must be similar enough to
The invocation context. The type of this can change via the
``cls`` parameter.
"""
if cls is None:
cls = ApplicationContext
return cls(self, interaction)

async def get_autocomplete_context(self, interaction: Interaction, cls=None) -> AutocompleteContext:
async def get_autocomplete_context(self, interaction: Interaction, cls: Any = AutocompleteContext) -> AutocompleteContext:
r"""|coro|
Returns the autocomplete context from the interaction.
Expand All @@ -976,8 +990,6 @@ class be provided, it must be similar enough to
The autocomplete context. The type of this can change via the
``cls`` parameter.
"""
if cls is None:
cls = AutocompleteContext
return cls(self, interaction)

async def invoke_application_command(self, ctx: ApplicationContext) -> None:
Expand Down
3 changes: 3 additions & 0 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,9 @@ def qualified_name(self) -> str:
else:
return self.name

def to_dict(self) -> Dict[str, Any]:
raise NotImplementedError

def __str__(self) -> str:
return self.qualified_name

Expand Down
2 changes: 1 addition & 1 deletion docs/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1447,7 +1447,7 @@ to handle it, which defaults to print a traceback and ignoring the exception.
:param rule: The deleted rule.
:type rule: :class:`AutoModRule`

.. function:: on_auto_moderation_action_execution(guild, action)
.. function:: on_auto_moderation_action_execution(payload)

Called when an auto moderation action is executed.

Expand Down

0 comments on commit 19e4e90

Please sign in to comment.