Skip to content

Commit

Permalink
Extension fixes (Pycord-Development#817)
Browse files Browse the repository at this point in the history
  • Loading branch information
plun1331 authored Jan 21, 2022
1 parent fe92268 commit a9f6bd5
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 16 deletions.
12 changes: 11 additions & 1 deletion discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,10 @@ def __init__(self, *args, **kwargs) -> None:
def pending_application_commands(self):
return self._pending_application_commands

@property
def all_commands(self):
return self._application_commands

@property
def commands(self) -> List[Union[ApplicationCommand, Any]]:
commands = self.application_commands
Expand Down Expand Up @@ -149,9 +153,15 @@ def remove_application_command(
Returns
--------
Optional[:class:`.ApplicationCommand`]
The command that was removed. If the name is not valid then
The command that was removed. If the command is not valid then
``None`` is returned instead.
"""
if command.id is None:
try:
index = self._pending_application_commands.index(command)
except ValueError:
return None
return self._pending_application_commands.pop(index)
return self._application_commands.pop(command.id)

@property
Expand Down
35 changes: 20 additions & 15 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,25 +1159,30 @@ class GroupMixin(Generic[CogT]):
Attributes
-----------
all_commands: :class:`dict`
prefixed_commands: :class:`dict`
A mapping of command name to :class:`.Command`
objects.
case_insensitive: :class:`bool`
Whether the commands should be case insensitive. Defaults to ``False``.
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)

@property
def all_commands(self):
# merge app and prefixed commands
return {**self._application_commands, **self.prefixed_commands}

@property
def commands(self) -> Set[Command[CogT, Any, Any]]:
"""Set[:class:`.Command`]: A unique set of commands without aliases that are registered."""
return set(self.all_commands.values())

def recursively_remove_all_commands(self) -> None:
for command in self.all_commands.copy().values():
for command in self.prefixed_commands.copy().values():
if isinstance(command, GroupMixin):
command.recursively_remove_all_commands()
self.remove_command(command.name)
Expand Down Expand Up @@ -1210,15 +1215,15 @@ def add_command(self, command: Command[CogT, Any, Any]) -> None:
if isinstance(self, Command):
command.parent = self

if command.name in self.all_commands:
if command.name in self.prefixed_commands:
raise CommandRegistrationError(command.name)

self.all_commands[command.name] = command
self.prefixed_commands[command.name] = command
for alias in command.aliases:
if alias in self.all_commands:
if alias in self.prefixed_commands:
self.remove_command(command.name)
raise CommandRegistrationError(alias, alias_conflict=True)
self.all_commands[alias] = command
self.prefixed_commands[alias] = command

def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
"""Remove a :class:`.Command` from the internal list
Expand All @@ -1237,7 +1242,7 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
The command that was removed. If the name is not valid then
``None`` is returned instead.
"""
command = self.all_commands.pop(name, None)
command = self.prefixed_commands.pop(name, None)

# does not exist
if command is None:
Expand All @@ -1249,12 +1254,12 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:

# we're not removing the alias so let's delete the rest of them.
for alias in command.aliases:
cmd = self.all_commands.pop(alias, None)
cmd = self.prefixed_commands.pop(alias, None)
# in the case of a CommandRegistrationError, an alias might conflict
# with an already existing command. If this is the case, we want to
# make sure the pre-existing command is not removed.
if cmd is not None and cmd != command:
self.all_commands[alias] = cmd
self.prefixed_commands[alias] = cmd
return command

def walk_commands(self) -> Generator[Command[CogT, Any, Any], None, None]:
Expand Down Expand Up @@ -1296,18 +1301,18 @@ def get_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:

# fast path, no space in name.
if ' ' not in name:
return self.all_commands.get(name)
return self.prefixed_commands.get(name)

names = name.split()
if not names:
return None
obj = self.all_commands.get(names[0])
obj = self.prefixed_commands.get(names[0])
if not isinstance(obj, GroupMixin):
return obj

for name in names[1:]:
try:
obj = obj.all_commands[name] # type: ignore
obj = obj.prefixed_commands[name] # type: ignore
except (AttributeError, KeyError):
return None

Expand Down Expand Up @@ -1463,7 +1468,7 @@ async def invoke(self, ctx: Context) -> None:

if trigger:
ctx.subcommand_passed = trigger
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)

if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback)
Expand Down Expand Up @@ -1497,7 +1502,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:

if trigger:
ctx.subcommand_passed = trigger
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)

if early_invoke:
try:
Expand Down

0 comments on commit a9f6bd5

Please sign in to comment.