Skip to content

Commit

Permalink
Update core.py
Browse files Browse the repository at this point in the history
  • Loading branch information
baronkobama committed Jun 27, 2022
1 parent a0ee549 commit 8b1497b
Showing 1 changed file with 21 additions and 14 deletions.
35 changes: 21 additions & 14 deletions discord/ext/bridge/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
DEALINGS IN THE SOFTWARE.
"""
from typing import Any, Union
import asyncio

import discord.commands.options
from discord.commands import Option, SlashCommand
Expand Down Expand Up @@ -77,8 +78,9 @@ def __init__(self, callback, **kwargs):
"""
self.callback = callback
self.kwargs = kwargs
self.ext_command = None # will be defined when self.add_to is called
self.application_command = None # will be defined when self.add_to is called

self.ext_command = BridgeExtCommand(self.callback, **self.kwargs)
self.application_command = BridgeSlashCommand(self.callback, **self.kwargs)

def get_ext_command(self):
"""A method to get the ext.commands version of this command.
Expand All @@ -88,8 +90,7 @@ def get_ext_command(self):
:class:`BridgeExtCommand`
The respective traditional (prefix-based) version of the command.
"""
command = BridgeExtCommand(self.callback, **self.kwargs)
return command
return self.ext_command

def get_application_command(self):
"""A method to get the discord.commands version of this command.
Expand All @@ -99,8 +100,7 @@ def get_application_command(self):
:class:`BridgeSlashCommand`
The respective slash command version of the command.
"""
command = BridgeSlashCommand(self.callback, **self.kwargs)
return command
return self.application_command

def add_to(self, bot: Union[ExtBot, ExtAutoShardedBot]) -> None:
"""Adds the command to a bot.
Expand All @@ -110,8 +110,6 @@ def add_to(self, bot: Union[ExtBot, ExtAutoShardedBot]) -> None:
bot: Union[:class:`ExtBot`, :class:`ExtAutoShardedBot`]
The bot to add the command to.
"""
self.ext_command = self.get_ext_command()
self.application_command = self.get_application_command()

bot.add_command(self.ext_command)
bot.add_application_command(self.application_command)
Expand All @@ -138,8 +136,11 @@ def error(self, coro):
The coroutine passed is not actually a coroutine.
"""

self.ext_command.error(coro)
self.application_command.error(coro)
if not asyncio.iscoroutinefunction(coro):
raise TypeError("The error handler must be a coroutine.")

self.ext_command.on_error = coro
self.application_command.on_error = coro

return coro

Expand All @@ -163,8 +164,11 @@ def before_invoke(self, coro):
The coroutine passed is not actually a coroutine.
"""

self.ext_command.before_invoke(coro)
self.application_command.before_invoke(coro)
if not asyncio.iscoroutinefunction(coro):
raise TypeError("The pre-invoke hook must be a coroutine.")

self.ext_command.before_invoke = coro
self.application_command.before_invoke = coro

return coro

Expand All @@ -188,8 +192,11 @@ def after_invoke(self, coro):
The coroutine passed is not actually a coroutine.
"""

self.ext_command.after_invoke(coro)
self.application_command.after_invoke(coro)
if not asyncio.iscoroutinefunction(coro):
raise TypeError("The post-invoke hook must be a coroutine.")

self.ext_command.after_invoke = coro
self.application_command.after_invoke = coro

return coro

Expand Down

0 comments on commit 8b1497b

Please sign in to comment.