From 7e93a47c65290b75e33f73786fc7c60903b439d5 Mon Sep 17 00:00:00 2001 From: Krittick Date: Mon, 14 Feb 2022 16:42:07 -0800 Subject: [PATCH] re-run Black after #1015 --- discord/__init__.py | 4 +- discord/__main__.py | 34 +- discord/abc.py | 114 ++----- discord/activity.py | 58 +--- discord/appinfo.py | 4 +- discord/asset.py | 20 +- discord/audit_logs.py | 88 ++--- discord/bot.py | 111 ++---- discord/channel.py | 152 ++------- discord/client.py | 69 +--- discord/cog.py | 47 +-- discord/colour.py | 4 +- discord/commands/context.py | 9 +- discord/commands/core.py | 152 +++------ discord/commands/errors.py | 4 +- discord/commands/options.py | 15 +- discord/commands/permissions.py | 16 +- discord/components.py | 12 +- discord/embeds.py | 30 +- discord/emoji.py | 8 +- discord/enums.py | 28 +- discord/errors.py | 8 +- discord/ext/commands/_types.py | 4 +- discord/ext/commands/bot.py | 12 +- discord/ext/commands/context.py | 4 +- discord/ext/commands/converter.py | 128 ++----- discord/ext/commands/cooldowns.py | 16 +- discord/ext/commands/core.py | 124 ++----- discord/ext/commands/errors.py | 50 +-- discord/ext/commands/flags.py | 63 +--- discord/ext/commands/help.py | 75 +---- discord/ext/pages/pagination.py | 77 ++--- discord/ext/tasks/__init__.py | 54 +-- discord/file.py | 10 +- discord/flags.py | 12 +- discord/gateway.py | 32 +- discord/guild.py | 257 ++++---------- discord/http.py | 355 +++++--------------- discord/integrations.py | 12 +- discord/interactions.py | 22 +- discord/invite.py | 68 +--- discord/iterators.py | 24 +- discord/member.py | 85 ++--- discord/mentions.py | 8 +- discord/message.py | 118 ++----- discord/object.py | 4 +- discord/opus.py | 29 +- discord/partial_emoji.py | 8 +- discord/permissions.py | 16 +- discord/player.py | 55 +-- discord/raw_models.py | 8 +- discord/reaction.py | 8 +- discord/role.py | 28 +- discord/scheduled_events.py | 32 +- discord/shard.py | 40 +-- discord/sinks/core.py | 4 +- discord/sinks/m4a.py | 16 +- discord/sinks/mka.py | 8 +- discord/sinks/mkv.py | 8 +- discord/sinks/mp3.py | 8 +- discord/sinks/mp4.py | 16 +- discord/sinks/ogg.py | 8 +- discord/sinks/wave.py | 4 +- discord/stage_instance.py | 16 +- discord/state.py | 176 +++------- discord/sticker.py | 20 +- discord/team.py | 8 +- discord/template.py | 24 +- discord/threads.py | 10 +- discord/types/interactions.py | 24 +- discord/types/voice.py | 4 +- discord/ui/button.py | 8 +- discord/ui/input_text.py | 4 +- discord/ui/item.py | 4 +- discord/ui/modal.py | 4 +- discord/ui/view.py | 42 +-- discord/user.py | 12 +- discord/utils.py | 39 +-- discord/voice_client.py | 41 +-- discord/webhook/async_.py | 117 ++----- discord/webhook/sync.py | 77 ++--- discord/welcome_screen.py | 11 +- discord/widget.py | 22 +- examples/app_commands/info.py | 4 +- examples/app_commands/slash_autocomplete.py | 20 +- examples/app_commands/slash_basic.py | 8 +- examples/app_commands/slash_cog_groups.py | 8 +- examples/app_commands/slash_groups.py | 4 +- examples/app_commands/slash_options.py | 4 +- examples/audio_recording.py | 9 +- examples/basic_voice.py | 16 +- examples/converters.py | 20 +- examples/cooldown.py | 4 +- examples/custom_context.py | 4 +- examples/guessing_game.py | 4 +- examples/modal_dialogs.py | 8 +- examples/reaction_roles.py | 16 +- examples/secret.py | 24 +- examples/views/button_roles.py | 4 +- examples/views/confirm.py | 4 +- examples/views/dropdown.py | 16 +- examples/views/ephemeral.py | 8 +- examples/views/paginator.py | 70 +--- examples/views/persistent.py | 8 +- 104 files changed, 923 insertions(+), 2900 deletions(-) diff --git a/discord/__init__.py b/discord/__init__.py index da17997f19..ade2f583ba 100644 --- a/discord/__init__.py +++ b/discord/__init__.py @@ -74,8 +74,6 @@ class VersionInfo(NamedTuple): serial: int -version_info: VersionInfo = VersionInfo( - major=2, minor=0, micro=0, releaselevel="beta", serial=4 -) +version_info: VersionInfo = VersionInfo(major=2, minor=0, micro=0, releaselevel="beta", serial=4) logging.getLogger(__name__).addHandler(logging.NullHandler()) diff --git a/discord/__main__.py b/discord/__main__.py index 12b5b5a1c3..07c00ef91d 100644 --- a/discord/__main__.py +++ b/discord/__main__.py @@ -36,16 +36,10 @@ def show_version() -> None: - entries = [ - "- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format( - sys.version_info - ) - ] + entries = ["- Python v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(sys.version_info)] version_info = discord.version_info - entries.append( - "- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info) - ) + entries.append("- py-cord v{0.major}.{0.minor}.{0.micro}-{0.releaselevel}".format(version_info)) if version_info.releaselevel != "final": pkg = pkg_resources.get_distribution("py-cord") if pkg: @@ -299,9 +293,7 @@ def newcog(parser, args) -> None: def add_newbot_args(subparser: argparse._SubParsersAction) -> None: - parser = subparser.add_parser( - "newbot", help="creates a command bot project quickly" - ) + parser = subparser.add_parser("newbot", help="creates a command bot project quickly") parser.set_defaults(func=newbot) parser.add_argument("name", help="the bot project name") @@ -311,12 +303,8 @@ def add_newbot_args(subparser: argparse._SubParsersAction) -> None: nargs="?", default=Path.cwd(), ) - parser.add_argument( - "--prefix", help="the bot prefix (default: $)", default="$", metavar="" - ) - parser.add_argument( - "--sharded", help="whether to use AutoShardedBot", action="store_true" - ) + parser.add_argument("--prefix", help="the bot prefix (default: $)", default="$", metavar="") + parser.add_argument("--sharded", help="whether to use AutoShardedBot", action="store_true") parser.add_argument( "--no-git", help="do not create a .gitignore file", @@ -347,18 +335,12 @@ def add_newcog_args(subparser: argparse._SubParsersAction) -> None: help="whether to hide all commands in the cog", action="store_true", ) - parser.add_argument( - "--full", help="add all special methods as well", action="store_true" - ) + parser.add_argument("--full", help="add all special methods as well", action="store_true") def parse_args() -> Tuple[argparse.ArgumentParser, argparse.Namespace]: - parser = argparse.ArgumentParser( - prog="discord", description="Tools for helping with Pycord" - ) - parser.add_argument( - "-v", "--version", action="store_true", help="shows the library version" - ) + parser = argparse.ArgumentParser(prog="discord", description="Tools for helping with Pycord") + parser.add_argument("-v", "--version", action="store_true", help="shows the library version") parser.set_defaults(func=core) subparser = parser.add_subparsers(dest="subcommand", title="subcommands") diff --git a/discord/abc.py b/discord/abc.py index 013319d356..2dc2602bd2 100644 --- a/discord/abc.py +++ b/discord/abc.py @@ -95,9 +95,7 @@ from .ui.view import View from .user import ClientUser - PartialMessageableChannel = Union[ - TextChannel, Thread, DMChannel, PartialMessageable - ] + PartialMessageableChannel = Union[TextChannel, Thread, DMChannel, PartialMessageable] MessageableChannel = Union[PartialMessageableChannel, GroupChannel] SnowflakeTime = Union["Snowflake", datetime] @@ -262,9 +260,7 @@ class GuildChannel: if TYPE_CHECKING: - def __init__( - self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any] - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: Dict[str, Any]): ... def __str__(self) -> str: @@ -290,9 +286,7 @@ async def _move( http = self._state.http bucket = self._sorting_bucket - channels: List[GuildChannel] = [ - c for c in self.guild.channels if c._sorting_bucket == bucket - ] + channels: List[GuildChannel] = [c for c in self.guild.channels if c._sorting_bucket == bucket] channels.sort(key=lambda c: c.position) @@ -319,9 +313,7 @@ async def _move( await http.bulk_channel_update(self.guild.id, payload, reason=reason) - async def _edit( - self, options: Dict[str, Any], reason: Optional[str] - ) -> Optional[ChannelPayload]: + async def _edit(self, options: Dict[str, Any], reason: Optional[str]) -> Optional[ChannelPayload]: try: parent = options.pop("category") except KeyError: @@ -357,18 +349,14 @@ async def _edit( if lock_permissions: category = self.guild.get_channel(parent_id) if category: - options["permission_overwrites"] = [ - c._asdict() for c in category._overwrites - ] + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] options["parent_id"] = parent_id elif lock_permissions and self.category_id is not None: # if we're syncing permissions on a pre-existing channel category without changing it # we need to update the permissions to point to the pre-existing category category = self.guild.get_channel(self.category_id) if category: - options["permission_overwrites"] = [ - c._asdict() for c in category._overwrites - ] + options["permission_overwrites"] = [c._asdict() for c in category._overwrites] else: await self._move( position, @@ -382,18 +370,14 @@ async def _edit( perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument( - f"Expected PermissionOverwrite received {perm.__class__.__name__}" - ) + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") allow, deny = perm.pair() payload = { "allow": allow.value, "deny": deny.value, "id": target.id, - "type": _Overwrites.ROLE - if isinstance(target, Role) - else _Overwrites.MEMBER, + "type": _Overwrites.ROLE if isinstance(target, Role) else _Overwrites.MEMBER, } perms.append(payload) @@ -409,9 +393,7 @@ async def _edit( options["type"] = ch_type.value if options: - return await self._state.http.edit_channel( - self.id, reason=reason, **options - ) + return await self._state.http.edit_channel(self.id, reason=reason, **options) def _fill_overwrites(self, data: GuildChannelPayload) -> None: self._overwrites = [] @@ -617,9 +599,7 @@ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite( - allow=maybe_everyone.allow, deny=maybe_everyone.deny - ) + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) except IndexError: pass @@ -650,9 +630,7 @@ def permissions_for(self, obj: Union[Member, Role], /) -> Permissions: try: maybe_everyone = self._overwrites[0] if maybe_everyone.id == self.guild.id: - base.handle_overwrite( - allow=maybe_everyone.allow, deny=maybe_everyone.deny - ) + base.handle_overwrite(allow=maybe_everyone.allow, deny=maybe_everyone.deny) remaining_overwrites = self._overwrites[1:] else: remaining_overwrites = self._overwrites @@ -735,9 +713,7 @@ async def set_permissions( ) -> None: ... - async def set_permissions( - self, target, *, overwrite=_undefined, reason=None, **permissions - ): + async def set_permissions(self, target, *, overwrite=_undefined, reason=None, **permissions): r"""|coro| Sets the channel specific permission overwrites for a target in the @@ -831,9 +807,7 @@ async def set_permissions( await http.delete_channel_permissions(self.id, target.id, reason=reason) elif isinstance(overwrite, PermissionOverwrite): (allow, deny) = overwrite.pair() - await http.edit_channel_permissions( - self.id, target.id, allow.value, deny.value, perm_type, reason=reason - ) + await http.edit_channel_permissions(self.id, target.id, allow.value, deny.value, perm_type, reason=reason) else: raise InvalidArgument("Invalid overwrite type provided.") @@ -849,18 +823,14 @@ async def _clone_impl( base_attrs["name"] = name or self.name guild_id = self.guild.id cls = self.__class__ - data = await self._state.http.create_channel( - guild_id, self.type.value, reason=reason, **base_attrs - ) + data = await self._state.http.create_channel(guild_id, self.type.value, reason=reason, **base_attrs) obj = cls(state=self._state, guild=self.guild, data=data) # temporarily add it to the cache self.guild._channels[obj.id] = obj # type: ignore return obj - async def clone( - self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> GCH: + async def clone(self: GCH, *, name: Optional[str] = None, reason: Optional[str] = None) -> GCH: """|coro| Clones this channel. This creates a channel with the same properties @@ -1007,9 +977,7 @@ async def move(self, **kwargs) -> None: before, after = kwargs.get("before"), kwargs.get("after") offset = kwargs.get("offset", 0) if sum(bool(a) for a in (beginning, end, before, after)) > 1: - raise InvalidArgument( - "Only one of [before, after, end, beginning] can be used." - ) + raise InvalidArgument("Only one of [before, after, end, beginning] can be used.") bucket = self._sorting_bucket parent_id = kwargs.get("category", MISSING) @@ -1017,15 +985,11 @@ async def move(self, **kwargs) -> None: if parent_id not in (MISSING, None): parent_id = parent_id.id channels = [ - ch - for ch in self.guild.channels - if ch._sorting_bucket == bucket and ch.category_id == parent_id + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == parent_id ] else: channels = [ - ch - for ch in self.guild.channels - if ch._sorting_bucket == bucket and ch.category_id == self.category_id + ch for ch in self.guild.channels if ch._sorting_bucket == bucket and ch.category_id == self.category_id ] channels.sort(key=lambda c: (c.position, c.id)) @@ -1045,9 +1009,7 @@ async def move(self, **kwargs) -> None: elif before: index = next((i for i, c in enumerate(channels) if c.id == before.id), None) elif after: - index = next( - (i + 1 for i, c in enumerate(channels) if c.id == after.id), None - ) + index = next((i + 1 for i, c in enumerate(channels) if c.id == after.id), None) if index is None: raise InvalidArgument("Could not resolve appropriate move position") @@ -1062,9 +1024,7 @@ async def move(self, **kwargs) -> None: d.update(parent_id=parent_id, lock_permissions=lock_permissions) payload.append(d) - await self._state.http.bulk_channel_update( - self.guild.id, payload, reason=reason - ) + await self._state.http.bulk_channel_update(self.guild.id, payload, reason=reason) async def create_invite( self, @@ -1180,10 +1140,7 @@ async def invites(self) -> List[Invite]: state = self._state data = await state.http.invites_from_channel(self.id) guild = self.guild - return [ - Invite(state=state, data=invite, channel=self, guild=guild) - for invite in data - ] + return [Invite(state=state, data=invite, channel=self, guild=guild) for invite in data] class Messageable: @@ -1390,27 +1347,21 @@ async def send( content = str(content) if content is not None else None if embed is not None and embeds is not None: - raise InvalidArgument( - "cannot pass both embed and embeds parameter to send()" - ) + raise InvalidArgument("cannot pass both embed and embeds parameter to send()") if embed is not None: embed = embed.to_dict() elif embeds is not None: if len(embeds) > 10: - raise InvalidArgument( - "embeds parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("embeds parameter must be a list of up to 10 elements") embeds = [embed.to_dict() for embed in embeds] if stickers is not None: stickers = [sticker.id for sticker in stickers] if allowed_mentions is None: - allowed_mentions = ( - state.allowed_mentions and state.allowed_mentions.to_dict() - ) + allowed_mentions = state.allowed_mentions and state.allowed_mentions.to_dict() elif state.allowed_mentions is not None: allowed_mentions = state.allowed_mentions.merge(allowed_mentions).to_dict() @@ -1430,9 +1381,7 @@ async def send( if view: if not hasattr(view, "__discord_ui_view__"): - raise InvalidArgument( - f"view parameter must be View not {view.__class__!r}" - ) + raise InvalidArgument(f"view parameter must be View not {view.__class__!r}") components = view.to_components() else: @@ -1464,9 +1413,7 @@ async def send( elif files is not None: if len(files) > 10: - raise InvalidArgument( - "files parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") @@ -1635,15 +1582,10 @@ def can_send(self, *objects) -> bool: if obj is None: permission = mapping["Message"] else: - permission = ( - mapping.get(type(obj).__name__) or mapping[obj.__name__] - ) + permission = mapping.get(type(obj).__name__) or mapping[obj.__name__] if type(obj).__name__ == "Emoji": - if ( - obj._to_partial().is_unicode_emoji - or obj.guild_id == channel.guild.id - ): + if obj._to_partial().is_unicode_emoji or obj.guild_id == channel.guild.id: continue elif type(obj).__name__ == "GuildSticker": if obj.guild_id == channel.guild.id: diff --git a/discord/activity.py b/discord/activity.py index c7e1b55bb8..8202d468ae 100644 --- a/discord/activity.py +++ b/discord/activity.py @@ -132,9 +132,7 @@ def created_at(self) -> Optional[datetime.datetime]: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp( - self._created_at / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) def to_dict(self) -> ActivityPayload: raise NotImplementedError @@ -236,15 +234,11 @@ def __init__(self, **kwargs): activity_type = kwargs.pop("type", -1) self.type: ActivityType = ( - activity_type - if isinstance(activity_type, ActivityType) - else try_enum(ActivityType, activity_type) + activity_type if isinstance(activity_type, ActivityType) else try_enum(ActivityType, activity_type) ) emoji = kwargs.pop("emoji", None) - self.emoji: Optional[PartialEmoji] = ( - PartialEmoji.from_dict(emoji) if emoji is not None else None - ) + self.emoji: Optional[PartialEmoji] = PartialEmoji.from_dict(emoji) if emoji is not None else None def __repr__(self) -> str: attrs = ( @@ -393,18 +387,14 @@ def type(self) -> ActivityType: def start(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user started playing this game in UTC, if applicable.""" if self._start: - return datetime.datetime.fromtimestamp( - self._start / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._start / 1000, tz=datetime.timezone.utc) return None @property def end(self) -> Optional[datetime.datetime]: """Optional[:class:`datetime.datetime`]: When the user will stop playing this game in UTC, if applicable.""" if self._end: - return datetime.datetime.fromtimestamp( - self._end / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._end / 1000, tz=datetime.timezone.utc) return None def __str__(self) -> str: @@ -534,11 +524,7 @@ def to_dict(self) -> Dict[str, Any]: return ret def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, Streaming) - and other.name == self.name - and other.url == self.url - ) + return isinstance(other, Streaming) and other.name == self.name and other.url == self.url def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -606,9 +592,7 @@ def created_at(self) -> Optional[datetime.datetime]: .. versionadded:: 1.3 """ if self._created_at is not None: - return datetime.datetime.fromtimestamp( - self._created_at / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._created_at / 1000, tz=datetime.timezone.utc) @property def colour(self) -> Colour: @@ -711,16 +695,12 @@ def track_url(self) -> str: @property def start(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user started playing this song in UTC.""" - return datetime.datetime.fromtimestamp( - self._timestamps["start"] / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._timestamps["start"] / 1000, tz=datetime.timezone.utc) @property def end(self) -> datetime.datetime: """:class:`datetime.datetime`: When the user will stop playing this song in UTC.""" - return datetime.datetime.fromtimestamp( - self._timestamps["end"] / 1000, tz=datetime.timezone.utc - ) + return datetime.datetime.fromtimestamp(self._timestamps["end"] / 1000, tz=datetime.timezone.utc) @property def duration(self) -> datetime.timedelta: @@ -766,9 +746,7 @@ class CustomActivity(BaseActivity): __slots__ = ("name", "emoji", "state") - def __init__( - self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any - ): + def __init__(self, name: Optional[str], *, emoji: Optional[PartialEmoji] = None, **extra: Any): super().__init__(**extra) self.name: Optional[str] = name self.state: Optional[str] = extra.pop("state", None) @@ -785,9 +763,7 @@ def __init__( elif isinstance(emoji, PartialEmoji): self.emoji = emoji else: - raise TypeError( - f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead." - ) + raise TypeError(f"Expected str, PartialEmoji, or None, received {type(emoji)!r} instead.") @property def type(self) -> ActivityType: @@ -815,11 +791,7 @@ def to_dict(self) -> Dict[str, Any]: return o def __eq__(self, other: Any) -> bool: - return ( - isinstance(other, CustomActivity) - and other.name == self.name - and other.emoji == self.emoji - ) + return isinstance(other, CustomActivity) and other.name == self.name and other.emoji == self.emoji def __ne__(self, other: Any) -> bool: return not self.__eq__(other) @@ -873,10 +845,6 @@ def create_activity(data: Optional[ActivityPayload]) -> Optional[ActivityTypes]: # the url won't be None here return Streaming(**data) # type: ignore return Activity(**data) - elif ( - game_type is ActivityType.listening - and "sync_id" in data - and "session_id" in data - ): + elif game_type is ActivityType.listening and "sync_id" in data and "session_id" in data: return Spotify(**data) return Activity(**data) diff --git a/discord/appinfo.py b/discord/appinfo.py index 0c5e79e248..11565deec7 100644 --- a/discord/appinfo.py +++ b/discord/appinfo.py @@ -155,9 +155,7 @@ def __init__(self, state: ConnectionState, data: AppInfoPayload): self.guild_id: Optional[int] = utils._get_as_snowflake(data, "guild_id") - self.primary_sku_id: Optional[int] = utils._get_as_snowflake( - data, "primary_sku_id" - ) + self.primary_sku_id: Optional[int] = utils._get_as_snowflake(data, "primary_sku_id") self.slug: Optional[str] = data.get("slug") self._cover_image: Optional[str] = data.get("cover_image") self.terms_of_service_url: Optional[str] = data.get("terms_of_service_url") diff --git a/discord/asset.py b/discord/asset.py index 4132f8f123..829340361a 100644 --- a/discord/asset.py +++ b/discord/asset.py @@ -183,9 +183,7 @@ def _from_avatar(cls, state, user_id: int, avatar: str) -> Asset: ) @classmethod - def _from_guild_avatar( - cls, state, guild_id: int, member_id: int, avatar: str - ) -> Asset: + def _from_guild_avatar(cls, state, guild_id: int, member_id: int, avatar: str) -> Asset: animated = avatar.startswith("a_") format = "gif" if animated else "png" return cls( @@ -260,9 +258,7 @@ def _from_user_banner(cls, state, user_id: int, banner_hash: str) -> Asset: ) @classmethod - def _from_scheduled_event_cover( - cls, state, event_id: int, cover_hash: str - ) -> Asset: + def _from_scheduled_event_cover(cls, state, event_id: int, cover_hash: str) -> Asset: return cls( state, url=f"{cls.BASE}/guild-events/{event_id}/{cover_hash}.png", @@ -336,22 +332,16 @@ def replace( if format is not MISSING: if self._animated: if format not in VALID_ASSET_FORMATS: - raise InvalidArgument( - f"format must be one of {VALID_ASSET_FORMATS}" - ) + raise InvalidArgument(f"format must be one of {VALID_ASSET_FORMATS}") url = url.with_path(f"{path}.{format}") elif static_format is MISSING: if format not in VALID_STATIC_FORMATS: - raise InvalidArgument( - f"format must be one of {VALID_STATIC_FORMATS}" - ) + raise InvalidArgument(f"format must be one of {VALID_STATIC_FORMATS}") url = url.with_path(f"{path}.{format}") if static_format is not MISSING and not self._animated: if static_format not in VALID_STATIC_FORMATS: - raise InvalidArgument( - f"static_format must be one of {VALID_STATIC_FORMATS}" - ) + raise InvalidArgument(f"static_format must be one of {VALID_STATIC_FORMATS}") url = url.with_path(f"{path}.{static_format}") if size is not MISSING: diff --git a/discord/audit_logs.py b/discord/audit_logs.py index 323257c635..b50f496af1 100644 --- a/discord/audit_logs.py +++ b/discord/audit_logs.py @@ -88,25 +88,19 @@ def _transform_snowflake(entry: AuditLogEntry, data: Snowflake) -> int: return int(data) -def _transform_channel( - entry: AuditLogEntry, data: Optional[Snowflake] -) -> Optional[Union[abc.GuildChannel, Object]]: +def _transform_channel(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Union[abc.GuildChannel, Object]]: if data is None: return None return entry.guild.get_channel(int(data)) or Object(id=data) -def _transform_member_id( - entry: AuditLogEntry, data: Optional[Snowflake] -) -> Union[Member, User, None]: +def _transform_member_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Union[Member, User, None]: if data is None: return None return entry._get_member(int(data)) -def _transform_guild_id( - entry: AuditLogEntry, data: Optional[Snowflake] -) -> Optional[Guild]: +def _transform_guild_id(entry: AuditLogEntry, data: Optional[Snowflake]) -> Optional[Guild]: if data is None: return None return entry._state._get_guild(data) @@ -170,9 +164,7 @@ def _transform(entry: AuditLogEntry, data: int) -> T: return _transform -def _transform_type( - entry: AuditLogEntry, data: int -) -> Union[enums.ChannelType, enums.StickerType]: +def _transform_type(entry: AuditLogEntry, data: int) -> Union[enums.ChannelType, enums.StickerType]: if entry.action.name.startswith("sticker_"): return enums.try_enum(enums.StickerType, data) else: @@ -290,15 +282,10 @@ def __init__( if attr == "location" and hasattr(self.before, "location_type"): from .scheduled_events import ScheduledEventLocation - if ( - self.before.location_type - is enums.ScheduledEventLocationType.external - ): + if self.before.location_type is enums.ScheduledEventLocationType.external: before = ScheduledEventLocation(state=state, value=before) elif hasattr(self.before, "channel"): - before = ScheduledEventLocation( - state=state, value=self.before.channel - ) + before = ScheduledEventLocation(state=state, value=self.before.channel) setattr(self.before, attr, before) @@ -313,15 +300,10 @@ def __init__( if attr == "location" and hasattr(self.after, "location_type"): from .scheduled_events import ScheduledEventLocation - if ( - self.after.location_type - is enums.ScheduledEventLocationType.external - ): + if self.after.location_type is enums.ScheduledEventLocationType.external: after = ScheduledEventLocation(state=state, value=after) elif hasattr(self.after, "channel"): - after = ScheduledEventLocation( - state=state, value=self.after.channel - ) + after = ScheduledEventLocation(state=state, value=self.after.channel) setattr(self.after, attr, after) @@ -428,9 +410,7 @@ class AuditLogEntry(Hashable): which actions have this field filled out. """ - def __init__( - self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild - ): + def __init__(self, *, users: Dict[int, User], data: AuditLogEntryPayload, guild: Guild): self._state = guild._state self.guild = guild self._users = users @@ -450,38 +430,27 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: self.extra: _AuditLogProxyMemberPrune = type( "_AuditLogProxy", (), {k: int(v) for k, v in self.extra.items()} )() - elif ( - self.action is enums.AuditLogAction.member_move - or self.action is enums.AuditLogAction.message_delete - ): + elif self.action is enums.AuditLogAction.member_move or self.action is enums.AuditLogAction.message_delete: channel_id = int(self.extra["channel_id"]) elems = { "count": int(self.extra["count"]), - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), } - self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyMemberMoveOrMessageDelete = type("_AuditLogProxy", (), elems)() elif self.action is enums.AuditLogAction.member_disconnect: # The member disconnect action has a dict with some information elems = { "count": int(self.extra["count"]), } - self.extra: _AuditLogProxyMemberDisconnect = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyMemberDisconnect = type("_AuditLogProxy", (), elems)() elif self.action.name.endswith("pin"): # the pin actions have a dict with some information channel_id = int(self.extra["channel_id"]) elems = { - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id), + "channel": self.guild.get_channel(channel_id) or Object(id=channel_id), "message_id": int(self.extra["message_id"]), } - self.extra: _AuditLogProxyPinAction = type( - "_AuditLogProxy", (), elems - )() + self.extra: _AuditLogProxyPinAction = type("_AuditLogProxy", (), elems)() elif self.action.name.startswith("overwrite_"): # the overwrite_ actions have a dict with some information instance_id = int(self.extra["id"]) @@ -496,13 +465,8 @@ def _from_data(self, data: AuditLogEntryPayload) -> None: self.extra: Role = role elif self.action.name.startswith("stage_instance"): channel_id = int(self.extra["channel_id"]) - elems = { - "channel": self.guild.get_channel(channel_id) - or Object(id=channel_id) - } - self.extra: _AuditLogProxyStageInstanceAction = type( - "_AuditLogProxy", (), elems - )() + elems = {"channel": self.guild.get_channel(channel_id) or Object(id=channel_id)} + self.extra: _AuditLogProxyStageInstanceAction = type("_AuditLogProxy", (), elems)() self.extra: Union[ _AuditLogProxyMemberPrune, @@ -586,9 +550,7 @@ def after(self) -> AuditLogDiff: def _convert_target_guild(self, target_id: int) -> Guild: return self.guild - def _convert_target_channel( - self, target_id: int - ) -> Union[abc.GuildChannel, Object]: + def _convert_target_channel(self, target_id: int) -> Union[abc.GuildChannel, Object]: return self.guild.get_channel(target_id) or Object(id=target_id) def _convert_target_user(self, target_id: int) -> Union[Member, User, None]: @@ -600,11 +562,7 @@ def _convert_target_role(self, target_id: int) -> Union[Role, Object]: def _convert_target_invite(self, target_id: int) -> Invite: # invites have target_id set to null # so figure out which change has the full invite data - changeset = ( - self.before - if self.action is enums.AuditLogAction.invite_delete - else self.after - ) + changeset = self.before if self.action is enums.AuditLogAction.invite_delete else self.after fake_payload = { "max_age": changeset.max_age, @@ -627,9 +585,7 @@ def _convert_target_emoji(self, target_id: int) -> Union[Emoji, Object]: def _convert_target_message(self, target_id: int) -> Union[Member, User, None]: return self._get_member(target_id) - def _convert_target_stage_instance( - self, target_id: int - ) -> Union[StageInstance, Object]: + def _convert_target_stage_instance(self, target_id: int) -> Union[StageInstance, Object]: return self.guild.get_stage_instance(target_id) or Object(id=target_id) def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object]: @@ -638,7 +594,5 @@ def _convert_target_sticker(self, target_id: int) -> Union[GuildSticker, Object] def _convert_target_thread(self, target_id: int) -> Union[Thread, Object]: return self.guild.get_thread(target_id) or Object(id=target_id) - def _convert_target_scheduled_event( - self, target_id: int - ) -> Union[ScheduledEvent, None]: + def _convert_target_scheduled_event(self, target_id: int) -> Union[ScheduledEvent, None]: return self.guild.get_scheduled_event(target_id) or Object(id=target_id) diff --git a/discord/bot.py b/discord/bot.py index ae1e712dab..12d34ea140 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -138,9 +138,7 @@ def add_application_command(self, command: ApplicationCommand) -> None: break self._pending_application_commands.append(command) - def remove_application_command( - self, command: ApplicationCommand - ) -> Optional[ApplicationCommand]: + def remove_application_command(self, command: ApplicationCommand) -> Optional[ApplicationCommand]: """Remove a :class:`.ApplicationCommand` from the internal list of commands. @@ -210,9 +208,7 @@ def get_application_command( return return command - async def get_desynced_commands( - self, guild_id: Optional[int] = None - ) -> List[Dict[str, Any]]: + async def get_desynced_commands(self, guild_id: Optional[int] = None) -> List[Dict[str, Any]]: """|coro| Gets the list of commands that are desynced from discord. If ``guild_id`` is specified, it will only return @@ -246,14 +242,8 @@ async def get_desynced_commands( registered_commands = await self.http.get_global_commands(self.user.id) pending = [cmd for cmd in cmds if cmd.guild_ids is None] else: - registered_commands = await self.http.get_guild_commands( - self.user.id, guild_id - ) - pending = [ - cmd - for cmd in cmds - if cmd.guild_ids is not None and guild_id in cmd.guild_ids - ] + registered_commands = await self.http.get_guild_commands(self.user.id, guild_id) + pending = [cmd for cmd in cmds if cmd.guild_ids is not None and guild_id in cmd.guild_ids] registered_commands_dict = {cmd["name"]: cmd for cmd in registered_commands} to_check = { @@ -279,11 +269,7 @@ async def get_desynced_commands( falsy_vals = (False, []) for opt in value: - cmd_vals = ( - [val.get(opt, MISSING) for val in as_dict[check]] - if check in as_dict - else [] - ) + cmd_vals = [val.get(opt, MISSING) for val in as_dict[check]] if check in as_dict else [] for i, val in enumerate(cmd_vals): if val in falsy_vals: cmd_vals[i] = MISSING @@ -422,9 +408,7 @@ def register(method: str, *args, **kwargs): } def register(method: str, *args, **kwargs): - return registration_methods[method]( - self.user.id, guild_id, *args, **kwargs - ) + return registration_methods[method](self.user.id, guild_id, *args, **kwargs) pending_actions = [] @@ -463,9 +447,7 @@ def register(method: str, *args, **kwargs): else: raise ValueError(f"Unknown action: {cmd['action']}") - filtered_deleted = list( - filter(lambda a: a["action"] != "delete", pending_actions) - ) + filtered_deleted = list(filter(lambda a: a["action"] != "delete", pending_actions)) if len(filtered_deleted) == len(pending): # It appears that all the commands need to be modified, so we can just do a bulk upsert data = [cmd["command"].to_dict() for cmd in filtered_deleted] @@ -478,13 +460,9 @@ def register(method: str, *args, **kwargs): await register("delete", cmd["command"]) continue if cmd["action"] == "edit": - registered.append( - await register("edit", cmd["id"], cmd["command"].to_dict()) - ) + registered.append(await register("edit", cmd["id"], cmd["command"].to_dict())) elif cmd["action"] == "upsert": - registered.append( - await register("upsert", cmd["command"].to_dict()) - ) + registered.append(await register("upsert", cmd["command"].to_dict())) else: raise ValueError(f"Unknown action: {cmd['action']}") else: @@ -504,9 +482,7 @@ def register(method: str, *args, **kwargs): type=i["type"], ) if not cmd: - raise ValueError( - f"Registered command {i['name']}, type {i['type']} not found in pending commands" - ) + raise ValueError(f"Registered command {i['name']}, type {i['type']} not found in pending commands") cmd.id = i["id"] self._application_commands[cmd.id] = cmd @@ -577,11 +553,7 @@ async def sync_commands( if unregister_guilds is not None: cmd_guild_ids.extend(unregister_guilds) for guild_id in set(cmd_guild_ids): - guild_commands = [ - cmd - for cmd in commands - if cmd.guild_ids is not None and guild_id in cmd.guild_ids - ] + guild_commands = [cmd for cmd in commands if cmd.guild_ids is not None and guild_id in cmd.guild_ids] registered_guild_commands[guild_id] = await self.register_commands( guild_commands, guild_id=guild_id, force=force ) @@ -602,9 +574,7 @@ async def sync_commands( self._application_commands[cmd.id] = cmd # Permissions (Roles will be converted to IDs just before Upsert for Global Commands) - global_permissions.append( - {"id": i["id"], "permissions": cmd.permissions} - ) + global_permissions.append({"id": i["id"], "permissions": cmd.permissions}) for guild_id, commands in registered_guild_commands.items(): guild_permissions: List = [] @@ -628,11 +598,7 @@ async def sync_commands( perm.to_dict() for perm in cmd.permissions if perm.guild_id is None - or ( - perm.guild_id == guild_id - and cmd.guild_ids is not None - and perm.guild_id in cmd.guild_ids - ) + or (perm.guild_id == guild_id and cmd.guild_ids is not None and perm.guild_id in cmd.guild_ids) ] guild_permissions.append({"id": i["id"], "permissions": permissions}) @@ -641,15 +607,9 @@ async def sync_commands( perm.to_dict() for perm in global_command["permissions"] if perm.guild_id is None - or ( - perm.guild_id == guild_id - and cmd.guild_ids is not None - and perm.guild_id in cmd.guild_ids - ) + or (perm.guild_id == guild_id and cmd.guild_ids is not None and perm.guild_id in cmd.guild_ids) ] - guild_permissions.append( - {"id": global_command["id"], "permissions": permissions} - ) + guild_permissions.append({"id": global_command["id"], "permissions": permissions}) # Collect & Upsert Permissions for Each Guild # Command Permissions for this Guild @@ -721,18 +681,14 @@ async def sync_commands( # Upsert try: - await self.http.bulk_upsert_command_permissions( - self.user.id, guild_id, guild_cmd_perms - ) + await self.http.bulk_upsert_command_permissions(self.user.id, guild_id, guild_cmd_perms) except Forbidden: raise RuntimeError( f"Failed to add command permissions to guild {guild_id}", file=sys.stderr, ) - async def process_application_commands( - self, interaction: Interaction, auto_sync: bool = None - ) -> None: + async def process_application_commands(self, interaction: Interaction, auto_sync: bool = None) -> None: """|coro| This function processes the commands that have been registered @@ -772,10 +728,7 @@ async def process_application_commands( for cmd in self.application_commands: if cmd.name == interaction.data["name"] and ( interaction.data.get("guild_id") == cmd.guild_ids - or ( - isinstance(cmd.guild_ids, list) - and interaction.data.get("guild_id") in cmd.guild_ids - ) + or (isinstance(cmd.guild_ids, list) and interaction.data.get("guild_id") in cmd.guild_ids) ): command = cmd break @@ -976,9 +929,7 @@ def walk_application_commands(self) -> Generator[ApplicationCommand, None, None] yield from command.walk_commands() yield command - async def get_application_context( - self, interaction: Interaction, cls=None - ) -> ApplicationContext: + async def get_application_context(self, interaction: Interaction, cls=None) -> ApplicationContext: r"""|coro| Returns the invocation context from the interaction. @@ -1006,9 +957,7 @@ class be provided, it must be similar enough to cls = ApplicationContext return cls(self, interaction) - async def get_autocomplete_context( - self, interaction: Interaction, cls=None - ) -> AutocompleteContext: + async def get_autocomplete_context(self, interaction: Interaction, cls=None) -> AutocompleteContext: r"""|coro| Returns the autocomplete context from the interaction. @@ -1062,12 +1011,8 @@ def __init__(self, description=None, *args, **options): if self.owner_id and self.owner_ids: raise TypeError("Both owner_id and owner_ids are set.") - if self.owner_ids and not isinstance( - self.owner_ids, collections.abc.Collection - ): - raise TypeError( - f"owner_ids must be a collection not {self.owner_ids.__class__!r}" - ) + if self.owner_ids and not isinstance(self.owner_ids, collections.abc.Collection): + raise TypeError(f"owner_ids must be a collection not {self.owner_ids.__class__!r}") self._checks = [] self._check_once = [] @@ -1081,9 +1026,7 @@ async def on_connect(self): async def on_interaction(self, interaction): await self.process_application_commands(interaction) - async def on_application_command_error( - self, context: ApplicationContext, exception: DiscordException - ) -> None: + async def on_application_command_error(self, context: ApplicationContext, exception: DiscordException) -> None: """|coro| The default command error handler provided by the bot. @@ -1105,9 +1048,7 @@ async def on_application_command_error( return print(f"Ignoring exception in command {context.command}:", file=sys.stderr) - traceback.print_exception( - type(exception), exception, exception.__traceback__, file=sys.stderr - ) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) # global check registration # TODO: Remove these from commands.Bot @@ -1204,9 +1145,7 @@ def whitelist(ctx): self.add_check(func, call_once=True) return func - async def can_run( - self, ctx: ApplicationContext, *, call_once: bool = False - ) -> bool: + async def can_run(self, ctx: ApplicationContext, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks if len(data) == 0: diff --git a/discord/channel.py b/discord/channel.py index a3fe0ea58b..095bf528be 100644 --- a/discord/channel.py +++ b/discord/channel.py @@ -176,9 +176,7 @@ class TextChannel(discord.abc.Messageable, discord.abc.GuildChannel, Hashable): "default_auto_archive_duration", ) - def __init__( - self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: TextChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self._type: int = data["type"] @@ -205,13 +203,9 @@ def _update(self, guild: Guild, data: TextChannelPayload) -> None: self.nsfw: bool = data.get("nsfw", False) # Does this need coercion into `int`? No idea yet. self.slowmode_delay: int = data.get("rate_limit_per_user", 0) - self.default_auto_archive_duration: ThreadArchiveDuration = data.get( - "default_auto_archive_duration", 1440 - ) + self.default_auto_archive_duration: ThreadArchiveDuration = data.get("default_auto_archive_duration", 1440) self._type: int = data.get("type", self._type) - self.last_message_id: Optional[int] = utils._get_as_snowflake( - data, "last_message_id" - ) + self.last_message_id: Optional[int] = utils._get_as_snowflake(data, "last_message_id") self._fill_overwrites(data) async def _get_channel(self): @@ -246,11 +240,7 @@ def threads(self) -> List[Thread]: .. versionadded:: 2.0 """ - return [ - thread - for thread in self.guild._threads.values() - if thread.parent_id == self.id - ] + return [thread for thread in self.guild._threads.values() if thread.parent_id == self.id] def is_nsfw(self) -> bool: """:class:`bool`: Checks if the channel is NSFW.""" @@ -279,11 +269,7 @@ def last_message(self) -> Optional[Message]: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return ( - self._state._get_message(self.last_message_id) - if self.last_message_id - else None - ) + return self._state._get_message(self.last_message_id) if self.last_message_id else None @overload async def edit( @@ -379,9 +365,7 @@ async def edit(self, *, reason=None, **options): return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> TextChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> TextChannel: return await self._clone_impl( { "topic": self.topic, @@ -521,9 +505,7 @@ def is_me(m): ret: List[Message] = [] count = 0 - minimum_time = ( - int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 - ) + minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 strategy = self.delete_messages if bulk else _single_delete_strategy async for message in iterator: @@ -624,14 +606,10 @@ async def create_webhook( if avatar is not None: avatar = utils._bytes_to_base64_data(avatar) # type: ignore - data = await self._state.http.create_webhook( - self.id, name=str(name), avatar=avatar, reason=reason - ) + data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason) return Webhook.from_state(data, state=self._state) - async def follow( - self, *, destination: TextChannel, reason: Optional[str] = None - ) -> Webhook: + async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook: """ Follows a channel using a webhook. @@ -670,15 +648,11 @@ async def follow( raise ClientException("The channel must be a news channel.") if not isinstance(destination, TextChannel): - raise InvalidArgument( - f"Expected TextChannel received {destination.__class__.__name__}" - ) + raise InvalidArgument(f"Expected TextChannel received {destination.__class__.__name__}") from .webhook import Webhook - data = await self._state.http.follow_webhook( - self.id, webhook_channel_id=destination.id, reason=reason - ) + data = await self._state.http.follow_webhook(self.id, webhook_channel_id=destination.id, reason=reason) return Webhook._as_follower(data, channel=destination, user=self._state.user) def get_partial_message(self, message_id: int, /) -> PartialMessage: @@ -777,8 +751,7 @@ async def create_thread( data = await self._state.http.start_thread_without_message( self.id, name=name, - auto_archive_duration=auto_archive_duration - or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, type=type.value, reason=reason, ) @@ -787,8 +760,7 @@ async def create_thread( self.id, message.id, name=name, - auto_archive_duration=auto_archive_duration - or self.default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or self.default_auto_archive_duration, reason=reason, ) @@ -877,18 +849,12 @@ def _get_voice_client_key(self) -> Tuple[int, str]: def _get_voice_state_pair(self) -> Tuple[int, int]: return self.guild.id, self.id - def _update( - self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload] - ) -> None: + def _update(self, guild: Guild, data: Union[VoiceChannelPayload, StageChannelPayload]) -> None: self.guild = guild self.name: str = data["name"] rtc = data.get("rtc_region") - self.rtc_region: Optional[VoiceRegion] = ( - try_enum(VoiceRegion, rtc) if rtc is not None else None - ) - self.video_quality_mode: VideoQualityMode = try_enum( - VideoQualityMode, data.get("video_quality_mode", 1) - ) + self.rtc_region: Optional[VoiceRegion] = try_enum(VoiceRegion, rtc) if rtc is not None else None + self.video_quality_mode: VideoQualityMode = try_enum(VideoQualityMode, data.get("video_quality_mode", 1)) self.category_id: Optional[int] = utils._get_as_snowflake(data, "parent_id") self.position: int = data["position"] self.bitrate: int = data.get("bitrate") @@ -1016,9 +982,7 @@ def type(self) -> ChannelType: return ChannelType.voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> VoiceChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> VoiceChannel: return await self._clone_impl( {"bitrate": self.bitrate, "user_limit": self.user_limit}, name=name, @@ -1112,9 +1076,7 @@ async def edit(self, *, reason=None, **options): # the payload will always be the proper channel payload return self.__class__(state=self._state, guild=self.guild, data=payload) # type: ignore - async def create_activity_invite( - self, activity: Union[EmbeddedActivity, int], **kwargs - ) -> Invite: + async def create_activity_invite(self, activity: Union[EmbeddedActivity, int], **kwargs) -> Invite: """|coro| A shortcut method that creates an instant activity invite. @@ -1244,11 +1206,7 @@ def _update(self, guild: Guild, data: StageChannelPayload) -> None: @property def requesting_to_speak(self) -> List[Member]: """List[:class:`Member`]: A list of members who are requesting to speak in the stage channel.""" - return [ - member - for member in self.members - if member.voice and member.voice.requested_to_speak_at is not None - ] + return [member for member in self.members if member.voice and member.voice.requested_to_speak_at is not None] @property def speakers(self) -> List[Member]: @@ -1259,9 +1217,7 @@ def speakers(self) -> List[Member]: return [ member for member in self.members - if member.voice - and not member.voice.suppress - and member.voice.requested_to_speak_at is None + if member.voice and not member.voice.suppress and member.voice.requested_to_speak_at is None ] @property @@ -1270,9 +1226,7 @@ def listeners(self) -> List[Member]: .. versionadded:: 2.0 """ - return [ - member for member in self.members if member.voice and member.voice.suppress - ] + return [member for member in self.members if member.voice and member.voice.suppress] @property def moderators(self) -> List[Member]: @@ -1281,11 +1235,7 @@ def moderators(self) -> List[Member]: .. versionadded:: 2.0 """ required_permissions = Permissions.stage_moderator() - return [ - member - for member in self.members - if self.permissions_for(member) >= required_permissions - ] + return [member for member in self.members if self.permissions_for(member) >= required_permissions] @property def type(self) -> ChannelType: @@ -1293,9 +1243,7 @@ def type(self) -> ChannelType: return ChannelType.stage_voice @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> StageChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StageChannel: return await self._clone_impl({}, name=name, reason=reason) @property @@ -1350,9 +1298,7 @@ async def create_instance( if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument( - "privacy_level field must be of type PrivacyLevel" - ) + raise InvalidArgument("privacy_level field must be of type PrivacyLevel") payload["privacy_level"] = privacy_level.value @@ -1515,9 +1461,7 @@ class CategoryChannel(discord.abc.GuildChannel, Hashable): "category_id", ) - def __init__( - self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: CategoryChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self._update(guild, data) @@ -1547,9 +1491,7 @@ def is_nsfw(self) -> bool: return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> CategoryChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> CategoryChannel: return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload @@ -1639,22 +1581,14 @@ def comparator(channel): @property def text_channels(self) -> List[TextChannel]: """List[:class:`TextChannel`]: Returns the text channels that are under this category.""" - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, TextChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, TextChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @property def voice_channels(self) -> List[VoiceChannel]: """List[:class:`VoiceChannel`]: Returns the voice channels that are under this category.""" - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, VoiceChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, VoiceChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @@ -1664,11 +1598,7 @@ def stage_channels(self) -> List[StageChannel]: .. versionadded:: 1.7 """ - ret = [ - c - for c in self.guild.channels - if c.category_id == self.id and isinstance(c, StageChannel) - ] + ret = [c for c in self.guild.channels if c.category_id == self.id and isinstance(c, StageChannel)] ret.sort(key=lambda c: (c.position, c.id)) return ret @@ -1764,9 +1694,7 @@ class StoreChannel(discord.abc.GuildChannel, Hashable): "_overwrites", ) - def __init__( - self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload - ): + def __init__(self, *, state: ConnectionState, guild: Guild, data: StoreChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self._update(guild, data) @@ -1805,9 +1733,7 @@ def is_nsfw(self) -> bool: return self.nsfw @utils.copy_doc(discord.abc.GuildChannel.clone) - async def clone( - self, *, name: Optional[str] = None, reason: Optional[str] = None - ) -> StoreChannel: + async def clone(self, *, name: Optional[str] = None, reason: Optional[str] = None) -> StoreChannel: return await self._clone_impl({"nsfw": self.nsfw}, name=name, reason=reason) @overload @@ -1922,9 +1848,7 @@ class DMChannel(discord.abc.Messageable, Hashable): __slots__ = ("id", "recipient", "me", "_state") - def __init__( - self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload - ): + def __init__(self, *, me: ClientUser, state: ConnectionState, data: DMChannelPayload): self._state: ConnectionState = state self.recipient: Optional[User] = state.store_user(data["recipients"][0]) self.me: ClientUser = me @@ -2065,9 +1989,7 @@ class GroupChannel(discord.abc.Messageable, Hashable): "_state", ) - def __init__( - self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload - ): + def __init__(self, *, me: ClientUser, state: ConnectionState, data: GroupChannelPayload): self._state: ConnectionState = state self.id: int = int(data["id"]) self.me: ClientUser = me @@ -2077,9 +1999,7 @@ def _update_group(self, data: GroupChannelPayload) -> None: self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_id") self._icon: Optional[str] = data.get("icon") self.name: Optional[str] = data.get("name") - self.recipients: List[User] = [ - self._state.store_user(u) for u in data.get("recipients", []) - ] + self.recipients: List[User] = [self._state.store_user(u) for u in data.get("recipients", [])] self.owner: Optional[BaseUser] if self.owner_id == self.me.id: @@ -2203,9 +2123,7 @@ class PartialMessageable(discord.abc.Messageable, Hashable): The channel type associated with this partial messageable, if given. """ - def __init__( - self, state: ConnectionState, id: int, type: Optional[ChannelType] = None - ): + def __init__(self, state: ConnectionState, id: int, type: Optional[ChannelType] = None): self._state: ConnectionState = state self._channel: Object = Object(id=id) self.id: int = id diff --git a/discord/client.py b/discord/client.py index a5cc0e1907..e457a8ea9a 100644 --- a/discord/client.py +++ b/discord/client.py @@ -225,12 +225,8 @@ def __init__( ): # self.ws is set in the connect method self.ws: DiscordWebSocket = None # type: ignore - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) - self._listeners: Dict[ - str, List[Tuple[asyncio.Future, Callable[..., bool]]] - ] = {} + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop + self._listeners: Dict[str, List[Tuple[asyncio.Future, Callable[..., bool]]]] = {} self.shard_id: Optional[int] = options.get("shard_id") self.shard_count: Optional[int] = options.get("shard_count") @@ -248,9 +244,7 @@ def __init__( self._handlers: Dict[str, Callable] = {"ready": self._handle_ready} - self._hooks: Dict[str, Callable] = { - "before_identify": self._call_before_identify_hook - } + self._hooks: Dict[str, Callable] = {"before_identify": self._call_before_identify_hook} self._enable_debug_events: bool = options.pop("enable_debug_events", False) self._connection: ConnectionState = self._get_state(**options) @@ -266,9 +260,7 @@ def __init__( # internals - def _get_websocket( - self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None - ) -> DiscordWebSocket: + def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: return self.ws def _get_state(self, **options: Any) -> ConnectionState: @@ -461,17 +453,13 @@ async def on_error(self, event_method: str, *args: Any, **kwargs: Any) -> None: # hooks - async def _call_before_identify_hook( - self, shard_id: Optional[int], *, initial: bool = False - ) -> None: + async def _call_before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: # This hook is an internal hook that actually calls the public one. # It allows the library to have its own hook without stepping on the # toes of those who need to override their own hook. await self.before_identify_hook(shard_id, initial=initial) - async def before_identify_hook( - self, shard_id: Optional[int], *, initial: bool = False - ) -> None: + async def before_identify_hook(self, shard_id: Optional[int], *, initial: bool = False) -> None: """|coro| A hook that is called before IDENTIFYing a session. This is useful @@ -519,9 +507,7 @@ async def login(self, token: str) -> None: passing status code. """ if not isinstance(token, str): - raise TypeError( - f"token must be of type str, not {token.__class__.__name__}" - ) + raise TypeError(f"token must be of type str, not {token.__class__.__name__}") _log.info("logging in using static token") @@ -621,9 +607,7 @@ async def connect(self, *, reconnect: bool = True) -> None: # Always try to RESUME the connection # If the connection is not RESUME-able then the gateway will invalidate the session. # This is apparently what the official Discord client does. - ws_params.update( - sequence=self.ws.sequence, resume=True, session=self.ws.session_id - ) + ws_params.update(sequence=self.ws.sequence, resume=True, session=self.ws.session_id) async def close(self) -> None: """|coro| @@ -789,9 +773,7 @@ def allowed_mentions(self, value: Optional[AllowedMentions]) -> None: if value is None or isinstance(value, AllowedMentions): self._connection.allowed_mentions = value else: - raise TypeError( - f"allowed_mentions must be AllowedMentions not {value.__class__!r}" - ) + raise TypeError(f"allowed_mentions must be AllowedMentions not {value.__class__!r}") @property def intents(self) -> Intents: @@ -808,9 +790,7 @@ def users(self) -> List[User]: """List[:class:`~discord.User`]: Returns a list of all the users the bot can see.""" return list(self._connection._users.values()) - def get_channel( - self, id: int, / - ) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: + def get_channel(self, id: int, /) -> Optional[Union[GuildChannel, Thread, PrivateChannel]]: """Returns a channel or thread with the given ID. Parameters @@ -825,9 +805,7 @@ def get_channel( """ return self._connection.get_channel(id) - def get_partial_messageable( - self, id: int, *, type: Optional[ChannelType] = None - ) -> PartialMessageable: + def get_partial_messageable(self, id: int, *, type: Optional[ChannelType] = None) -> PartialMessageable: """Returns a partial messageable with the given channel ID. This is useful if you have a channel_id but don't want to do an API call @@ -1369,9 +1347,7 @@ async def create_guild( region_value = str(region) if code: - data = await self.http.create_from_template( - code, name, region_value, icon_base64 - ) + data = await self.http.create_from_template(code, name, region_value, icon_base64) else: data = await self.http.create_guild(name, region_value, icon_base64) return Guild(data=data, state=self._connection) @@ -1576,9 +1552,7 @@ async def fetch_user(self, user_id: int, /) -> User: data = await self.http.get_user(user_id) return User(state=self._connection, data=data) - async def fetch_channel( - self, channel_id: int, / - ) -> Union[GuildChannel, PrivateChannel, Thread]: + async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, PrivateChannel, Thread]: """|coro| Retrieves a :class:`.abc.GuildChannel`, :class:`.abc.PrivateChannel`, or :class:`.Thread` with the specified ID. @@ -1609,9 +1583,7 @@ async def fetch_channel( factory, ch_type = _threaded_channel_factory(data["type"]) if factory is None: - raise InvalidData( - "Unknown channel type {type} for channel ID {id}.".format_map(data) - ) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): # the factory will be a DMChannel or GroupChannel here @@ -1644,9 +1616,7 @@ async def fetch_webhook(self, webhook_id: int, /) -> Webhook: data = await self.http.get_webhook(webhook_id) return Webhook.from_state(data, state=self._connection) - async def fetch_sticker( - self, sticker_id: int, / - ) -> Union[StandardSticker, GuildSticker]: + async def fetch_sticker(self, sticker_id: int, /) -> Union[StandardSticker, GuildSticker]: """|coro| Retrieves a :class:`.Sticker` with the specified ID. @@ -1687,10 +1657,7 @@ async def fetch_premium_sticker_packs(self) -> List[StickerPack]: All available premium sticker packs. """ data = await self.http.list_premium_sticker_packs() - return [ - StickerPack(state=self._connection, data=pack) - for pack in data["sticker_packs"] - ] + return [StickerPack(state=self._connection, data=pack) for pack in data["sticker_packs"]] async def create_dm(self, user: Snowflake) -> DMChannel: """|coro| @@ -1750,9 +1717,7 @@ def add_view(self, view: View, *, message_id: Optional[int] = None) -> None: raise TypeError(f"expected an instance of View not {view.__class__!r}") if not view.is_persistent(): - raise ValueError( - "View is not persistent. Items need to have a custom_id set and View must have no timeout" - ) + raise ValueError("View is not persistent. Items need to have a custom_id set and View must have no timeout") self._connection.store_view(view, message_id) diff --git a/discord/cog.py b/discord/cog.py index 8941770c64..7ee8f276ce 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -150,8 +150,7 @@ def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: new_cls = super().__new__(cls, name, bases, attrs, **kwargs) valid_commands = [ - (c for i, c in j.__dict__.items() if isinstance(c, _BaseCommand)) - for j in reversed(new_cls.__mro__) + (c for i, c in j.__dict__.items() if isinstance(c, _BaseCommand)) for j in reversed(new_cls.__mro__) ] if any(isinstance(i, ApplicationCommand) for i in valid_commands) and any( not isinstance(i, _BaseCommand) for i in valid_commands @@ -168,9 +167,7 @@ def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: del listeners[elem] try: - if getattr(value, "parent") is not None and isinstance( - value, ApplicationCommand - ): + if getattr(value, "parent") is not None and isinstance(value, ApplicationCommand): # Skip commands if they are a part of a group continue except AttributeError: @@ -181,9 +178,7 @@ def __new__(cls: Type[CogMeta], *args: Any, **kwargs: Any) -> CogMeta: value = value.__func__ if isinstance(value, _filter): if is_static_method: - raise TypeError( - f"Command in method {base}.{elem!r} must not be staticmethod." - ) + raise TypeError(f"Command in method {base}.{elem!r} must not be staticmethod.") if elem.startswith(("cog_", "bot_")): raise TypeError(no_bot_cog.format(base, elem)) commands[elem] = value @@ -278,11 +273,7 @@ def get_commands(self) -> List[ApplicationCommand]: This does not include subcommands. """ - return [ - c - for c in self.__cog_commands__ - if isinstance(c, ApplicationCommand) and c.parent is None - ] + return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None] @property def qualified_name(self) -> str: @@ -318,17 +309,12 @@ def get_listeners(self) -> List[Tuple[str, Callable[..., Any]]]: List[Tuple[:class:`str`, :ref:`coroutine `]] The listeners defined in this cog. """ - return [ - (name, getattr(self, method_name)) - for name, method_name in self.__cog_listeners__ - ] + return [(name, getattr(self, method_name)) for name, method_name in self.__cog_listeners__] @classmethod def _get_overridden_method(cls, method: FuncT) -> Optional[FuncT]: """Return None if the method is not overridden. Otherwise returns the overridden method.""" - return getattr( - getattr(method, "__func__", method), "__cog_special_method__", method - ) + return getattr(getattr(method, "__func__", method), "__cog_special_method__", method) @classmethod def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: @@ -350,9 +336,7 @@ def listener(cls, name: str = MISSING) -> Callable[[FuncT], FuncT]: """ if name is not MISSING and not isinstance(name, str): - raise TypeError( - f"Cog.listener expected str but received {name.__class__.__name__!r} instead." - ) + raise TypeError(f"Cog.listener expected str but received {name.__class__.__name__!r} instead.") def decorator(func: FuncT) -> FuncT: actual = func @@ -423,9 +407,7 @@ def cog_check(self, ctx: ApplicationContext) -> bool: return True @_cog_special_method - async def cog_command_error( - self, ctx: ApplicationContext, error: Exception - ) -> None: + async def cog_command_error(self, ctx: ApplicationContext, error: Exception) -> None: """A special method that is called whenever an error is dispatched inside this cog. @@ -670,8 +652,7 @@ def _remove_module_references(self, name: str) -> None: remove = [ index for index, event in enumerate(event_list) - if event.__module__ is not None - and _is_submodule(name, event.__module__) + if event.__module__ is not None and _is_submodule(name, event.__module__) ] for index in reversed(remove): @@ -695,9 +676,7 @@ def _call_module_finalizers(self, lib: types.ModuleType, key: str) -> None: if _is_submodule(name, module): del sys.modules[module] - def _load_from_module_spec( - self, spec: importlib.machinery.ModuleSpec, key: str - ) -> None: + def _load_from_module_spec(self, spec: importlib.machinery.ModuleSpec, key: str) -> None: # precondition: key not in self.__extensions lib = importlib.util.module_from_spec(spec) sys.modules[key] = lib @@ -858,11 +837,7 @@ def reload_extension(self, name: str, *, package: Optional[str] = None) -> None: raise errors.ExtensionNotLoaded(name) # get the previous module states from sys modules - modules = { - name: module - for name, module in sys.modules.items() - if _is_submodule(lib.__name__, name) - } + modules = {name: module for name, module in sys.modules.items() if _is_submodule(lib.__name__, name)} try: # Unload and then load the module... diff --git a/discord/colour.py b/discord/colour.py index 50bf8fb381..010c8cfa63 100644 --- a/discord/colour.py +++ b/discord/colour.py @@ -73,9 +73,7 @@ class Colour: def __init__(self, value: int): if not isinstance(value, int): - raise TypeError( - f"Expected int parameter, received {value.__class__.__name__} instead." - ) + raise TypeError(f"Expected int parameter, received {value.__class__.__name__} instead.") self.value: int = value diff --git a/discord/commands/context.py b/discord/commands/context.py index bc3a3b023a..716cb49f2f 100644 --- a/discord/commands/context.py +++ b/discord/commands/context.py @@ -159,11 +159,7 @@ def guild_locale(self) -> Optional[str]: @cached_property def me(self) -> Optional[Union[Member, ClientUser]]: - return ( - self.interaction.guild.me - if self.interaction.guild is not None - else self.bot.user - ) + return self.interaction.guild.me if self.interaction.guild is not None else self.bot.user @cached_property def message(self) -> Optional[Message]: @@ -213,8 +209,7 @@ def unselected_options(self) -> Optional[List[Option]]: return [ option for option in self.command.options # type: ignore - if option.to_dict()["name"] - not in [opt["name"] for opt in self.selected_options] + if option.to_dict()["name"] not in [opt["name"] for opt in self.selected_options] ] else: return self.command.options # type: ignore diff --git a/discord/commands/core.py b/discord/commands/core.py index a1b88d231b..157740dd25 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -118,10 +118,7 @@ async def wrapped(arg): except Exception as exc: raise ApplicationCommandInvokeError(exc) from exc finally: - if ( - hasattr(command, "_max_concurrency") - and command._max_concurrency is not None - ): + if hasattr(command, "_max_concurrency") and command._max_concurrency is not None: await command._max_concurrency.release(ctx) await command.call_after_hooks(ctx) return ret @@ -160,9 +157,7 @@ def __init__(self, func: Callable, **kwargs) -> None: elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets try: @@ -183,9 +178,7 @@ def __eq__(self, other) -> bool: check = self.id == other.id else: check = self.name == other.name and self.guild_ids == self.guild_ids - return ( - isinstance(other, self.__class__) and self.parent == other.parent and check - ) + return isinstance(other, self.__class__) and self.parent == other.parent and check async def __call__(self, ctx, *args, **kwargs): """|coro| @@ -236,9 +229,7 @@ async def prepare(self, ctx: ApplicationContext) -> None: ctx.command = self if not await self.can_run(ctx): - raise CheckFailure( - f"The check functions for the command {self.name} failed" - ) + raise CheckFailure(f"The check functions for the command {self.name} failed") if hasattr(self, "_max_concurrency"): if self._max_concurrency is not None: @@ -323,9 +314,7 @@ async def invoke(self, ctx: ApplicationContext) -> None: async def can_run(self, ctx: ApplicationContext) -> bool: if not await ctx.bot.can_run(ctx): - raise CheckFailure( - f"The global check functions for command {self.name} failed." - ) + raise CheckFailure(f"The global check functions for command {self.name} failed.") predicates = self.checks if not predicates: @@ -585,9 +574,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.id = None description = kwargs.get("description") or ( - inspect.cleandoc(func.__doc__).splitlines()[0] - if func.__doc__ is not None - else "No description provided" + inspect.cleandoc(func.__doc__).splitlines()[0] if func.__doc__ is not None else "No description provided" ) validate_chat_input_description(description) self.description: str = description @@ -615,9 +602,9 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: # Permissions self.default_permission = kwargs.get("default_permission", True) - self.permissions: List[CommandPermission] = getattr( - func, "__app_cmd_perms__", [] - ) + kwargs.get("permissions", []) + self.permissions: List[CommandPermission] = getattr(func, "__app_cmd_perms__", []) + kwargs.get( + "permissions", [] + ) if self.permissions and self.default_permission: self.default_permission = False @@ -632,9 +619,7 @@ def _parse_options(self, params) -> List[Option]: try: next(params) except StopIteration: - raise ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') final_options = [] for p_name, p_obj in params: @@ -645,9 +630,7 @@ def _parse_options(self, params) -> List[Option]: if self._is_typing_union(option): if self._is_typing_optional(option): - option = Option( - option.__args__[0], "No description provided", required=False - ) + option = Option(option.__args__[0], "No description provided", required=False) else: option = Option(option.__args__, "No description provided") @@ -683,22 +666,15 @@ def _match_option_param_names(self, params, options): try: next(params) except StopIteration: - raise ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') check_annotations = [ lambda o, a: o.input_type == SlashCommandOptionType.string and o.converter is not None, # pass on converters - lambda o, a: isinstance( - o.input_type, SlashCommandOptionType - ), # pass on slash cmd option type enums + lambda o, a: isinstance(o.input_type, SlashCommandOptionType), # pass on slash cmd option type enums lambda o, a: isinstance(o._raw_type, tuple) and a == Union[o._raw_type], # type: ignore # union types - lambda o, a: self._is_typing_optional(a) - and not o.required - and o._raw_type in a.__args__, # optional - lambda o, a: inspect.isclass(a) - and issubclass(a, o._raw_type), # 'normal' types + lambda o, a: self._is_typing_optional(a) and not o.required and o._raw_type in a.__args__, # optional + lambda o, a: inspect.isclass(a) and issubclass(a, o._raw_type), # 'normal' types ] for o in options: validate_chat_input_name(o.name) @@ -706,15 +682,11 @@ def _match_option_param_names(self, params, options): try: p_name, p_obj = next(params) except StopIteration: # not enough params for all the options - raise ClientException( - f"Too many arguments passed to the options kwarg." - ) + raise ClientException(f"Too many arguments passed to the options kwarg.") p_obj = p_obj.annotation if not any(c(o, p_obj) for c in check_annotations): - raise TypeError( - f"Parameter {p_name} does not match input type of {o.name}." - ) + raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.") o._parameter_name = p_name left_out_params = OrderedDict() @@ -726,9 +698,7 @@ def _match_option_param_names(self, params, options): return options def _is_typing_union(self, annotation): - return getattr(annotation, "__origin__", None) is Union or type( - annotation - ) is getattr( + return getattr(annotation, "__origin__", None) is Union or type(annotation) is getattr( types, "UnionType", Union ) # type: ignore @@ -759,22 +729,14 @@ async def _invoke(self, ctx: ApplicationContext) -> None: arg = arg["value"] # Checks if input_type is user, role or channel - if ( - SlashCommandOptionType.user.value - <= op.input_type.value - <= SlashCommandOptionType.role.value - ): + if SlashCommandOptionType.user.value <= op.input_type.value <= SlashCommandOptionType.role.value: if ctx.guild is None and op.input_type.name == "user": _data = ctx.interaction.data["resolved"]["users"][arg] _data["id"] = int(arg) arg = User(state=ctx.interaction._state, data=_data) else: - name = ( - "member" if op.input_type.name == "user" else op.input_type.name - ) - arg = await get_or_fetch( - ctx.guild, name, int(arg), default=int(arg) - ) + name = "member" if op.input_type.name == "user" else op.input_type.name + arg = await get_or_fetch(ctx.guild, name, int(arg), default=int(arg)) elif op.input_type == SlashCommandOptionType.mentionable: arg_id = int(arg) @@ -782,10 +744,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: if arg is None: arg = ctx.guild.get_role(arg_id) or arg_id - elif ( - op.input_type == SlashCommandOptionType.string - and (converter := op.converter) is not None - ): + elif op.input_type == SlashCommandOptionType.string and (converter := op.converter) is not None: arg = await converter.convert(converter, ctx, arg) elif op.input_type == SlashCommandOptionType.attachment: @@ -812,9 +771,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): for op in ctx.interaction.data.get("options", []): if op.get("focused", False): option = find(lambda o: o.name == op["name"], self.options) - values.update( - {i["name"]: i["value"] for i in ctx.interaction.data["options"]} - ) + values.update({i["name"]: i["value"] for i in ctx.interaction.data["options"]}) ctx.command = self ctx.focused = option ctx.value = op.get("value") @@ -829,13 +786,8 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): if asyncio.iscoroutinefunction(option.autocomplete): result = await result - choices = [ - o if isinstance(o, OptionChoice) else OptionChoice(o) - for o in result - ][:25] - return await ctx.interaction.response.send_autocomplete_result( - choices=choices - ) + choices = [o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result][:25] + return await ctx.interaction.response.send_autocomplete_result(choices=choices) def copy(self): """Creates a copy of this command. @@ -941,9 +893,7 @@ def __init__( self.name = name self.description = description self.input_type = SlashCommandOptionType.sub_command_group - self.subcommands: List[ - Union[SlashCommand, SlashCommandGroup] - ] = self.__initial_commands__ + self.subcommands: List[Union[SlashCommand, SlashCommandGroup]] = self.__initial_commands__ self.guild_ids = guild_ids self.parent = parent self.checks = [] @@ -1207,9 +1157,9 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: self.validate_parameters() self.default_permission = kwargs.get("default_permission", True) - self.permissions: List[CommandPermission] = getattr( - func, "__app_cmd_perms__", [] - ) + kwargs.get("permissions", []) + self.permissions: List[CommandPermission] = getattr(func, "__app_cmd_perms__", []) + kwargs.get( + "permissions", [] + ) if self.permissions and self.default_permission: self.default_permission = False @@ -1228,25 +1178,19 @@ def validate_parameters(self): try: next(params) except StopIteration: - raise ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') # next we have the 'user/message' as the next parameter try: next(params) except StopIteration: cmd = "user" if type(self) == UserCommand else "message" - raise ClientException( - f'Callback for {self.name} command is missing "{cmd}" parameter.' - ) + raise ClientException(f'Callback for {self.name} command is missing "{cmd}" parameter.') # next there should be no more parameters try: next(params) - raise ClientException( - f"Callback for {self.name} command has too many parameters." - ) + raise ClientException(f"Callback for {self.name} command has too many parameters.") except StopIteration: pass @@ -1400,9 +1344,7 @@ async def _invoke(self, ctx: ApplicationContext): message = v channel = ctx.interaction._state.get_channel(int(message["channel_id"])) if channel is None: - data = await ctx.interaction._state.http.start_private_message( - int(message["author"]["id"]) - ) + data = await ctx.interaction._state.http.start_private_message(int(message["author"]["id"])) channel = ctx.interaction._state.add_dm_channel(data) target = Message(state=ctx.interaction._state, channel=channel, data=message) @@ -1511,9 +1453,7 @@ def decorator(func: Callable) -> cls: if isinstance(func, ApplicationCommand): func = func.callback elif not callable(func): - raise TypeError( - "func needs to be a callable or a subclass of ApplicationCommand." - ) + raise TypeError("func needs to be a callable or a subclass of ApplicationCommand.") return cls(func, **attrs) return decorator @@ -1539,9 +1479,7 @@ def command(**kwargs): def validate_chat_input_name(name: Any): # Must meet the regex ^[\w-]{1,32}$ if not isinstance(name, str): - raise TypeError( - f"Chat input command names and options must be of type str. Received {name}" - ) + raise TypeError(f"Chat input command names and options must be of type str. Received {name}") if not re.match(r"^[\w-]{1,32}$", name): raise ValidationError( r'Chat input command names and options must follow the regex "^[\w-]{1,32}$". For more information, see ' @@ -1549,23 +1487,13 @@ def validate_chat_input_name(name: Any): f"{name}" ) if not 1 <= len(name) <= 32: - raise ValidationError( - f"Chat input command names and options must be 1-32 characters long. Received {name}" - ) - if ( - not name.lower() == name - ): # Can't use islower() as it fails if none of the chars can be lower. See #512. - raise ValidationError( - f"Chat input command names and options must be lowercase. Received {name}" - ) + raise ValidationError(f"Chat input command names and options must be 1-32 characters long. Received {name}") + if not name.lower() == name: # Can't use islower() as it fails if none of the chars can be lower. See #512. + raise ValidationError(f"Chat input command names and options must be lowercase. Received {name}") def validate_chat_input_description(description: Any): if not isinstance(description, str): - raise TypeError( - f"Command description must be of type str. Received {description}" - ) + raise TypeError(f"Command description must be of type str. Received {description}") if not 1 <= len(description) <= 100: - raise ValidationError( - f"Command description must be 1-100 characters long. Received {description}" - ) + raise ValidationError(f"Command description must be 1-100 characters long. Received {description}") diff --git a/discord/commands/errors.py b/discord/commands/errors.py index 966ab6f71d..9922a531d1 100644 --- a/discord/commands/errors.py +++ b/discord/commands/errors.py @@ -66,6 +66,4 @@ class ApplicationCommandInvokeError(ApplicationCommandError): def __init__(self, e: Exception) -> None: self.original: Exception = e - super().__init__( - f"Application Command raised an exception: {e.__class__.__name__}: {e}" - ) + super().__init__(f"Application Command raised an exception: {e.__class__.__name__}: {e}") diff --git a/discord/commands/options.py b/discord/commands/options.py index 135db9a7d0..e7e243db2e 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -92,13 +92,10 @@ def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> Non self.channel_types.append(channel_type) input_type = _type self.input_type = input_type - self.required: bool = ( - kwargs.pop("required", True) if "default" not in kwargs else False - ) + self.required: bool = kwargs.pop("required", True) if "default" not in kwargs else False self.default = kwargs.pop("default", None) self.choices: List[OptionChoice] = [ - o if isinstance(o, OptionChoice) else OptionChoice(o) - for o in kwargs.pop("choices", list()) + o if isinstance(o, OptionChoice) else OptionChoice(o) for o in kwargs.pop("choices", list()) ] if self.input_type == SlashCommandOptionType.integer: @@ -113,13 +110,9 @@ def __init__(self, input_type: Any, /, description: str = None, **kwargs) -> Non self.max_value: minmax_typehint = kwargs.pop("max_value", None) if not isinstance(self.min_value, minmax_types) and self.min_value is not None: - raise TypeError( - f'Expected {minmax_typehint} for min_value, got "{type(self.min_value).__name__}"' - ) + raise TypeError(f'Expected {minmax_typehint} for min_value, got "{type(self.min_value).__name__}"') if not (isinstance(self.max_value, minmax_types) or self.min_value is None): - raise TypeError( - f'Expected {minmax_typehint} for max_value, got "{type(self.max_value).__name__}"' - ) + raise TypeError(f'Expected {minmax_typehint} for max_value, got "{type(self.max_value).__name__}"') self.autocomplete = kwargs.pop("autocomplete", None) diff --git a/discord/commands/permissions.py b/discord/commands/permissions.py index fc17537170..be8597e98e 100644 --- a/discord/commands/permissions.py +++ b/discord/commands/permissions.py @@ -143,9 +143,7 @@ def decorator(func: Callable): func.__app_cmd_perms__ = [] # Permissions (Will Convert ID later in register_commands if needed) - app_cmd_perm = CommandPermission( - item, 1, True, guild_id - ) # {"id": item, "type": 1, "permission": True} + app_cmd_perm = CommandPermission(item, 1, True, guild_id) # {"id": item, "type": 1, "permission": True} # Append func.__app_cmd_perms__.append(app_cmd_perm) @@ -180,9 +178,7 @@ def decorator(func: Callable): # Permissions (Will Convert ID later in register_commands if needed) for item in items: - app_cmd_perm = CommandPermission( - item, 1, True, guild_id - ) # {"id": item, "type": 1, "permission": True} + app_cmd_perm = CommandPermission(item, 1, True, guild_id) # {"id": item, "type": 1, "permission": True} # Append func.__app_cmd_perms__.append(app_cmd_perm) @@ -214,9 +210,7 @@ def decorator(func: Callable): func.__app_cmd_perms__ = [] # Permissions (Will Convert ID later in register_commands if needed) - app_cmd_perm = CommandPermission( - user, 2, True, guild_id - ) # {"id": user, "type": 2, "permission": True} + app_cmd_perm = CommandPermission(user, 2, True, guild_id) # {"id": user, "type": 2, "permission": True} # Append func.__app_cmd_perms__.append(app_cmd_perm) @@ -247,9 +241,7 @@ def decorator(func: Callable): func.__app_cmd_perms__ = [] # Permissions (Will Convert ID later in register_commands if needed) - app_cmd_perm = CommandPermission( - "owner", 2, True, guild_id - ) # {"id": "owner", "type": 2, "permission": True} + app_cmd_perm = CommandPermission("owner", 2, True, guild_id) # {"id": "owner", "type": 2, "permission": True} # Append func.__app_cmd_perms__.append(app_cmd_perm) diff --git a/discord/components.py b/discord/components.py index edaab91a20..60b2b51fcf 100644 --- a/discord/components.py +++ b/discord/components.py @@ -130,9 +130,7 @@ class ActionRow(Component): def __init__(self, data: ComponentPayload): self.type: ComponentType = try_enum(ComponentType, data["type"]) - self.children: List[Component] = [ - _component_factory(d) for d in data.get("components", []) - ] + self.children: List[Component] = [_component_factory(d) for d in data.get("components", [])] def to_dict(self) -> ActionRowPayload: return { @@ -337,9 +335,7 @@ def __init__(self, data: SelectMenuPayload): self.placeholder: Optional[str] = data.get("placeholder") self.min_values: int = data.get("min_values", 1) self.max_values: int = data.get("max_values", 1) - self.options: List[SelectOption] = [ - SelectOption.from_dict(option) for option in data.get("options", []) - ] + self.options: List[SelectOption] = [SelectOption.from_dict(option) for option in data.get("options", [])] self.disabled: bool = data.get("disabled", False) def to_dict(self) -> SelectMenuPayload: @@ -410,9 +406,7 @@ def __init__( elif isinstance(emoji, _EmojiTag): emoji = emoji._to_partial() else: - raise TypeError( - f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}" - ) + raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}") self.emoji = emoji self.default = default diff --git a/discord/embeds.py b/discord/embeds.py index fc69805c56..da72932373 100644 --- a/discord/embeds.py +++ b/discord/embeds.py @@ -67,9 +67,7 @@ def __len__(self) -> int: return len(self.__dict__) def __repr__(self) -> str: - inner = ", ".join( - (f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_")) - ) + inner = ", ".join((f"{k}={v!r}" for k, v in self.__dict__.items() if not k.startswith("_"))) return f"EmbedProxy({inner})" def __getattr__(self, attr: str) -> _EmptyEmbed: @@ -355,9 +353,7 @@ def timestamp(self, value: MaybeEmpty[datetime.datetime]): elif isinstance(value, _EmptyEmbed): self._timestamp = value else: - raise TypeError( - f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead" - ) + raise TypeError(f"Expected datetime.datetime or Embed.Empty received {value.__class__.__name__} instead") @property def footer(self) -> _EmbedFooterProxy: @@ -648,9 +644,7 @@ def add_field(self: E, *, name: Any, value: Any, inline: bool = True) -> E: return self - def insert_field_at( - self: E, index: int, *, name: Any, value: Any, inline: bool = True - ) -> E: + def insert_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: """Inserts a field before a specified index to the embed. This function returns the class instance to allow for fluent-style @@ -711,9 +705,7 @@ def remove_field(self, index: int) -> None: except (AttributeError, IndexError): pass - def set_field_at( - self: E, index: int, *, name: Any, value: Any, inline: bool = True - ) -> E: + def set_field_at(self: E, index: int, *, name: Any, value: Any, inline: bool = True) -> E: """Modifies a field to the embed object. The index must point to a valid pre-existing field. @@ -752,11 +744,7 @@ def to_dict(self) -> EmbedData: """Converts this embed object into a dict.""" # add in the raw data into the dict - result = { - key[1:]: getattr(self, key) - for key in self.__slots__ - if key[0] == "_" and hasattr(self, key) - } + result = {key[1:]: getattr(self, key) for key in self.__slots__ if key[0] == "_" and hasattr(self, key)} # deal with basic convenience wrappers @@ -775,13 +763,9 @@ def to_dict(self) -> EmbedData: else: if timestamp: if timestamp.tzinfo: - result["timestamp"] = timestamp.astimezone( - tz=datetime.timezone.utc - ).isoformat() + result["timestamp"] = timestamp.astimezone(tz=datetime.timezone.utc).isoformat() else: - result["timestamp"] = timestamp.replace( - tzinfo=datetime.timezone.utc - ).isoformat() + result["timestamp"] = timestamp.replace(tzinfo=datetime.timezone.utc).isoformat() # add in the non raw attribute ones if self.type: diff --git a/discord/emoji.py b/discord/emoji.py index d059e01de3..0b2cf788d3 100644 --- a/discord/emoji.py +++ b/discord/emoji.py @@ -211,9 +211,7 @@ async def delete(self, *, reason: Optional[str] = None) -> None: An error occurred deleting the emoji. """ - await self._state.http.delete_custom_emoji( - self.guild.id, self.id, reason=reason - ) + await self._state.http.delete_custom_emoji(self.guild.id, self.id, reason=reason) async def edit( self, @@ -260,7 +258,5 @@ async def edit( if roles is not MISSING: payload["roles"] = [role.id for role in roles] - data = await self._state.http.edit_custom_emoji( - self.guild.id, self.id, payload=payload, reason=reason - ) + data = await self._state.http.edit_custom_emoji(self.guild.id, self.id, payload=payload, reason=reason) return Emoji(guild=self.guild, data=data, state=self._state) diff --git a/discord/enums.py b/discord/enums.py index 18665bbf70..4d76987ef9 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -70,29 +70,15 @@ def _create_value_cls(name, comparable): cls.__repr__ = lambda self: f"<{name}.{self.name}: {self.value!r}>" cls.__str__ = lambda self: f"{name}.{self.name}" if comparable: - cls.__le__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value <= other.value - ) - cls.__ge__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value >= other.value - ) - cls.__lt__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value < other.value - ) - cls.__gt__ = ( - lambda self, other: isinstance(other, self.__class__) - and self.value > other.value - ) + cls.__le__ = lambda self, other: isinstance(other, self.__class__) and self.value <= other.value + cls.__ge__ = lambda self, other: isinstance(other, self.__class__) and self.value >= other.value + cls.__lt__ = lambda self, other: isinstance(other, self.__class__) and self.value < other.value + cls.__gt__ = lambda self, other: isinstance(other, self.__class__) and self.value > other.value return cls def _is_descriptor(obj): - return ( - hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") - ) + return hasattr(obj, "__get__") or hasattr(obj, "__set__") or hasattr(obj, "__delete__") class EnumMeta(type): @@ -144,9 +130,7 @@ def __iter__(cls): return (cls._enum_member_map_[name] for name in cls._enum_member_names_) def __reversed__(cls): - return ( - cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_) - ) + return (cls._enum_member_map_[name] for name in reversed(cls._enum_member_names_)) def __len__(cls): return len(cls._enum_member_names_) diff --git a/discord/errors.py b/discord/errors.py index 89d8eb0b01..7b633afe56 100644 --- a/discord/errors.py +++ b/discord/errors.py @@ -138,9 +138,7 @@ class HTTPException(DiscordException): The Discord specific error code for the failure. """ - def __init__( - self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]] - ): + def __init__(self, response: _ResponseType, message: Optional[Union[str, Dict[str, Any]]]): self.response: _ResponseType = response self.status: int = response.status # type: ignore self.code: int @@ -315,9 +313,7 @@ def __init__(self, message: Optional[str] = None, *args: Any, name: str) -> None self.name: str = name message = message or f"Extension {name!r} had an error." # clean-up @everyone and @here mentions - m = message.replace("@everyone", "@\u200beveryone").replace( - "@here", "@\u200bhere" - ) + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") super().__init__(m, *args) diff --git a/discord/ext/commands/_types.py b/discord/ext/commands/_types.py index 7f86ac6d47..232bd32a13 100644 --- a/discord/ext/commands/_types.py +++ b/discord/ext/commands/_types.py @@ -41,9 +41,7 @@ Callable[["Cog", "Context[Any]"], MaybeCoro[bool]], Callable[["Context[Any]"], MaybeCoro[bool]], ] -Hook = Union[ - Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]] -] +Hook = Union[Callable[["Cog", "Context[Any]"], Coro[Any]], Callable[["Context[Any]"], Coro[Any]]] Error = Union[ Callable[["Cog", "Context[Any]", "CommandError"], Coro[Any]], Callable[["Context[Any]", "CommandError"], Coro[Any]], diff --git a/discord/ext/commands/bot.py b/discord/ext/commands/bot.py index 765f8f2c2a..0d17c5c5ed 100644 --- a/discord/ext/commands/bot.py +++ b/discord/ext/commands/bot.py @@ -167,9 +167,7 @@ async def close(self) -> None: await super().close() # type: ignore - async def on_command_error( - self, context: Context, exception: errors.CommandError - ) -> None: + async def on_command_error(self, context: Context, exception: errors.CommandError) -> None: """|coro| The default command error handler provided by the bot. @@ -191,9 +189,7 @@ async def on_command_error( return print(f"Ignoring exception in command {context.command}:", file=sys.stderr) - traceback.print_exception( - type(exception), exception, exception.__traceback__, file=sys.stderr - ) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) async def can_run(self, ctx: Context, *, call_once: bool = False) -> bool: data = self._check_once if call_once else self._checks @@ -263,9 +259,7 @@ async def get_prefix(self, message: Message) -> Union[List[str], str]: ) if not ret: - raise ValueError( - "Iterable command_prefix must contain at least one prefix" - ) + raise ValueError("Iterable command_prefix must contain at least one prefix") return ret diff --git a/discord/ext/commands/context.py b/discord/ext/commands/context.py index c86a6a858f..d49b16fd88 100644 --- a/discord/ext/commands/context.py +++ b/discord/ext/commands/context.py @@ -151,9 +151,7 @@ def __init__( self.current_parameter: Optional[inspect.Parameter] = current_parameter self._state: ConnectionState = self.message._state - async def invoke( - self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs - ) -> T: + async def invoke(self, command: Command[CogT, P, T], /, *args: P.args, **kwargs: P.kwargs) -> T: r"""|coro| Calls a command with the arguments given. diff --git a/discord/ext/commands/converter.py b/discord/ext/commands/converter.py index 79939bfe86..617d8b5d9b 100644 --- a/discord/ext/commands/converter.py +++ b/discord/ext/commands/converter.py @@ -161,9 +161,7 @@ class ObjectConverter(IDConverter[discord.Object]): """ async def convert(self, ctx: Context, argument: str) -> discord.Object: - match = self._get_id_match(argument) or re.match( - r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<(?:@(?:!|&)?|#)([0-9]{15,20})>$", argument) if match is None: raise ObjectNotFound(argument) @@ -200,14 +198,10 @@ async def query_member_named(self, guild, argument): if len(argument) > 5 and argument[-5] == "#": username, _, discriminator = argument.rpartition("#") members = await guild.query_members(username, limit=100, cache=cache) - return discord.utils.get( - members, name=username, discriminator=discriminator - ) + return discord.utils.get(members, name=username, discriminator=discriminator) else: members = await guild.query_members(argument, limit=100, cache=cache) - return discord.utils.find( - lambda m: m.name == argument or m.nick == argument, members - ) + return discord.utils.find(lambda m: m.name == argument or m.nick == argument, members) async def query_member_by_id(self, bot, guild, user_id): ws = bot._get_websocket(shard_id=guild.shard_id) @@ -232,9 +226,7 @@ async def query_member_by_id(self, bot, guild, user_id): async def convert(self, ctx: Context, argument: str) -> discord.Member: bot = ctx.bot - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) guild = ctx.guild result = None user_id = None @@ -289,9 +281,7 @@ class UserConverter(IDConverter[discord.User]): """ async def convert(self, ctx: Context, argument: str) -> discord.User: - match = self._get_id_match(argument) or re.match( - r"<@!?([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@!?([0-9]{15,20})>$", argument) result = None state = ctx._state @@ -347,9 +337,7 @@ class PartialMessageConverter(Converter[discord.PartialMessage]): @staticmethod def _get_id_matches(ctx, argument): - id_regex = re.compile( - r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$" - ) + id_regex = re.compile(r"(?:(?P[0-9]{15,20})-)?(?P[0-9]{15,20})$") link_regex = re.compile( r"https?://(?:(ptb|canary|www)\.)?discord(?:app)?\.com/channels/" r"(?P[0-9]{15,20}|@me)" @@ -375,9 +363,7 @@ def _get_id_matches(ctx, argument): return guild_id, message_id, channel_id @staticmethod - def _resolve_channel( - ctx, guild_id, channel_id - ) -> Optional[PartialMessageableChannel]: + def _resolve_channel(ctx, guild_id, channel_id) -> Optional[PartialMessageableChannel]: if guild_id is not None: guild = ctx.bot.get_guild(guild_id) if guild is not None and channel_id is not None: @@ -411,9 +397,7 @@ class MessageConverter(IDConverter[discord.Message]): """ async def convert(self, ctx: Context, argument: str) -> discord.Message: - guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches( - ctx, argument - ) + guild_id, message_id, channel_id = PartialMessageConverter._get_id_matches(ctx, argument) message = ctx.bot._connection._get_message(message_id) if message: return message @@ -444,19 +428,13 @@ class GuildChannelConverter(IDConverter[discord.abc.GuildChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.abc.GuildChannel: - return self._resolve_channel( - ctx, argument, "channels", discord.abc.GuildChannel - ) + return self._resolve_channel(ctx, argument, "channels", discord.abc.GuildChannel) @staticmethod - def _resolve_channel( - ctx: Context, argument: str, attribute: str, type: Type[CT] - ) -> CT: + def _resolve_channel(ctx: Context, argument: str, attribute: str, type: Type[CT]) -> CT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -484,14 +462,10 @@ def check(c): return result @staticmethod - def _resolve_thread( - ctx: Context, argument: str, attribute: str, type: Type[TT] - ) -> TT: + def _resolve_thread(ctx: Context, argument: str, attribute: str, type: Type[TT]) -> TT: bot = ctx.bot - match = IDConverter._get_id_match(argument) or re.match( - r"<#([0-9]{15,20})>$", argument - ) + match = IDConverter._get_id_match(argument) or re.match(r"<#([0-9]{15,20})>$", argument) result = None guild = ctx.guild @@ -528,9 +502,7 @@ class TextChannelConverter(IDConverter[discord.TextChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.TextChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "text_channels", discord.TextChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "text_channels", discord.TextChannel) class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): @@ -550,9 +522,7 @@ class VoiceChannelConverter(IDConverter[discord.VoiceChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.VoiceChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "voice_channels", discord.VoiceChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "voice_channels", discord.VoiceChannel) class StageChannelConverter(IDConverter[discord.StageChannel]): @@ -571,9 +541,7 @@ class StageChannelConverter(IDConverter[discord.StageChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StageChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "stage_channels", discord.StageChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "stage_channels", discord.StageChannel) class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): @@ -593,9 +561,7 @@ class CategoryChannelConverter(IDConverter[discord.CategoryChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.CategoryChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "categories", discord.CategoryChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "categories", discord.CategoryChannel) class StoreChannelConverter(IDConverter[discord.StoreChannel]): @@ -614,9 +580,7 @@ class StoreChannelConverter(IDConverter[discord.StoreChannel]): """ async def convert(self, ctx: Context, argument: str) -> discord.StoreChannel: - return GuildChannelConverter._resolve_channel( - ctx, argument, "channels", discord.StoreChannel - ) + return GuildChannelConverter._resolve_channel(ctx, argument, "channels", discord.StoreChannel) class ThreadConverter(IDConverter[discord.Thread]): @@ -634,9 +598,7 @@ class ThreadConverter(IDConverter[discord.Thread]): """ async def convert(self, ctx: Context, argument: str) -> discord.Thread: - return GuildChannelConverter._resolve_thread( - ctx, argument, "threads", discord.Thread - ) + return GuildChannelConverter._resolve_thread(ctx, argument, "threads", discord.Thread) class ColourConverter(Converter[discord.Colour]): @@ -665,9 +627,7 @@ class ColourConverter(Converter[discord.Colour]): Added support for ``rgb`` function and 3-digit hex shortcuts """ - RGB_REGEX = re.compile( - r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)" - ) + RGB_REGEX = re.compile(r"rgb\s*\((?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*,\s*(?P[0-9]{1,3}%?)\s*\)") def parse_hex_number(self, argument): arg = "".join(i * 2 for i in argument) if len(argument) == 3 else argument @@ -748,9 +708,7 @@ async def convert(self, ctx: Context, argument: str) -> discord.Role: if not guild: raise NoPrivateMessage() - match = self._get_id_match(argument) or re.match( - r"<@&([0-9]{15,20})>$", argument - ) + match = self._get_id_match(argument) or re.match(r"<@&([0-9]{15,20})>$", argument) if match: result = guild.get_role(int(match.group(1))) else: @@ -829,9 +787,7 @@ class EmojiConverter(IDConverter[discord.Emoji]): """ async def convert(self, ctx: Context, argument: str) -> discord.Emoji: - match = self._get_id_match(argument) or re.match( - r"$", argument - ) + match = self._get_id_match(argument) or re.match(r"$", argument) result = None bot = ctx.bot guild = ctx.guild @@ -960,27 +916,17 @@ async def convert(self, ctx: Context, argument: str) -> str: if ctx.guild: def resolve_member(id: int) -> str: - m = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.guild.get_member(id) - return ( - f"@{m.display_name if self.use_nicknames else m.name}" - if m - else "@deleted-user" - ) + m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_member(id) + return f"@{m.display_name if self.use_nicknames else m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: - r = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.guild.get_role(id) + r = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.guild.get_role(id) return f"@{r.name}" if r else "@deleted-role" else: def resolve_member(id: int) -> str: - m = ( - None if msg is None else _utils_get(msg.mentions, id=id) - ) or ctx.bot.get_user(id) + m = (None if msg is None else _utils_get(msg.mentions, id=id)) or ctx.bot.get_user(id) return f"@{m.name}" if m else "@deleted-user" def resolve_role(id: int) -> str: @@ -1061,11 +1007,7 @@ def __class_getitem__(cls, params: Union[Tuple[T], T]) -> Greedy[T]: origin = getattr(converter, "__origin__", None) args = getattr(converter, "__args__", ()) - if not ( - callable(converter) - or isinstance(converter, Converter) - or origin is not None - ): + if not (callable(converter) or isinstance(converter, Converter) or origin is not None): raise TypeError("Greedy[...] expects a type or a Converter instance.") if converter in (str, type(None)) or origin is Greedy: @@ -1128,9 +1070,7 @@ def is_generic_type(tp: Any, *, _GenericAlias: Type = _GenericAlias) -> bool: } -async def _actual_conversion( - ctx: Context, converter, argument: str, param: inspect.Parameter -): +async def _actual_conversion(ctx: Context, converter, argument: str, param: inspect.Parameter): if converter is bool: return _convert_to_bool(argument) @@ -1139,9 +1079,7 @@ async def _actual_conversion( except AttributeError: pass else: - if module is not None and ( - module.startswith("discord.") and not module.endswith("converter") - ): + if module is not None and (module.startswith("discord.") and not module.endswith("converter")): converter = CONVERTER_MAPPING.get(converter, converter) try: @@ -1167,14 +1105,10 @@ async def _actual_conversion( except AttributeError: name = converter.__class__.__name__ - raise BadArgument( - f'Converting to "{name}" failed for parameter "{param.name}".' - ) from exc + raise BadArgument(f'Converting to "{name}" failed for parameter "{param.name}".') from exc -async def run_converters( - ctx: Context, converter, argument: str, param: inspect.Parameter -): +async def run_converters(ctx: Context, converter, argument: str, param: inspect.Parameter): """|coro| Runs converters for a given converter, argument, and parameter. diff --git a/discord/ext/commands/cooldowns.py b/discord/ext/commands/cooldowns.py index 945fb30171..7fb348833e 100644 --- a/discord/ext/commands/cooldowns.py +++ b/discord/ext/commands/cooldowns.py @@ -255,17 +255,13 @@ def get_bucket(self, message: Message, current: Optional[float] = None) -> Coold return bucket - def update_rate_limit( - self, message: Message, current: Optional[float] = None - ) -> Optional[float]: + def update_rate_limit(self, message: Message, current: Optional[float] = None) -> Optional[float]: bucket = self.get_bucket(message, current) return bucket.update_rate_limit(current) class DynamicCooldownMapping(CooldownMapping): - def __init__( - self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any] - ) -> None: + def __init__(self, factory: Callable[[Message], Cooldown], type: Callable[[Message], Any]) -> None: super().__init__(None, type) self._factory: Callable[[Message], Cooldown] = factory @@ -355,17 +351,13 @@ def __init__(self, number: int, *, per: BucketType, wait: bool) -> None: raise ValueError("max_concurrency 'number' cannot be less than 1") if not isinstance(per, BucketType): - raise TypeError( - f"max_concurrency 'per' must be of type BucketType not {type(per)!r}" - ) + raise TypeError(f"max_concurrency 'per' must be of type BucketType not {type(per)!r}") def copy(self: MC) -> MC: return self.__class__(self.number, per=self.per, wait=self.wait) def __repr__(self) -> str: - return ( - f"" - ) + return f"" def get_key(self, message: Message) -> Any: return self.per.get_key(message) diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index e2ca239081..676bd34b02 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -135,9 +135,7 @@ def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: return function -def get_signature_parameters( - function: Callable[..., Any], globalns: Dict[str, Any] -) -> Dict[str, inspect.Parameter]: +def get_signature_parameters(function: Callable[..., Any], globalns: Dict[str, Any]) -> Dict[str, inspect.Parameter]: signature = inspect.signature(function) params = {} cache: Dict[str, Any] = {} @@ -355,9 +353,7 @@ def __init__( self.extras: Dict[str, Any] = kwargs.get("extras", {}) if not isinstance(self.aliases, (list, tuple)): - raise TypeError( - "Aliases of a command must be a list or a tuple of strings." - ) + raise TypeError("Aliases of a command must be a list or a tuple of strings.") self.description: str = inspect.cleandoc(kwargs.get("description", "")) self.hidden: bool = kwargs.get("hidden", False) @@ -380,9 +376,7 @@ def __init__( elif isinstance(cooldown, CooldownMapping): buckets = cooldown else: - raise TypeError( - "Cooldown must be a an instance of CooldownMapping or None." - ) + raise TypeError("Cooldown must be a an instance of CooldownMapping or None.") self._buckets: CooldownMapping = buckets try: @@ -420,10 +414,7 @@ def __init__( @property def callback( self, - ) -> Union[ - Callable[Concatenate[CogT, Context, P], Coro[T]], - Callable[Concatenate[Context, P], Coro[T]], - ]: + ) -> Union[Callable[Concatenate[CogT, Context, P], Coro[T]], Callable[Concatenate[Context, P], Coro[T]],]: return self._callback @callback.setter @@ -569,9 +560,7 @@ async def dispatch_error(self, ctx: Context, error: Exception) -> None: async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: required = param.default is param.empty converter = get_converter(param) - consume_rest_is_special = ( - param.kind == param.KEYWORD_ONLY and not self.rest_is_raw - ) + consume_rest_is_special = param.kind == param.KEYWORD_ONLY and not self.rest_is_raw view = ctx.view view.skip_ws() @@ -579,13 +568,9 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: # it undos the view ready for the next parameter to use instead if isinstance(converter, Greedy): if param.kind in (param.POSITIONAL_OR_KEYWORD, param.POSITIONAL_ONLY): - return await self._transform_greedy_pos( - ctx, param, required, converter.converter - ) + return await self._transform_greedy_pos(ctx, param, required, converter.converter) elif param.kind == param.VAR_POSITIONAL: - return await self._transform_greedy_var_pos( - ctx, param, converter.converter - ) + return await self._transform_greedy_var_pos(ctx, param, converter.converter) else: # if we're here, then it's a KEYWORD_ONLY param type # since this is mostly useless, we'll helpfully transform Greedy[X] @@ -598,10 +583,7 @@ async def transform(self, ctx: Context, param: inspect.Parameter) -> Any: if required: if self._is_typing_optional(param.annotation): return None - if ( - hasattr(converter, "__commands_is_flag__") - and converter._can_be_constructible() - ): + if hasattr(converter, "__commands_is_flag__") and converter._can_be_constructible(): return await converter._construct_default(ctx) raise MissingRequiredArgument(param) return param.default @@ -645,9 +627,7 @@ async def _transform_greedy_pos( return param.default return result - async def _transform_greedy_var_pos( - self, ctx: Context, param: inspect.Parameter, converter: Any - ) -> Any: + async def _transform_greedy_var_pos(self, ctx: Context, param: inspect.Parameter, converter: Any) -> Any: view = ctx.view previous = view.index try: @@ -761,17 +741,13 @@ async def _parse_arguments(self, ctx: Context) -> None: try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "self" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "self" parameter.') # next we have the 'ctx' as the next parameter try: next(iterator) except StopIteration: - raise discord.ClientException( - f'Callback for {self.name} command is missing "ctx" parameter.' - ) + raise discord.ClientException(f'Callback for {self.name} command is missing "ctx" parameter.') for name, param in iterator: ctx.current_parameter = param @@ -798,9 +774,7 @@ async def _parse_arguments(self, ctx: Context) -> None: break if not self.ignore_extra and not view.eof: - raise TooManyArguments( - f"Too many arguments passed to {self.qualified_name}" - ) + raise TooManyArguments(f"Too many arguments passed to {self.qualified_name}") async def call_before_hooks(self, ctx: Context) -> None: # now that we're done preparing we can call the pre-command hooks @@ -860,9 +834,7 @@ async def prepare(self, ctx: Context) -> None: ctx.command = self if not await self.can_run(ctx): - raise CheckFailure( - f"The check functions for command {self.qualified_name} failed." - ) + raise CheckFailure(f"The check functions for command {self.qualified_name} failed.") if self._max_concurrency is not None: # For this application, context can be duck-typed as a Message @@ -1075,12 +1047,9 @@ def short_doc(self) -> str: return self.help.split("\n", 1)[0] return "" - def _is_typing_optional( - self, annotation: Union[T, Optional[T]] - ) -> TypeGuard[Optional[T]]: + def _is_typing_optional(self, annotation: Union[T, Optional[T]]) -> TypeGuard[Optional[T]]: return ( - getattr(annotation, "__origin__", None) is Union - or type(annotation) is getattr(types, "UnionType", Union) + getattr(annotation, "__origin__", None) is Union or type(annotation) is getattr(types, "UnionType", Union) ) and type( None ) in annotation.__args__ # type: ignore @@ -1113,24 +1082,13 @@ def signature(self) -> str: origin = getattr(annotation, "__origin__", None) if origin is Literal: - name = "|".join( - f'"{v}"' if isinstance(v, str) else str(v) - for v in annotation.__args__ - ) + name = "|".join(f'"{v}"' if isinstance(v, str) else str(v) for v in annotation.__args__) if param.default is not param.empty: # We don't want None or '' to trigger the [name=value] case and instead it should # do [name] since [name=None] or [name=] are not exactly useful for the user. - should_print = ( - param.default - if isinstance(param.default, str) - else param.default is not None - ) + should_print = param.default if isinstance(param.default, str) else param.default is not None if should_print: - result.append( - f"[{name}={param.default}]" - if not greedy - else f"[{name}={param.default}]..." - ) + result.append(f"[{name}={param.default}]" if not greedy else f"[{name}={param.default}]...") continue else: result.append(f"[{name}]") @@ -1184,9 +1142,7 @@ async def can_run(self, ctx: Context) -> bool: try: if not await ctx.bot.can_run(ctx): - raise CheckFailure( - f"The global check functions for command {self.qualified_name} failed." - ) + raise CheckFailure(f"The global check functions for command {self.qualified_name} failed.") cog = self.cog if cog is not None: @@ -1224,9 +1180,7 @@ class GroupMixin(Generic[CogT]): def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get("case_insensitive", False) - self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = ( - _CaseInsensitiveDict() if case_insensitive else {} - ) + self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) @@ -1893,9 +1847,7 @@ async def only_for_owners(ctx): try: pred = wrapped.predicate except AttributeError: - raise TypeError( - f"{wrapped!r} must be wrapped by commands.check decorator" - ) from None + raise TypeError(f"{wrapped!r} must be wrapped by commands.check decorator") from None else: unwrapped.append(pred) @@ -1997,10 +1949,7 @@ def predicate(ctx): # ctx.guild is None doesn't narrow ctx.author to Member getter = functools.partial(discord.utils.get, ctx.author.roles) # type: ignore if any( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - for item in items + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items ): return True raise MissingAnyRole(list(items)) @@ -2059,10 +2008,7 @@ def predicate(ctx): me = ctx.me getter = functools.partial(discord.utils.get, me.roles) if any( - getter(id=item) is not None - if isinstance(item, int) - else getter(name=item) is not None - for item in items + getter(id=item) is not None if isinstance(item, int) else getter(name=item) is not None for item in items ): return True raise BotMissingAnyRole(list(items)) @@ -2108,9 +2054,7 @@ def predicate(ctx: Context) -> bool: ch = ctx.channel permissions = ch.permissions_for(ctx.author) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2137,9 +2081,7 @@ def predicate(ctx: Context) -> bool: me = guild.me if guild is not None else ctx.bot.user permissions = ctx.channel.permissions_for(me) # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2168,9 +2110,7 @@ def predicate(ctx: Context) -> bool: raise NoPrivateMessage permissions = ctx.author.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2196,9 +2136,7 @@ def predicate(ctx: Context) -> bool: raise NoPrivateMessage permissions = ctx.me.guild_permissions # type: ignore - missing = [ - perm for perm, value in perms.items() if getattr(permissions, perm) != value - ] + missing = [perm for perm, value in perms.items() if getattr(permissions, perm) != value] if not missing: return True @@ -2276,9 +2214,7 @@ def is_nsfw() -> Callable[[T], T]: def pred(ctx: Context) -> bool: ch = ctx.channel - if ctx.guild is None or ( - isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw() - ): + if ctx.guild is None or (isinstance(ch, (discord.TextChannel, discord.Thread)) and ch.is_nsfw()): return True raise NSFWChannelRequired(ch) # type: ignore @@ -2371,9 +2307,7 @@ def decorator(func: Union[Command, CoroFunc]) -> Union[Command, CoroFunc]: return decorator # type: ignore -def max_concurrency( - number: int, per: BucketType = BucketType.default, *, wait: bool = False -) -> Callable[[T], T]: +def max_concurrency(number: int, per: BucketType = BucketType.default, *, wait: bool = False) -> Callable[[T], T]: """A decorator that adds a maximum concurrency to a command This enables you to only allow a certain number of command invocations at the same time, diff --git a/discord/ext/commands/errors.py b/discord/ext/commands/errors.py index b43dae5529..1b463a460f 100644 --- a/discord/ext/commands/errors.py +++ b/discord/ext/commands/errors.py @@ -110,9 +110,7 @@ class CommandError(DiscordException): def __init__(self, message: Optional[str] = None, *args: Any) -> None: if message is not None: # clean-up @everyone and @here mentions - m = message.replace("@everyone", "@\u200beveryone").replace( - "@here", "@\u200bhere" - ) + m = message.replace("@everyone", "@\u200beveryone").replace("@here", "@\u200bhere") super().__init__(m, *args) else: super().__init__(*args) @@ -221,9 +219,7 @@ class CheckAnyFailure(CheckFailure): A list of check predicates that failed. """ - def __init__( - self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]] - ) -> None: + def __init__(self, checks: List[CheckFailure], errors: List[Callable[[Context], bool]]) -> None: self.checks: List[CheckFailure] = checks self.errors: List[Callable[[Context], bool]] = errors super().__init__("You do not have permission to run this command.") @@ -237,9 +233,7 @@ class PrivateMessageOnly(CheckFailure): """ def __init__(self, message: Optional[str] = None) -> None: - super().__init__( - message or "This command can only be used in private messages." - ) + super().__init__(message or "This command can only be used in private messages.") class NoPrivateMessage(CheckFailure): @@ -577,9 +571,7 @@ class CommandOnCooldown(CommandError): The amount of seconds to wait before you can retry again. """ - def __init__( - self, cooldown: Cooldown, retry_after: float, type: BucketType - ) -> None: + def __init__(self, cooldown: Cooldown, retry_after: float, type: BucketType) -> None: self.cooldown: Cooldown = cooldown self.retry_after: float = retry_after self.type: BucketType = type @@ -606,9 +598,7 @@ def __init__(self, number: int, per: BucketType) -> None: suffix = f"per {name}" if per.name != "default" else "globally" plural = "%s times %s" if number > 1 else "%s time %s" fmt = plural % (number, suffix) - super().__init__( - f"Too many people are using this command. It can only be used {fmt} concurrently." - ) + super().__init__(f"Too many people are using this command. It can only be used {fmt} concurrently.") class MissingRole(CheckFailure): @@ -725,9 +715,7 @@ class NSFWChannelRequired(CheckFailure): def __init__(self, channel: Union[GuildChannel, Thread]) -> None: self.channel: Union[GuildChannel, Thread] = channel - super().__init__( - f"Channel '{channel}' needs to be NSFW for this command to work." - ) + super().__init__(f"Channel '{channel}' needs to be NSFW for this command to work.") class MissingPermissions(CheckFailure): @@ -745,10 +733,7 @@ class MissingPermissions(CheckFailure): def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [ - perm.replace("_", " ").replace("guild", "server").title() - for perm in missing_permissions - ] + missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions] if len(missing) > 2: fmt = f"{', '.join(missing[:-1])}, and {missing[-1]}" @@ -773,10 +758,7 @@ class BotMissingPermissions(CheckFailure): def __init__(self, missing_permissions: List[str], *args: Any) -> None: self.missing_permissions: List[str] = missing_permissions - missing = [ - perm.replace("_", " ").replace("guild", "server").title() - for perm in missing_permissions - ] + missing = [perm.replace("_", " ").replace("guild", "server").title() for perm in missing_permissions] if len(missing) > 2: fmt = f"{', '.join(missing[:-1])}, and {missing[-1]}" @@ -802,9 +784,7 @@ class BadUnionArgument(UserInputError): A list of errors that were caught from failing the conversion. """ - def __init__( - self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError] - ) -> None: + def __init__(self, param: Parameter, converters: Tuple[Type, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.converters: Tuple[Type, ...] = converters self.errors: List[CommandError] = errors @@ -844,9 +824,7 @@ class BadLiteralArgument(UserInputError): A list of errors that were caught from failing the conversion. """ - def __init__( - self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError] - ) -> None: + def __init__(self, param: Parameter, literals: Tuple[Any, ...], errors: List[CommandError]) -> None: self.param: Parameter = param self.literals: Tuple[Any, ...] = literals self.errors: List[CommandError] = errors @@ -902,9 +880,7 @@ class InvalidEndOfQuotedStringError(ArgumentParsingError): def __init__(self, char: str) -> None: self.char: str = char - super().__init__( - f"Expected space after closing quotation but received {char!r}" - ) + super().__init__(f"Expected space after closing quotation but received {char!r}") class ExpectedClosingQuoteError(ArgumentParsingError): @@ -975,9 +951,7 @@ class TooManyFlags(FlagError): def __init__(self, flag: Flag, values: List[str]) -> None: self.flag: Flag = flag self.values: List[str] = values - super().__init__( - f"Too many flag values, expected {flag.max_args} but received {len(values)}." - ) + super().__init__(f"Too many flag values, expected {flag.max_args} but received {len(values)}.") class BadFlagArgument(FlagError): diff --git a/discord/ext/commands/flags.py b/discord/ext/commands/flags.py index e607a4d004..d5c4718078 100644 --- a/discord/ext/commands/flags.py +++ b/discord/ext/commands/flags.py @@ -161,14 +161,10 @@ def validate_flag_name(name: str, forbidden: Set[str]): if ch == "\\": raise ValueError(f"flag name {name!r} cannot have backslashes") if ch in forbidden: - raise ValueError( - f"flag name {name!r} cannot have any of {forbidden!r} within them" - ) + raise ValueError(f"flag name {name!r} cannot have any of {forbidden!r} within them") -def get_flags( - namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any] -) -> Dict[str, Flag]: +def get_flags(namespace: Dict[str, Any], globals: Dict[str, Any], locals: Dict[str, Any]) -> Dict[str, Flag]: annotations = namespace.get("__annotations__", {}) case_insensitive = namespace["__commands_flag_case_insensitive__"] flags: Dict[str, Flag] = {} @@ -185,9 +181,7 @@ def get_flags( if flag.name is MISSING: flag.name = name - annotation = flag.annotation = resolve_annotation( - flag.annotation, globals, locals, cache - ) + annotation = flag.annotation = resolve_annotation(flag.annotation, globals, locals, cache) if ( flag.default is MISSING @@ -244,9 +238,7 @@ def get_flags( if flag.max_args is MISSING: flag.max_args = 1 else: - raise TypeError( - f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag" - ) + raise TypeError(f"Unsupported typing annotation {annotation!r} for {flag.name!r} flag") if flag.override is MISSING: flag.override = False @@ -254,9 +246,7 @@ def get_flags( # Validate flag names are unique name = flag.name.casefold() if case_insensitive else flag.name if name in names: - raise TypeError( - f"{flag.name!r} flag conflicts with previous flag or alias." - ) + raise TypeError(f"{flag.name!r} flag conflicts with previous flag or alias.") else: names.add(name) @@ -264,9 +254,7 @@ def get_flags( # Validate alias is unique alias = alias.casefold() if case_insensitive else alias if alias in names: - raise TypeError( - f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias." - ) + raise TypeError(f"{flag.name!r} flag alias {alias!r} conflicts with previous flag or alias.") else: names.add(alias) @@ -321,17 +309,11 @@ def __new__( flags.update(base.__dict__["__commands_flags__"]) aliases.update(base.__dict__["__commands_flag_aliases__"]) if case_insensitive is MISSING: - attrs["__commands_flag_case_insensitive__"] = base.__dict__[ - "__commands_flag_case_insensitive__" - ] + attrs["__commands_flag_case_insensitive__"] = base.__dict__["__commands_flag_case_insensitive__"] if delimiter is MISSING: - attrs["__commands_flag_delimiter__"] = base.__dict__[ - "__commands_flag_delimiter__" - ] + attrs["__commands_flag_delimiter__"] = base.__dict__["__commands_flag_delimiter__"] if prefix is MISSING: - attrs["__commands_flag_prefix__"] = base.__dict__[ - "__commands_flag_prefix__" - ] + attrs["__commands_flag_prefix__"] = base.__dict__["__commands_flag_prefix__"] if case_insensitive is not MISSING: attrs["__commands_flag_case_insensitive__"] = case_insensitive @@ -357,9 +339,7 @@ def __new__( regex_flags = 0 if case_insensitive: flags = {key.casefold(): value for key, value in flags.items()} - aliases = { - key.casefold(): value.casefold() for key, value in aliases.items() - } + aliases = {key.casefold(): value.casefold() for key, value in aliases.items()} regex_flags = re.IGNORECASE keys = list(re.escape(k) for k in flags) @@ -378,9 +358,7 @@ def __new__( return type.__new__(cls, name, bases, attrs) -async def tuple_convert_all( - ctx: Context, argument: str, flag: Flag, converter: Any -) -> Tuple[Any, ...]: +async def tuple_convert_all(ctx: Context, argument: str, flag: Flag, converter: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -405,9 +383,7 @@ async def tuple_convert_all( return tuple(results) -async def tuple_convert_flag( - ctx: Context, argument: str, flag: Flag, converters: Any -) -> Tuple[Any, ...]: +async def tuple_convert_flag(ctx: Context, argument: str, flag: Flag, converters: Any) -> Tuple[Any, ...]: view = StringView(argument) results = [] param: inspect.Parameter = ctx.current_parameter # type: ignore @@ -445,13 +421,9 @@ async def convert_flag(ctx, argument: str, flag: Flag, annotation: Any = None) - else: if origin is tuple: if annotation.__args__[-1] is Ellipsis: - return await tuple_convert_all( - ctx, argument, flag, annotation.__args__[0] - ) + return await tuple_convert_all(ctx, argument, flag, annotation.__args__[0]) else: - return await tuple_convert_flag( - ctx, argument, flag, annotation.__args__ - ) + return await tuple_convert_flag(ctx, argument, flag, annotation.__args__) elif origin is list: # typing.List[x] annotation = annotation.__args__[0] @@ -533,12 +505,7 @@ async def _construct_default(cls: Type[F], ctx: Context) -> F: return self def __repr__(self) -> str: - pairs = " ".join( - [ - f"{flag.attribute}={getattr(self, flag.attribute)!r}" - for flag in self.get_flags().values() - ] - ) + pairs = " ".join([f"{flag.attribute}={getattr(self, flag.attribute)!r}" for flag in self.get_flags().values()]) return f"<{self.__class__.__name__} {pairs}>" @classmethod diff --git a/discord/ext/commands/help.py b/discord/ext/commands/help.py index b9df9fe593..73b65c9145 100644 --- a/discord/ext/commands/help.py +++ b/discord/ext/commands/help.py @@ -136,16 +136,11 @@ def add_line(self, line="", *, empty=False): RuntimeError The line was too big for the current :attr:`max_size`. """ - max_page_size = ( - self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len - ) + max_page_size = self.max_size - self._prefix_len - self._suffix_len - 2 * self._linesep_len if len(line) > max_page_size: raise RuntimeError(f"Line exceeds maximum page size {max_page_size}") - if ( - self._count + len(line) + self._linesep_len - > self.max_size - self._suffix_len - ): + if self._count + len(line) + self._linesep_len > self.max_size - self._suffix_len: self.close_page() self._count += len(line) + self._linesep_len @@ -403,11 +398,7 @@ def invoked_with(self): """ command_name = self._command_impl.name ctx = self.context - if ( - ctx is None - or ctx.command is None - or ctx.command.qualified_name != command_name - ): + if ctx is None or ctx.command is None or ctx.command.qualified_name != command_name: return command_name return ctx.invoked_with @@ -536,9 +527,7 @@ def subcommand_not_found(self, command, string): The string to use when the command did not have the subcommand requested. """ if isinstance(command, Group) and len(command.all_commands) > 0: - return ( - f'Command "{command.qualified_name}" has no subcommand named {string}' - ) + return f'Command "{command.qualified_name}" has no subcommand named {string}' return f'Command "{command.qualified_name}" has no subcommands.' async def filter_commands(self, commands, *, sort=False, key=None): @@ -571,15 +560,9 @@ async def filter_commands(self, commands, *, sort=False, key=None): # Ignore Application Commands cause they dont have hidden/docs prefix_commands = [ - command - for command in commands - if not isinstance(command, discord.commands.ApplicationCommand) + command for command in commands if not isinstance(command, discord.commands.ApplicationCommand) ] - iterator = ( - prefix_commands - if self.show_hidden - else filter(lambda c: not c.hidden, prefix_commands) - ) + iterator = prefix_commands if self.show_hidden else filter(lambda c: not c.hidden, prefix_commands) if self.verify_checks is False: # if we do not need to verify the checks then we can just @@ -870,24 +853,18 @@ async def command_callback(self, ctx, *, command=None): keys = command.split(" ") cmd = bot.all_commands.get(keys[0]) if cmd is None: - string = await maybe_coro( - self.command_not_found, self.remove_mentions(keys[0]) - ) + string = await maybe_coro(self.command_not_found, self.remove_mentions(keys[0])) return await self.send_error_message(string) for key in keys[1:]: try: found = cmd.all_commands.get(key) except AttributeError: - string = await maybe_coro( - self.subcommand_not_found, cmd, self.remove_mentions(key) - ) + string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) else: if found is None: - string = await maybe_coro( - self.subcommand_not_found, cmd, self.remove_mentions(key) - ) + string = await maybe_coro(self.subcommand_not_found, cmd, self.remove_mentions(key)) return await self.send_error_message(string) cmd = found @@ -1060,11 +1037,7 @@ def get_category(command, *, no_category=no_category): # Now we can add the commands to the page. for category, commands in to_iterate: - commands = ( - sorted(commands, key=lambda c: c.name) - if self.sort_commands - else list(commands) - ) + commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands) self.add_indented_commands(commands, heading=category, max_size=max_size) note = self.get_ending_note() @@ -1097,9 +1070,7 @@ async def send_cog_help(self, cog): if cog.description: self.paginator.add_line(cog.description, empty=True) - filtered = await self.filter_commands( - cog.get_commands(), sort=self.sort_commands - ) + filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) self.add_indented_commands(filtered, heading=self.commands_heading) note = self.get_ending_note() @@ -1182,9 +1153,7 @@ def get_opening_note(self): ) def get_command_signature(self, command): - return ( - f"{self.context.clean_prefix}{command.qualified_name} {command.signature}" - ) + return f"{self.context.clean_prefix}{command.qualified_name} {command.signature}" def get_ending_note(self): """Return the help command's ending note. This is mainly useful to override for i18n purposes. @@ -1233,11 +1202,7 @@ def add_subcommand_formatting(self, command): The command to show information of. """ fmt = "{0}{1} \N{EN DASH} {2}" if command.short_doc else "{0}{1}" - self.paginator.add_line( - fmt.format( - self.context.clean_prefix, command.qualified_name, command.short_doc - ) - ) + self.paginator.add_line(fmt.format(self.context.clean_prefix, command.qualified_name, command.short_doc)) def add_aliases_formatting(self, aliases): """Adds the formatting information on a command's aliases. @@ -1254,9 +1219,7 @@ def add_aliases_formatting(self, aliases): aliases: Sequence[:class:`str`] A list of aliases to format. """ - self.paginator.add_line( - f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True - ) + self.paginator.add_line(f'**{self.aliases_heading}** {", ".join(aliases)}', empty=True) def add_command_formatting(self, command): """A utility function to format commands and groups. @@ -1319,11 +1282,7 @@ def get_category(command, *, no_category=no_category): to_iterate = itertools.groupby(filtered, key=get_category) for category, commands in to_iterate: - commands = ( - sorted(commands, key=lambda c: c.name) - if self.sort_commands - else list(commands) - ) + commands = sorted(commands, key=lambda c: c.name) if self.sort_commands else list(commands) self.add_bot_commands_formatting(commands, category) note = self.get_ending_note() @@ -1345,9 +1304,7 @@ async def send_cog_help(self, cog): if cog.description: self.paginator.add_line(cog.description, empty=True) - filtered = await self.filter_commands( - cog.get_commands(), sort=self.sort_commands - ) + filtered = await self.filter_commands(cog.get_commands(), sort=self.sort_commands) if filtered: self.paginator.add_line(f"**{cog.qualified_name} {self.commands_heading}**") for command in filtered: diff --git a/discord/ext/pages/pagination.py b/discord/ext/pages/pagination.py index b5d0e11c7c..f4bc79f12d 100644 --- a/discord/ext/pages/pagination.py +++ b/discord/ext/pages/pagination.py @@ -95,10 +95,7 @@ async def callback(self, interaction: discord.Interaction): else: self.paginator.current_page -= 1 elif self.button_type == "next": - if ( - self.paginator.loop_pages - and self.paginator.current_page == self.paginator.page_count - ): + if self.paginator.loop_pages and self.paginator.current_page == self.paginator.page_count: self.paginator.current_page = 0 else: self.paginator.current_page += 1 @@ -171,9 +168,7 @@ def __init__( self.label = label self.description = description self.emoji: Union[str, discord.Emoji, discord.PartialEmoji] = emoji - self.pages: Union[ - List[str], List[Union[List[discord.Embed], discord.Embed]] - ] = pages + self.pages: Union[List[str], List[Union[List[discord.Embed], discord.Embed]]] = pages self.show_disabled = show_disabled self.show_indicator = show_indicator self.author_check = author_check @@ -238,9 +233,7 @@ class Paginator(discord.ui.View): def __init__( self, - pages: Union[ - List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]] - ], + pages: Union[List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]]], show_disabled: bool = True, show_indicator=True, show_menu=False, @@ -255,9 +248,7 @@ def __init__( ) -> None: super().__init__(timeout=timeout) self.timeout: float = timeout - self.pages: Union[ - List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]] - ] = pages + self.pages: Union[List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]]] = pages self.current_page = 0 self.menu: Optional[PaginatorMenu] = None self.show_menu = show_menu @@ -265,9 +256,7 @@ def __init__( if all(isinstance(pg, PageGroup) for pg in pages): self.page_groups = self.pages if show_menu else None - self.pages: Union[ - List[str], List[Union[List[discord.Embed], discord.Embed]] - ] = self.page_groups[0].pages + self.pages: Union[List[str], List[Union[List[discord.Embed], discord.Embed]]] = self.page_groups[0].pages self.page_count = len(self.pages) - 1 self.buttons = {} @@ -295,9 +284,7 @@ def __init__( async def update( self, - pages: Optional[ - Union[List[str], List[Union[List[discord.Embed], discord.Embed]]] - ] = None, + pages: Optional[Union[List[str], List[Union[List[discord.Embed], discord.Embed]]]] = None, show_disabled: Optional[bool] = None, show_indicator: Optional[bool] = None, author_check: Optional[bool] = None, @@ -339,34 +326,18 @@ async def update( """ # Update pages and reset current_page to 0 (default) - self.pages: Union[ - List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]] - ] = (pages if pages is not None else self.pages) + self.pages: Union[List[PageGroup], List[str], List[Union[List[discord.Embed], discord.Embed]]] = ( + pages if pages is not None else self.pages + ) self.page_count = len(self.pages) - 1 self.current_page = 0 # Apply config changes, if specified - self.show_disabled = ( - show_disabled if show_disabled is not None else self.show_disabled - ) - self.show_indicator = ( - show_indicator if show_indicator is not None else self.show_indicator - ) + self.show_disabled = show_disabled if show_disabled is not None else self.show_disabled + self.show_indicator = show_indicator if show_indicator is not None else self.show_indicator self.usercheck = author_check if author_check is not None else self.usercheck - self.disable_on_timeout = ( - disable_on_timeout - if disable_on_timeout is not None - else self.disable_on_timeout - ) - self.use_default_buttons = ( - use_default_buttons - if use_default_buttons is not None - else self.use_default_buttons - ) - self.default_button_row = ( - default_button_row - if default_button_row is not None - else self.default_button_row - ) + self.disable_on_timeout = disable_on_timeout if disable_on_timeout is not None else self.disable_on_timeout + self.use_default_buttons = use_default_buttons if use_default_buttons is not None else self.use_default_buttons + self.default_button_row = default_button_row if default_button_row is not None else self.default_button_row self.loop_pages = loop_pages if loop_pages is not None else self.loop_pages self.custom_view: discord.ui.View = None if custom_view is None else custom_view self.timeout: float = timeout if timeout is not None else self.timeout @@ -462,9 +433,7 @@ async def goto_page(self, page_number=0) -> discord.Message: self.update_buttons() self.current_page = page_number if self.show_indicator: - self.buttons["page_indicator"][ - "object" - ].label = f"{self.current_page + 1}/{self.page_count + 1}" + self.buttons["page_indicator"]["object"].label = f"{self.current_page + 1}/{self.page_count + 1}" page = self.pages[page_number] page = self.get_page_content(page) @@ -543,9 +512,7 @@ def add_button(self, button: PaginatorButton): ), "label": button.label, "loop_label": button.loop_label, - "hidden": button.disabled - if button.button_type != "page_indicator" - else not self.show_indicator, + "hidden": button.disabled if button.button_type != "page_indicator" else not self.show_indicator, } self.buttons[button.button_type]["object"].callback = button.callback button.paginator = self @@ -553,9 +520,7 @@ def add_button(self, button: PaginatorButton): def remove_button(self, button_type: str): """Removes a :class:`PaginatorButton` from the paginator.""" if button_type not in self.buttons.keys(): - raise ValueError( - f"no button_type {button_type} was found in this paginator." - ) + raise ValueError(f"no button_type {button_type} was found in this paginator.") self.buttons.pop(button_type) def update_buttons(self) -> Dict: @@ -599,9 +564,7 @@ def update_buttons(self) -> Dict: button["object"].label = button["label"] self.clear_items() if self.show_indicator: - self.buttons["page_indicator"][ - "object" - ].label = f"{self.current_page + 1}/{self.page_count + 1}" + self.buttons["page_indicator"]["object"].label = f"{self.current_page + 1}/{self.page_count + 1}" for key, button in self.buttons.items(): if key != "page_indicator": if button["hidden"]: @@ -785,9 +748,7 @@ def __init__( ) for page_group in self.page_groups ] - super().__init__( - placeholder=placeholder, max_values=1, min_values=1, options=opts - ) + super().__init__(placeholder=placeholder, max_values=1, min_values=1, options=opts) async def callback(self, interaction: discord.Interaction): selection = self.values[0] diff --git a/discord/ext/tasks/__init__.py b/discord/ext/tasks/__init__.py index e8ec272c10..5d7cccb2cd 100644 --- a/discord/ext/tasks/__init__.py +++ b/discord/ext/tasks/__init__.py @@ -62,9 +62,7 @@ class SleepHandle: __slots__ = ("future", "loop", "handle") - def __init__( - self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop - ) -> None: + def __init__(self, dt: datetime.datetime, *, loop: asyncio.AbstractEventLoop) -> None: self.loop = loop self.future = future = loop.create_future() relative_delta = discord.utils.compute_timedelta(dt) @@ -136,9 +134,7 @@ def __init__( self._next_iteration = None if not inspect.iscoroutinefunction(self.coro): - raise TypeError( - f"Expected coroutine function, not {type(self.coro).__name__!r}." - ) + raise TypeError(f"Expected coroutine function, not {type(self.coro).__name__!r}.") async def _call_loop_function(self, name: str, *args: Any, **kwargs: Any) -> None: coro = getattr(self, f"_{name}") @@ -368,9 +364,7 @@ def stop(self) -> None: self._stop_next_iteration = True def _can_be_cancelled(self) -> bool: - return bool( - not self._is_being_cancelled and self._task and not self._task.done() - ) + return bool(not self._is_being_cancelled and self._task and not self._task.done()) def cancel(self) -> None: """Cancels the internal task, if it is running.""" @@ -393,9 +387,7 @@ def restart(self, *args: Any, **kwargs: Any) -> None: The keyword arguments to use. """ - def restart_when_over( - fut: Any, *, args: Any = args, kwargs: Any = kwargs - ) -> None: + def restart_when_over(fut: Any, *, args: Any = args, kwargs: Any = kwargs) -> None: self._task.remove_done_callback(restart_when_over) self.start(*args, **kwargs) @@ -455,9 +447,7 @@ def remove_exception_type(self, *exceptions: Type[BaseException]) -> bool: Whether all exceptions were successfully removed. """ old_length = len(self._valid_exception) - self._valid_exception = tuple( - x for x in self._valid_exception if x not in exceptions - ) + self._valid_exception = tuple(x for x in self._valid_exception if x not in exceptions) return len(self._valid_exception) == old_length - len(exceptions) def get_task(self) -> Optional[asyncio.Task[None]]: @@ -488,9 +478,7 @@ async def _error(self, *args: Any) -> None: f"Unhandled exception in internal background task {self.coro.__name__!r}.", file=sys.stderr, ) - traceback.print_exception( - type(exception), exception, exception.__traceback__, file=sys.stderr - ) + traceback.print_exception(type(exception), exception, exception.__traceback__, file=sys.stderr) def before_loop(self, coro: FT) -> FT: """A decorator that registers a coroutine to be called before the loop starts running. @@ -512,9 +500,7 @@ def before_loop(self, coro: FT) -> FT: """ if not inspect.iscoroutinefunction(coro): - raise TypeError( - f"Expected coroutine function, received {coro.__class__.__name__!r}." - ) + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._before_loop = coro return coro @@ -542,9 +528,7 @@ def after_loop(self, coro: FT) -> FT: """ if not inspect.iscoroutinefunction(coro): - raise TypeError( - f"Expected coroutine function, received {coro.__class__.__name__!r}." - ) + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._after_loop = coro return coro @@ -570,9 +554,7 @@ def error(self, coro: ET) -> ET: The function was not a coroutine. """ if not inspect.iscoroutinefunction(coro): - raise TypeError( - f"Expected coroutine function, received {coro.__class__.__name__!r}." - ) + raise TypeError(f"Expected coroutine function, received {coro.__class__.__name__!r}.") self._error = coro # type: ignore return coro @@ -586,8 +568,7 @@ def _get_next_sleep_time(self) -> datetime.datetime: if self._current_loop == 0: # if we're at the last index on the first iteration, we need to sleep until tomorrow return datetime.datetime.combine( - datetime.datetime.now(datetime.timezone.utc) - + datetime.timedelta(days=1), + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), self._time[0], ) @@ -596,13 +577,10 @@ def _get_next_sleep_time(self) -> datetime.datetime: if self._current_loop == 0: self._time_index += 1 if next_time > datetime.datetime.now(datetime.timezone.utc).timetz(): - return datetime.datetime.combine( - datetime.datetime.now(datetime.timezone.utc), next_time - ) + return datetime.datetime.combine(datetime.datetime.now(datetime.timezone.utc), next_time) else: return datetime.datetime.combine( - datetime.datetime.now(datetime.timezone.utc) - + datetime.timedelta(days=1), + datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(days=1), next_time, ) @@ -619,9 +597,7 @@ def _prepare_time_index(self, now: datetime.datetime = MISSING) -> None: # pre-condition: self._time is set time_now = ( - now - if now is not MISSING - else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) + now if now is not MISSING else datetime.datetime.now(datetime.timezone.utc).replace(microsecond=0) ).timetz() for idx, time in enumerate(self._time): if time >= time_now: @@ -716,9 +692,7 @@ def change_interval( self._time = self._get_time_parameter(time) self._sleep = self._seconds = self._minutes = self._hours = MISSING - if self.is_running() and not ( - self._before_loop_running or self._after_loop_running - ): + if self.is_running() and not (self._before_loop_running or self._after_loop_running): if self._time is not MISSING: # prepare the next time index starting from after the last iteration self._prepare_time_index(now=self._last_iteration) diff --git a/discord/file.py b/discord/file.py index 9b6e0d1b2f..a9c5116b08 100644 --- a/discord/file.py +++ b/discord/file.py @@ -115,16 +115,10 @@ def __init__( else: self.filename = filename - if ( - spoiler - and self.filename is not None - and not self.filename.startswith("SPOILER_") - ): + if spoiler and self.filename is not None and not self.filename.startswith("SPOILER_"): self.filename = f"SPOILER_{self.filename}" - self.spoiler = spoiler or ( - self.filename is not None and self.filename.startswith("SPOILER_") - ) + self.spoiler = spoiler or (self.filename is not None and self.filename.startswith("SPOILER_")) self.description = description def reset(self, *, seek: Union[int, bool] = True) -> None: diff --git a/discord/flags.py b/discord/flags.py index 5bfb20a333..6898cfb81e 100644 --- a/discord/flags.py +++ b/discord/flags.py @@ -85,11 +85,7 @@ class alias_flag_value(flag_value): def fill_with_flags(*, inverted: bool = False): def decorator(cls: Type[BF]): - cls.VALID_FLAGS = { - name: value.flag - for name, value in cls.__dict__.items() - if isinstance(value, flag_value) - } + cls.VALID_FLAGS = {name: value.flag for name, value in cls.__dict__.items() if isinstance(value, flag_value)} if inverted: max_bits = max(cls.VALID_FLAGS.values()).bit_length() @@ -458,11 +454,7 @@ def bot_http_interactions(self): def all(self) -> List[UserFlags]: """List[:class:`UserFlags`]: Returns all public flags the user has.""" - return [ - public_flag - for public_flag in UserFlags - if self._has_flag(public_flag.value) - ] + return [public_flag for public_flag in UserFlags if self._has_flag(public_flag.value)] @fill_with_flags() diff --git a/discord/gateway.py b/discord/gateway.py index 2f68a20d05..1f9094e0e5 100644 --- a/discord/gateway.py +++ b/discord/gateway.py @@ -150,9 +150,7 @@ def run(self): try: f.result() except Exception: - _log.exception( - "An error occurred while stopping the gateway. Ignoring." - ) + _log.exception("An error occurred while stopping the gateway. Ignoring.") finally: self.stop() return @@ -386,9 +384,7 @@ def wait_for(self, event, predicate, result=None): """ future = self.loop.create_future() - entry = EventListener( - event=event, predicate=predicate, result=result, future=future - ) + entry = EventListener(event=event, predicate=predicate, result=result, future=future) self._dispatch_listeners.append(entry) return future @@ -426,9 +422,7 @@ async def identify(self): if state._intents is not None: payload["d"]["intents"] = state._intents.value - await self.call_hooks( - "before_identify", self.shard_id, initial=self._initial_identify - ) + await self.call_hooks("before_identify", self.shard_id, initial=self._initial_identify) await self.send_as_json(payload) _log.info("Shard ID %s has sent the IDENTIFY payload.", self.shard_id) @@ -495,9 +489,7 @@ async def received_message(self, msg, /): if op == self.HELLO: interval = data["heartbeat_interval"] / 1000.0 - self._keep_alive = KeepAliveHandler( - ws=self, interval=interval, shard_id=self.shard_id - ) + self._keep_alive = KeepAliveHandler(ws=self, interval=interval, shard_id=self.shard_id) # send a heartbeat immediately await self.send_as_json(self._keep_alive.get_payload()) self._keep_alive.start() @@ -623,9 +615,7 @@ async def poll_event(self): raise ReconnectWebSocket(self.shard_id) from None else: _log.info("Websocket closed with %s, cannot reconnect.", code) - raise ConnectionClosed( - self.socket, shard_id=self.shard_id, code=code - ) from None + raise ConnectionClosed(self.socket, shard_id=self.shard_id, code=code) from None async def debug_send(self, data, /): await self._rate_limiter.block() @@ -676,9 +666,7 @@ async def change_presence(self, *, activity=None, status=None, since=0.0): _log.debug('Sending "%s" to change status', sent) await self.send(sent) - async def request_chunks( - self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None - ): + async def request_chunks(self, guild_id, query=None, *, limit, user_ids=None, presences=False, nonce=None): payload = { "op": self.REQUEST_MEMBERS, "d": {"guild_id": guild_id, "presences": presences, "limit": limit}, @@ -865,9 +853,7 @@ async def received_message(self, msg): await self.load_secret_key(data) elif op == self.HELLO: interval = data["heartbeat_interval"] / 1000.0 - self._keep_alive = VoiceKeepAliveHandler( - ws=self, interval=min(interval, 5.0) - ) + self._keep_alive = VoiceKeepAliveHandler(ws=self, interval=min(interval, 5.0)) self._keep_alive.start() elif op == self.SPEAKING: @@ -904,9 +890,7 @@ async def initial_connection(self, data): _log.debug("detected ip: %s port: %s", state.ip, state.port) # there *should* always be at least one supported mode (xsalsa20_poly1305) - modes = [ - mode for mode in data["modes"] if mode in self._connection.supported_modes - ] + modes = [mode for mode in data["modes"] if mode in self._connection.supported_modes] _log.debug("received supported encryption modes: %s", ", ".join(modes)) mode = modes[0] diff --git a/discord/guild.py b/discord/guild.py index 103ceb6b37..54b1e7eace 100644 --- a/discord/guild.py +++ b/discord/guild.py @@ -106,9 +106,7 @@ from .webhook import Webhook VocalGuildChannel = Union[VoiceChannel, StageChannel] - GuildChannel = Union[ - VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel - ] + GuildChannel = Union[VoiceChannel, StageChannel, TextChannel, CategoryChannel, StoreChannel] ByCategoryItem = Tuple[Optional[CategoryChannel], List[GuildChannel]] @@ -388,9 +386,7 @@ def _remove_threads_by_channel(self, channel_id: int) -> None: del self._threads[k] def _filter_threads(self, channel_ids: Set[int]) -> Dict[int, Thread]: - to_remove: Dict[int, Thread] = { - k: t for k, t in self._threads.items() if t.parent_id in channel_ids - } + to_remove: Dict[int, Thread] = {k: t for k, t in self._threads.items() if t.parent_id in channel_ids} for k in to_remove: del self._threads[k] return to_remove @@ -470,15 +466,11 @@ def _from_data(self, guild: GuildPayload) -> None: self.name: str = guild.get("name") self.region: VoiceRegion = try_enum(VoiceRegion, guild.get("region")) - self.verification_level: VerificationLevel = try_enum( - VerificationLevel, guild.get("verification_level") - ) + self.verification_level: VerificationLevel = try_enum(VerificationLevel, guild.get("verification_level")) self.default_notifications: NotificationLevel = try_enum( NotificationLevel, guild.get("default_message_notifications") ) - self.explicit_content_filter: ContentFilter = try_enum( - ContentFilter, guild.get("explicit_content_filter", 0) - ) + self.explicit_content_filter: ContentFilter = try_enum(ContentFilter, guild.get("explicit_content_filter", 0)) self.afk_timeout: int = guild.get("afk_timeout") self._icon: Optional[str] = guild.get("icon") self._banner: Optional[str] = guild.get("banner") @@ -491,39 +483,25 @@ def _from_data(self, guild: GuildPayload) -> None: self._roles[role.id] = role self.mfa_level: MFALevel = guild.get("mfa_level") - self.emojis: Tuple[Emoji, ...] = tuple( - map(lambda d: state.store_emoji(self, d), guild.get("emojis", [])) - ) + self.emojis: Tuple[Emoji, ...] = tuple(map(lambda d: state.store_emoji(self, d), guild.get("emojis", []))) self.stickers: Tuple[GuildSticker, ...] = tuple( map(lambda d: state.store_sticker(self, d), guild.get("stickers", [])) ) self.features: List[GuildFeature] = guild.get("features", []) self._splash: Optional[str] = guild.get("splash") - self._system_channel_id: Optional[int] = utils._get_as_snowflake( - guild, "system_channel_id" - ) + self._system_channel_id: Optional[int] = utils._get_as_snowflake(guild, "system_channel_id") self.description: Optional[str] = guild.get("description") self.max_presences: Optional[int] = guild.get("max_presences") self.max_members: Optional[int] = guild.get("max_members") - self.max_video_channel_users: Optional[int] = guild.get( - "max_video_channel_users" - ) + self.max_video_channel_users: Optional[int] = guild.get("max_video_channel_users") self.premium_tier: int = guild.get("premium_tier", 0) - self.premium_subscription_count: int = ( - guild.get("premium_subscription_count") or 0 - ) - self.premium_progress_bar_enabled: bool = ( - guild.get("premium_progress_bar_enabled") or False - ) + self.premium_subscription_count: int = guild.get("premium_subscription_count") or 0 + self.premium_progress_bar_enabled: bool = guild.get("premium_progress_bar_enabled") or False self._system_channel_flags: int = guild.get("system_channel_flags", 0) self.preferred_locale: Optional[str] = guild.get("preferred_locale") self._discovery_splash: Optional[str] = guild.get("discovery_splash") - self._rules_channel_id: Optional[int] = utils._get_as_snowflake( - guild, "rules_channel_id" - ) - self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake( - guild, "public_updates_channel_id" - ) + self._rules_channel_id: Optional[int] = utils._get_as_snowflake(guild, "rules_channel_id") + self._public_updates_channel_id: Optional[int] = utils._get_as_snowflake(guild, "public_updates_channel_id") self.nsfw_level: NSFWLevel = try_enum(NSFWLevel, guild.get("nsfw_level", 0)) self.approximate_presence_count = guild.get("approximate_presence_count") self.approximate_member_count = guild.get("approximate_member_count") @@ -542,22 +520,12 @@ def _from_data(self, guild: GuildPayload) -> None: events = [] for event in guild.get("guild_scheduled_events", []): - creator = ( - None - if not event.get("creator", None) - else self.get_member(event.get("creator_id")) - ) - events.append( - ScheduledEvent( - state=self._state, guild=self, creator=creator, data=event - ) - ) + creator = None if not event.get("creator", None) else self.get_member(event.get("creator_id")) + events.append(ScheduledEvent(state=self._state, guild=self, creator=creator, data=event)) self._scheduled_events_from_list(events) self._sync(guild) - self._large: Optional[bool] = ( - None if member_count is None else self._member_count >= 250 - ) + self._large: Optional[bool] = None if member_count is None else self._member_count >= 250 self.owner_id: Optional[int] = utils._get_as_snowflake(guild, "owner_id") self.afk_channel: Optional[VocalGuildChannel] = self.get_channel(utils._get_as_snowflake(guild, "afk_channel_id")) # type: ignore @@ -709,17 +677,13 @@ def key(t: ByCategoryItem) -> Tuple[Tuple[int, int], List[GuildChannel]]: channels.sort(key=lambda c: (c._sorting_bucket, c.position, c.id)) return as_list - def _resolve_channel( - self, id: Optional[int], / - ) -> Optional[Union[GuildChannel, Thread]]: + def _resolve_channel(self, id: Optional[int], /) -> Optional[Union[GuildChannel, Thread]]: if id is None: return return self._channels.get(id) or self._threads.get(id) - def get_channel_or_thread( - self, channel_id: int, / - ) -> Optional[Union[Thread, GuildChannel]]: + def get_channel_or_thread(self, channel_id: int, /) -> Optional[Union[Thread, GuildChannel]]: """Returns a channel or thread with the given ID. .. versionadded:: 2.0 @@ -824,18 +788,12 @@ def sticker_limit(self) -> int: .. versionadded:: 2.0 """ more_stickers = 60 if "MORE_STICKERS" in self.features else 0 - return max( - more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers - ) + return max(more_stickers, self._PREMIUM_GUILD_LIMITS[self.premium_tier].stickers) @property def bitrate_limit(self) -> float: """:class:`float`: The maximum bitrate for voice channels this guild can have.""" - vip_guild = ( - self._PREMIUM_GUILD_LIMITS[1].bitrate - if "VIP_REGIONS" in self.features - else 96e3 - ) + vip_guild = self._PREMIUM_GUILD_LIMITS[1].bitrate if "VIP_REGIONS" in self.features else 96e3 return max(vip_guild, self._PREMIUM_GUILD_LIMITS[self.premium_tier].bitrate) @property @@ -965,27 +923,21 @@ def banner(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image( - self._state, self.id, self._banner, path="banners" - ) + return Asset._from_guild_image(self._state, self.id, self._banner, path="banners") @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image( - self._state, self.id, self._splash, path="splashes" - ) + return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes") @property def discovery_splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's discovery splash asset, if available.""" if self._discovery_splash is None: return None - return Asset._from_guild_image( - self._state, self.id, self._discovery_splash, path="discovery-splashes" - ) + return Asset._from_guild_image(self._state, self.id, self._discovery_splash, path="discovery-splashes") @property def member_count(self) -> int: @@ -1064,9 +1016,7 @@ def get_member_named(self, name: str, /) -> Optional[Member]: # do the actual lookup and return if found # if it isn't found then we'll do a full name lookup below. - result = utils.get( - members, name=name[:-5], discriminator=potential_discriminator - ) + result = utils.get(members, name=name[:-5], discriminator=potential_discriminator) if result is not None: return result @@ -1091,18 +1041,14 @@ def _create_channel( perms = [] for target, perm in overwrites.items(): if not isinstance(perm, PermissionOverwrite): - raise InvalidArgument( - f"Expected PermissionOverwrite received {perm.__class__.__name__}" - ) + raise InvalidArgument(f"Expected PermissionOverwrite received {perm.__class__.__name__}") allow, deny = perm.pair() payload = { "allow": allow.value, "deny": deny.value, "id": target.id, - "type": abc._Overwrites.ROLE - if isinstance(target, Role) - else abc._Overwrites.MEMBER, + "type": abc._Overwrites.ROLE if isinstance(target, Role) else abc._Overwrites.MEMBER, } perms.append(payload) @@ -1641,15 +1587,11 @@ async def edit( if discovery_splash is None: fields["discovery_splash"] = discovery_splash else: - fields["discovery_splash"] = utils._bytes_to_base64_data( - discovery_splash - ) + fields["discovery_splash"] = utils._bytes_to_base64_data(discovery_splash) if default_notifications is not MISSING: if not isinstance(default_notifications, NotificationLevel): - raise InvalidArgument( - "default_notifications field must be of type NotificationLevel" - ) + raise InvalidArgument("default_notifications field must be of type NotificationLevel") fields["default_message_notifications"] = default_notifications.value if afk_channel is not MISSING: @@ -1678,9 +1620,7 @@ async def edit( if owner is not MISSING: if self.owner_id != self._state.self_id: - raise InvalidArgument( - "To transfer ownership you must be the owner of the guild." - ) + raise InvalidArgument("To transfer ownership you must be the owner of the guild.") fields["owner_id"] = owner.id @@ -1689,35 +1629,26 @@ async def edit( if verification_level is not MISSING: if not isinstance(verification_level, VerificationLevel): - raise InvalidArgument( - "verification_level field must be of type VerificationLevel" - ) + raise InvalidArgument("verification_level field must be of type VerificationLevel") fields["verification_level"] = verification_level.value if explicit_content_filter is not MISSING: if not isinstance(explicit_content_filter, ContentFilter): - raise InvalidArgument( - "explicit_content_filter field must be of type ContentFilter" - ) + raise InvalidArgument("explicit_content_filter field must be of type ContentFilter") fields["explicit_content_filter"] = explicit_content_filter.value if system_channel_flags is not MISSING: if not isinstance(system_channel_flags, SystemChannelFlags): - raise InvalidArgument( - "system_channel_flags field must be of type SystemChannelFlags" - ) + raise InvalidArgument("system_channel_flags field must be of type SystemChannelFlags") fields["system_channel_flags"] = system_channel_flags.value if community is not MISSING: features = [] if community: - if ( - "rules_channel_id" in fields - and "public_updates_channel_id" in fields - ): + if "rules_channel_id" in fields and "public_updates_channel_id" in fields: features.append("COMMUNITY") else: raise InvalidArgument( @@ -1760,9 +1691,7 @@ async def fetch_channels(self) -> Sequence[GuildChannel]: def convert(d): factory, ch_type = _guild_channel_factory(d["type"]) if factory is None: - raise InvalidData( - "Unknown channel type {type} for channel ID {id}.".format_map(d) - ) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(d)) channel = factory(guild=self, state=self._state, data=d) return channel @@ -1789,10 +1718,7 @@ async def active_threads(self) -> List[Thread]: The active threads """ data = await self._state.http.get_active_threads(self.id) - threads = [ - Thread(guild=self, state=self._state, data=d) - for d in data.get("threads", []) - ] + threads = [Thread(guild=self, state=self._state, data=d) for d in data.get("threads", [])] thread_lookup: Dict[int, Thread] = {thread.id: thread for thread in threads} for member in data.get("members", []): thread = thread_lookup.get(int(member["id"])) @@ -1802,9 +1728,7 @@ async def active_threads(self) -> List[Thread]: return threads # TODO: Remove Optional typing here when async iterators are refactored - def fetch_members( - self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None - ) -> MemberIterator: + def fetch_members(self, *, limit: int = 1000, after: Optional[SnowflakeTime] = None) -> MemberIterator: """Retrieves an :class:`.AsyncIterator` that enables receiving the guild's members. In order to use this, :meth:`Intents.members` must be enabled. @@ -1914,9 +1838,7 @@ async def fetch_ban(self, user: Snowflake) -> BanEntry: The :class:`BanEntry` object for the specified user. """ data: BanPayload = await self._state.http.get_ban(user.id, self.id) - return BanEntry( - user=User(state=self._state, data=data["user"]), reason=data["reason"] - ) + return BanEntry(user=User(state=self._state, data=data["user"]), reason=data["reason"]) async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread]: """|coro| @@ -1951,9 +1873,7 @@ async def fetch_channel(self, channel_id: int, /) -> Union[GuildChannel, Thread] factory, ch_type = _threaded_guild_channel_factory(data["type"]) if factory is None: - raise InvalidData( - "Unknown channel type {type} for channel ID {id}.".format_map(data) - ) + raise InvalidData("Unknown channel type {type} for channel ID {id}.".format_map(data)) if ch_type in (ChannelType.group, ChannelType.private): raise InvalidData("Channel ID resolved to a private channel") @@ -1987,10 +1907,7 @@ async def bans(self) -> List[BanEntry]: """ data: List[BanPayload] = await self._state.http.get_bans(self.id) - return [ - BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) - for e in data - ] + return [BanEntry(user=User(state=self._state, data=e["user"]), reason=e["reason"]) for e in data] async def prune_members( self, @@ -2050,9 +1967,7 @@ async def prune_members( """ if not isinstance(days, int): - raise InvalidArgument( - f"Expected int for ``days``, received {days.__class__.__name__} instead." - ) + raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.") role_ids = [str(role.id) for role in roles] if roles else [] data = await self._state.http.prune_members( @@ -2111,9 +2026,7 @@ async def webhooks(self) -> List[Webhook]: data = await self._state.http.guild_webhooks(self.id) return [Webhook.from_state(d, state=self._state) for d in data] - async def estimate_pruned_members( - self, *, days: int, roles: List[Snowflake] = MISSING - ) -> int: + async def estimate_pruned_members(self, *, days: int, roles: List[Snowflake] = MISSING) -> int: """|coro| Similar to :meth:`prune_members` except instead of actually @@ -2146,9 +2059,7 @@ async def estimate_pruned_members( """ if not isinstance(days, int): - raise InvalidArgument( - f"Expected int for ``days``, received {days.__class__.__name__} instead." - ) + raise InvalidArgument(f"Expected int for ``days``, received {days.__class__.__name__} instead.") role_ids = [str(role.id) for role in roles] if roles else [] data = await self._state.http.estimate_pruned_members(self.id, days, role_ids) @@ -2179,15 +2090,11 @@ async def invites(self) -> List[Invite]: result = [] for invite in data: channel = self.get_channel(int(invite["channel"]["id"])) - result.append( - Invite(state=self._state, data=invite, guild=self, channel=channel) - ) + result.append(Invite(state=self._state, data=invite, guild=self, channel=channel)) return result - async def create_template( - self, *, name: str, description: str = MISSING - ) -> Template: + async def create_template(self, *, name: str, description: str = MISSING) -> Template: """|coro| Creates a template for the guild. @@ -2268,11 +2175,7 @@ async def integrations(self) -> List[Integration]: def convert(d): factory, _ = _integration_factory(d["type"]) if factory is None: - raise InvalidData( - "Unknown integration type {type!r} for integration ID {id}".format_map( - d - ) - ) + raise InvalidData("Unknown integration type {type!r} for integration ID {id}".format_map(d)) return factory(guild=self, data=d) return [convert(d) for d in data] @@ -2392,14 +2295,10 @@ async def create_sticker( payload["tags"] = emoji - data = await self._state.http.create_guild_sticker( - self.id, payload, file, reason - ) + data = await self._state.http.create_guild_sticker(self.id, payload, file, reason) return self._state.store_sticker(self, data) - async def delete_sticker( - self, sticker: Snowflake, *, reason: Optional[str] = None - ) -> None: + async def delete_sticker(self, sticker: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Deletes the custom :class:`Sticker` from the guild. @@ -2522,14 +2421,10 @@ async def create_custom_emoji( img = utils._bytes_to_base64_data(image) role_ids = [role.id for role in roles] if roles else [] - data = await self._state.http.create_custom_emoji( - self.id, name, img, roles=role_ids, reason=reason - ) + data = await self._state.http.create_custom_emoji(self.id, name, img, roles=role_ids, reason=reason) return self._state.store_emoji(self, data) - async def delete_emoji( - self, emoji: Snowflake, *, reason: Optional[str] = None - ) -> None: + async def delete_emoji(self, emoji: Snowflake, *, reason: Optional[str] = None) -> None: """|coro| Deletes the custom :class:`Emoji` from the guild. @@ -2686,9 +2581,7 @@ async def create_role( # TODO: add to cache return role - async def edit_role_positions( - self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None - ) -> List[Role]: + async def edit_role_positions(self, positions: Dict[Snowflake, int], *, reason: Optional[str] = None) -> List[Role]: """|coro| Bulk edits a list of :class:`Role` in the guild. @@ -2742,9 +2635,7 @@ async def edit_role_positions( role_positions.append(payload) - data = await self._state.http.move_role_position( - self.id, role_positions, reason=reason - ) + data = await self._state.http.move_role_position(self.id, role_positions, reason=reason) roles: List[Role] = [] for d in data: role = Role(guild=self, data=d, state=self._state) @@ -2985,9 +2876,7 @@ async def widget(self) -> Widget: return Widget(state=self._state, data=data) - async def edit_widget( - self, *, enabled: bool = MISSING, channel: Optional[Snowflake] = MISSING - ) -> None: + async def edit_widget(self, *, enabled: bool = MISSING, channel: Optional[Snowflake] = MISSING) -> None: """|coro| Edits the widget of the guild. @@ -3240,23 +3129,17 @@ async def edit_welcome_screen(self, **options): for channel in welcome_channels: if not isinstance(channel, WelcomeScreenChannel): - raise TypeError( - "welcome_channels parameter must be a list of WelcomeScreenChannel." - ) + raise TypeError("welcome_channels parameter must be a list of WelcomeScreenChannel.") welcome_channels_data.append(channel.to_dict()) options["welcome_channels"] = welcome_channels_data if options: - new = await self._state.http.edit_welcome_screen( - self.id, options, reason=options.get("reason") - ) + new = await self._state.http.edit_welcome_screen(self.id, options, reason=options.get("reason")) return WelcomeScreen(data=new, guild=self) - async def fetch_scheduled_events( - self, *, with_user_count: bool = True - ) -> List[ScheduledEvent]: + async def fetch_scheduled_events(self, *, with_user_count: bool = True) -> List[ScheduledEvent]: """|coro| Returns a list of :class:`ScheduledEvent` in the guild. @@ -3284,21 +3167,11 @@ async def fetch_scheduled_events( List[:class:`ScheduledEvent`] The fetched scheduled events """ - data = await self._state.http.get_scheduled_events( - self.id, with_user_count=with_user_count - ) + data = await self._state.http.get_scheduled_events(self.id, with_user_count=with_user_count) result = [] for event in data: - creator = ( - None - if not event.get("creator", None) - else self.get_member(event.get("creator_id")) - ) - result.append( - ScheduledEvent( - state=self._state, guild=self, creator=creator, data=event - ) - ) + creator = None if not event.get("creator", None) else self.get_member(event.get("creator_id")) + result.append(ScheduledEvent(state=self._state, guild=self, creator=creator, data=event)) self._scheduled_events_from_list(result) return result @@ -3334,14 +3207,8 @@ async def fetch_scheduled_event( data = await self._state.http.get_scheduled_event( guild_id=self.id, event_id=event_id, with_user_count=with_user_count ) - creator = ( - None - if not data.get("creator", None) - else self.get_member(data.get("creator_id")) - ) - event = ScheduledEvent( - state=self._state, guild=self, creator=creator, data=data - ) + creator = None if not data.get("creator", None) else self.get_member(data.get("creator_id")) + event = ScheduledEvent(state=self._state, guild=self, creator=creator, data=data) old_event = self._scheduled_events.get(event.id) if old_event: @@ -3437,12 +3304,8 @@ async def create_scheduled_event( if end_time is not MISSING: payload["scheduled_end_time"] = end_time.isoformat() - data = await self._state.http.create_scheduled_event( - guild_id=self.id, reason=reason, **payload - ) - event = ScheduledEvent( - state=self._state, guild=self, creator=self.me, data=data - ) + data = await self._state.http.create_scheduled_event(guild_id=self.id, reason=reason, **payload) + event = ScheduledEvent(state=self._state, guild=self, creator=self.me, data=data) self._add_scheduled_event(event) return event diff --git a/discord/http.py b/discord/http.py index 2aeb76bc5a..1d137a52f6 100644 --- a/discord/http.py +++ b/discord/http.py @@ -119,12 +119,7 @@ def __init__(self, method: str, path: str, **parameters: Any) -> None: self.method: str = method url = self.BASE + self.path if parameters: - url = url.format_map( - { - k: _uriquote(v) if isinstance(v, str) else v - for k, v in parameters.items() - } - ) + url = url.format_map({k: _uriquote(v) if isinstance(v, str) else v for k, v in parameters.items()}) self.url: str = url # major parameters: @@ -177,9 +172,7 @@ def __init__( loop: Optional[asyncio.AbstractEventLoop] = None, unsync_clock: bool = True, ) -> None: - self.loop: asyncio.AbstractEventLoop = ( - asyncio.get_event_loop() if loop is None else loop - ) + self.loop: asyncio.AbstractEventLoop = asyncio.get_event_loop() if loop is None else loop self.connector = connector self.__session: aiohttp.ClientSession = MISSING # filled in static_login self._locks: weakref.WeakValueDictionary = weakref.WeakValueDictionary() @@ -192,9 +185,7 @@ def __init__( self.use_clock: bool = not unsync_clock user_agent = "DiscordBot (https://github.com/Pycord-Development/pycord {0}) Python/{1[0]}.{1[1]} aiohttp/{2}" - self.user_agent: str = user_agent.format( - __version__, sys.version_info, aiohttp.__version__ - ) + self.user_agent: str = user_agent.format(__version__, sys.version_info, aiohttp.__version__) def recreate(self) -> None: if self.__session.closed: @@ -284,9 +275,7 @@ async def request( kwargs["data"] = form_data try: - async with self.__session.request( - method, url, **kwargs - ) as response: + async with self.__session.request(method, url, **kwargs) as response: _log.debug( "%s %s with %s has returned %s", method, @@ -302,9 +291,7 @@ async def request( remaining = response.headers.get("X-Ratelimit-Remaining") if remaining == "0" and response.status != 429: # we've depleted our current bucket - delta = utils._parse_ratelimit_header( - response, use_clock=self.use_clock - ) + delta = utils._parse_ratelimit_header(response, use_clock=self.use_clock) _log.debug( "A rate limit bucket has been exhausted (bucket: %s, retry: %s).", bucket, @@ -424,21 +411,15 @@ def logout(self) -> Response[None]: # Group functionality - def start_group( - self, user_id: Snowflake, recipients: List[int] - ) -> Response[channel.GroupDMChannel]: + def start_group(self, user_id: Snowflake, recipients: List[int]) -> Response[channel.GroupDMChannel]: payload = { "recipients": recipients, } - return self.request( - Route("POST", "/users/{user_id}/channels", user_id=user_id), json=payload - ) + return self.request(Route("POST", "/users/{user_id}/channels", user_id=user_id), json=payload) def leave_group(self, channel_id) -> Response[None]: - return self.request( - Route("DELETE", "/channels/{channel_id}", channel_id=channel_id) - ) + return self.request(Route("DELETE", "/channels/{channel_id}", channel_id=channel_id)) # Message management @@ -496,9 +477,7 @@ def send_message( return self.request(r, json=payload) def send_typing(self, channel_id: Snowflake) -> Response[None]: - return self.request( - Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id) - ) + return self.request(Route("POST", "/channels/{channel_id}/typing", channel_id=channel_id)) def send_multipart_helper( self, @@ -675,18 +654,14 @@ def delete_messages( *, reason: Optional[str] = None, ) -> Response[None]: - r = Route( - "POST", "/channels/{channel_id}/messages/bulk-delete", channel_id=channel_id - ) + r = Route("POST", "/channels/{channel_id}/messages/bulk-delete", channel_id=channel_id) payload = { "messages": message_ids, } return self.request(r, json=payload, reason=reason) - def edit_message( - self, channel_id: Snowflake, message_id: Snowflake, **fields: Any - ) -> Response[message.Message]: + def edit_message(self, channel_id: Snowflake, message_id: Snowflake, **fields: Any) -> Response[message.Message]: r = Route( "PATCH", "/channels/{channel_id}/messages/{message_id}", @@ -695,9 +670,7 @@ def edit_message( ) return self.request(r, json=fields) - def add_reaction( - self, channel_id: Snowflake, message_id: Snowflake, emoji: str - ) -> Response[None]: + def add_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( "PUT", "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", @@ -724,9 +697,7 @@ def remove_reaction( ) return self.request(r) - def remove_own_reaction( - self, channel_id: Snowflake, message_id: Snowflake, emoji: str - ) -> Response[None]: + def remove_own_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( "DELETE", "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}/@me", @@ -759,9 +730,7 @@ def get_reaction_users( params["after"] = after return self.request(r, params=params) - def clear_reactions( - self, channel_id: Snowflake, message_id: Snowflake - ) -> Response[None]: + def clear_reactions(self, channel_id: Snowflake, message_id: Snowflake) -> Response[None]: r = Route( "DELETE", "/channels/{channel_id}/messages/{message_id}/reactions", @@ -771,9 +740,7 @@ def clear_reactions( return self.request(r) - def clear_single_reaction( - self, channel_id: Snowflake, message_id: Snowflake, emoji: str - ) -> Response[None]: + def clear_single_reaction(self, channel_id: Snowflake, message_id: Snowflake, emoji: str) -> Response[None]: r = Route( "DELETE", "/channels/{channel_id}/messages/{message_id}/reactions/{emoji}", @@ -783,9 +750,7 @@ def clear_single_reaction( ) return self.request(r) - def get_message( - self, channel_id: Snowflake, message_id: Snowflake - ) -> Response[message.Message]: + def get_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: r = Route( "GET", "/channels/{channel_id}/messages/{message_id}", @@ -822,9 +787,7 @@ def logs_from( params=params, ) - def publish_message( - self, channel_id: Snowflake, message_id: Snowflake - ) -> Response[message.Message]: + def publish_message(self, channel_id: Snowflake, message_id: Snowflake) -> Response[message.Message]: return self.request( Route( "POST", @@ -834,9 +797,7 @@ def publish_message( ) ) - def pin_message( - self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None - ) -> Response[None]: + def pin_message(self, channel_id: Snowflake, message_id: Snowflake, reason: Optional[str] = None) -> Response[None]: r = Route( "PUT", "/channels/{channel_id}/pins/{message_id}", @@ -857,15 +818,11 @@ def unpin_message( return self.request(r, reason=reason) def pins_from(self, channel_id: Snowflake) -> Response[List[message.Message]]: - return self.request( - Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id) - ) + return self.request(Route("GET", "/channels/{channel_id}/pins", channel_id=channel_id)) # Member management - def kick( - self, user_id: Snowflake, guild_id: Snowflake, reason: Optional[str] = None - ) -> Response[None]: + def kick(self, user_id: Snowflake, guild_id: Snowflake, reason: Optional[str] = None) -> Response[None]: r = Route( "DELETE", "/guilds/{guild_id}/members/{user_id}", @@ -897,9 +854,7 @@ def ban( return self.request(r, params=params, reason=reason) - def unban( - self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None - ) -> Response[None]: + def unban(self, user_id: Snowflake, guild_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: r = Route( "DELETE", "/guilds/{guild_id}/bans/{user_id}", @@ -967,15 +922,11 @@ def change_nickname( } return self.request(r, json=payload, reason=reason) - def edit_my_voice_state( - self, guild_id: Snowflake, payload: Dict[str, Any] - ) -> Response[None]: + def edit_my_voice_state(self, guild_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: r = Route("PATCH", "/guilds/{guild_id}/voice-states/@me", guild_id=guild_id) return self.request(r, json=payload) - def edit_voice_state( - self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any] - ) -> Response[None]: + def edit_voice_state(self, guild_id: Snowflake, user_id: Snowflake, payload: Dict[str, Any]) -> Response[None]: r = Route( "PATCH", "/guilds/{guild_id}/voice-states/{user_id}", @@ -1068,9 +1019,7 @@ def create_channel( "video_quality_mode", "auto_archive_duration", ) - payload.update( - {k: v for k, v in options.items() if k in valid_keys and v is not None} - ) + payload.update({k: v for k, v in options.items() if k in valid_keys and v is not None}) return self.request( Route("POST", "/guilds/{guild_id}/channels", guild_id=guild_id), @@ -1142,9 +1091,7 @@ def join_thread(self, channel_id: Snowflake) -> Response[None]: ) ) - def add_user_to_thread( - self, channel_id: Snowflake, user_id: Snowflake - ) -> Response[None]: + def add_user_to_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: return self.request( Route( "PUT", @@ -1163,9 +1110,7 @@ def leave_thread(self, channel_id: Snowflake) -> Response[None]: ) ) - def remove_user_from_thread( - self, channel_id: Snowflake, user_id: Snowflake - ) -> Response[None]: + def remove_user_from_thread(self, channel_id: Snowflake, user_id: Snowflake) -> Response[None]: route = Route( "DELETE", "/channels/{channel_id}/thread-members/{user_id}", @@ -1218,18 +1163,12 @@ def get_joined_private_archived_threads( params["limit"] = limit return self.request(route, params=params) - def get_active_threads( - self, guild_id: Snowflake - ) -> Response[threads.ThreadPaginationPayload]: + def get_active_threads(self, guild_id: Snowflake) -> Response[threads.ThreadPaginationPayload]: route = Route("GET", "/guilds/{guild_id}/threads/active", guild_id=guild_id) return self.request(route) - def get_thread_members( - self, channel_id: Snowflake - ) -> Response[List[threads.ThreadMember]]: - route = Route( - "GET", "/channels/{channel_id}/thread-members", channel_id=channel_id - ) + def get_thread_members(self, channel_id: Snowflake) -> Response[List[threads.ThreadMember]]: + route = Route("GET", "/channels/{channel_id}/thread-members", channel_id=channel_id) return self.request(route) # Webhook management @@ -1251,22 +1190,14 @@ def create_webhook( r = Route("POST", "/channels/{channel_id}/webhooks", channel_id=channel_id) return self.request(r, json=payload, reason=reason) - def channel_webhooks( - self, channel_id: Snowflake - ) -> Response[List[webhook.Webhook]]: - return self.request( - Route("GET", "/channels/{channel_id}/webhooks", channel_id=channel_id) - ) + def channel_webhooks(self, channel_id: Snowflake) -> Response[List[webhook.Webhook]]: + return self.request(Route("GET", "/channels/{channel_id}/webhooks", channel_id=channel_id)) def guild_webhooks(self, guild_id: Snowflake) -> Response[List[webhook.Webhook]]: - return self.request( - Route("GET", "/guilds/{guild_id}/webhooks", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/webhooks", guild_id=guild_id)) def get_webhook(self, webhook_id: Snowflake) -> Response[webhook.Webhook]: - return self.request( - Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id) - ) + return self.request(Route("GET", "/webhooks/{webhook_id}", webhook_id=webhook_id)) def follow_webhook( self, @@ -1303,24 +1234,16 @@ def get_guilds( return self.request(Route("GET", "/users/@me/guilds"), params=params) def leave_guild(self, guild_id: Snowflake) -> Response[None]: - return self.request( - Route("DELETE", "/users/@me/guilds/{guild_id}", guild_id=guild_id) - ) + return self.request(Route("DELETE", "/users/@me/guilds/{guild_id}", guild_id=guild_id)) - def get_guild( - self, guild_id: Snowflake, *, with_counts=True - ) -> Response[guild.Guild]: + def get_guild(self, guild_id: Snowflake, *, with_counts=True) -> Response[guild.Guild]: params = {"with_counts": int(with_counts)} - return self.request( - Route("GET", "/guilds/{guild_id}", guild_id=guild_id), params=params - ) + return self.request(Route("GET", "/guilds/{guild_id}", guild_id=guild_id), params=params) def delete_guild(self, guild_id: Snowflake) -> Response[None]: return self.request(Route("DELETE", "/guilds/{guild_id}", guild_id=guild_id)) - def create_guild( - self, name: str, region: str, icon: Optional[str] - ) -> Response[guild.Guild]: + def create_guild(self, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { "name": name, "region": region, @@ -1330,9 +1253,7 @@ def create_guild( return self.request(Route("POST", "/guilds"), json=payload) - def edit_guild( - self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any - ) -> Response[guild.Guild]: + def edit_guild(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[guild.Guild]: valid_keys = ( "name", "region", @@ -1368,21 +1289,15 @@ def get_template(self, code: str) -> Response[template.Template]: return self.request(Route("GET", "/guilds/templates/{code}", code=code)) def guild_templates(self, guild_id: Snowflake) -> Response[List[template.Template]]: - return self.request( - Route("GET", "/guilds/{guild_id}/templates", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/templates", guild_id=guild_id)) - def create_template( - self, guild_id: Snowflake, payload: template.CreateTemplate - ) -> Response[template.Template]: + def create_template(self, guild_id: Snowflake, payload: template.CreateTemplate) -> Response[template.Template]: return self.request( Route("POST", "/guilds/{guild_id}/templates", guild_id=guild_id), json=payload, ) - def sync_template( - self, guild_id: Snowflake, code: str - ) -> Response[template.Template]: + def sync_template(self, guild_id: Snowflake, code: str) -> Response[template.Template]: return self.request( Route( "PUT", @@ -1392,9 +1307,7 @@ def sync_template( ) ) - def edit_template( - self, guild_id: Snowflake, code: str, payload - ) -> Response[template.Template]: + def edit_template(self, guild_id: Snowflake, code: str, payload) -> Response[template.Template]: valid_keys = ( "name", "description", @@ -1420,18 +1333,14 @@ def delete_template(self, guild_id: Snowflake, code: str) -> Response[None]: ) ) - def create_from_template( - self, code: str, name: str, region: str, icon: Optional[str] - ) -> Response[guild.Guild]: + def create_from_template(self, code: str, name: str, region: str, icon: Optional[str]) -> Response[guild.Guild]: payload = { "name": name, "region": region, } if icon: payload["icon"] = icon - return self.request( - Route("POST", "/guilds/templates/{code}", code=code), json=payload - ) + return self.request(Route("POST", "/guilds/templates/{code}", code=code), json=payload) def get_bans(self, guild_id: Snowflake) -> Response[List[guild.Ban]]: return self.request(Route("GET", "/guilds/{guild_id}/bans", guild_id=guild_id)) @@ -1447,13 +1356,9 @@ def get_ban(self, user_id: Snowflake, guild_id: Snowflake) -> Response[guild.Ban ) def get_vanity_code(self, guild_id: Snowflake) -> Response[invite.VanityInvite]: - return self.request( - Route("GET", "/guilds/{guild_id}/vanity-url", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/vanity-url", guild_id=guild_id)) - def change_vanity_code( - self, guild_id: Snowflake, code: str, *, reason: Optional[str] = None - ) -> Response[None]: + def change_vanity_code(self, guild_id: Snowflake, code: str, *, reason: Optional[str] = None) -> Response[None]: payload: Dict[str, Any] = {"code": code} return self.request( Route("PATCH", "/guilds/{guild_id}/vanity-url", guild_id=guild_id), @@ -1461,12 +1366,8 @@ def change_vanity_code( reason=reason, ) - def get_all_guild_channels( - self, guild_id: Snowflake - ) -> Response[List[guild.GuildChannel]]: - return self.request( - Route("GET", "/guilds/{guild_id}/channels", guild_id=guild_id) - ) + def get_all_guild_channels(self, guild_id: Snowflake) -> Response[List[guild.GuildChannel]]: + return self.request(Route("GET", "/guilds/{guild_id}/channels", guild_id=guild_id)) def get_members( self, guild_id: Snowflake, limit: int, after: Optional[Snowflake] @@ -1480,9 +1381,7 @@ def get_members( r = Route("GET", "/guilds/{guild_id}/members", guild_id=guild_id) return self.request(r, params=params) - def get_member( - self, guild_id: Snowflake, member_id: Snowflake - ) -> Response[member.MemberWithUser]: + def get_member(self, guild_id: Snowflake, member_id: Snowflake) -> Response[member.MemberWithUser]: return self.request( Route( "GET", @@ -1526,28 +1425,18 @@ def estimate_pruned_members( if roles: params["include_roles"] = ", ".join(roles) - return self.request( - Route("GET", "/guilds/{guild_id}/prune", guild_id=guild_id), params=params - ) + return self.request(Route("GET", "/guilds/{guild_id}/prune", guild_id=guild_id), params=params) def get_sticker(self, sticker_id: Snowflake) -> Response[sticker.Sticker]: - return self.request( - Route("GET", "/stickers/{sticker_id}", sticker_id=sticker_id) - ) + return self.request(Route("GET", "/stickers/{sticker_id}", sticker_id=sticker_id)) def list_premium_sticker_packs(self) -> Response[sticker.ListPremiumStickerPacks]: return self.request(Route("GET", "/sticker-packs")) - def get_all_guild_stickers( - self, guild_id: Snowflake - ) -> Response[List[sticker.GuildSticker]]: - return self.request( - Route("GET", "/guilds/{guild_id}/stickers", guild_id=guild_id) - ) + def get_all_guild_stickers(self, guild_id: Snowflake) -> Response[List[sticker.GuildSticker]]: + return self.request(Route("GET", "/guilds/{guild_id}/stickers", guild_id=guild_id)) - def get_guild_sticker( - self, guild_id: Snowflake, sticker_id: Snowflake - ) -> Response[sticker.GuildSticker]: + def get_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake) -> Response[sticker.GuildSticker]: return self.request( Route( "GET", @@ -1618,9 +1507,7 @@ def modify_guild_sticker( reason=reason, ) - def delete_guild_sticker( - self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str] - ) -> Response[None]: + def delete_guild_sticker(self, guild_id: Snowflake, sticker_id: Snowflake, reason: Optional[str]) -> Response[None]: return self.request( Route( "DELETE", @@ -1632,13 +1519,9 @@ def delete_guild_sticker( ) def get_all_custom_emojis(self, guild_id: Snowflake) -> Response[List[emoji.Emoji]]: - return self.request( - Route("GET", "/guilds/{guild_id}/emojis", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/emojis", guild_id=guild_id)) - def get_custom_emoji( - self, guild_id: Snowflake, emoji_id: Snowflake - ) -> Response[emoji.Emoji]: + def get_custom_emoji(self, guild_id: Snowflake, emoji_id: Snowflake) -> Response[emoji.Emoji]: return self.request( Route( "GET", @@ -1697,16 +1580,12 @@ def edit_custom_emoji( ) return self.request(r, json=payload, reason=reason) - def get_all_integrations( - self, guild_id: Snowflake - ) -> Response[List[integration.Integration]]: + def get_all_integrations(self, guild_id: Snowflake) -> Response[List[integration.Integration]]: r = Route("GET", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r) - def create_integration( - self, guild_id: Snowflake, type: integration.IntegrationType, id: int - ) -> Response[None]: + def create_integration(self, guild_id: Snowflake, type: integration.IntegrationType, id: int) -> Response[None]: payload = { "type": type, "id": id, @@ -1715,9 +1594,7 @@ def create_integration( r = Route("POST", "/guilds/{guild_id}/integrations", guild_id=guild_id) return self.request(r, json=payload) - def edit_integration( - self, guild_id: Snowflake, integration_id: Snowflake, **payload: Any - ) -> Response[None]: + def edit_integration(self, guild_id: Snowflake, integration_id: Snowflake, **payload: Any) -> Response[None]: r = Route( "PATCH", "/guilds/{guild_id}/integrations/{integration_id}", @@ -1727,9 +1604,7 @@ def edit_integration( return self.request(r, json=payload) - def sync_integration( - self, guild_id: Snowflake, integration_id: Snowflake - ) -> Response[None]: + def sync_integration(self, guild_id: Snowflake, integration_id: Snowflake) -> Response[None]: r = Route( "POST", "/guilds/{guild_id}/integrations/{integration_id}/sync", @@ -1778,16 +1653,10 @@ def get_audit_logs( return self.request(r, params=params) def get_widget(self, guild_id: Snowflake) -> Response[widget.Widget]: - return self.request( - Route("GET", "/guilds/{guild_id}/widget.json", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/widget.json", guild_id=guild_id)) - def edit_widget( - self, guild_id: Snowflake, payload - ) -> Response[widget.WidgetSettings]: - return self.request( - Route("PATCH", "/guilds/{guild_id}/widget", guild_id=guild_id), json=payload - ) + def edit_widget(self, guild_id: Snowflake, payload) -> Response[widget.WidgetSettings]: + return self.request(Route("PATCH", "/guilds/{guild_id}/widget", guild_id=guild_id), json=payload) # Invite management @@ -1839,28 +1708,16 @@ def get_invite( if guild_scheduled_event_id is not None: params["guild_scheduled_event_id"] = int(guild_scheduled_event_id) - return self.request( - Route("GET", "/invites/{invite_id}", invite_id=invite_id), params=params - ) + return self.request(Route("GET", "/invites/{invite_id}", invite_id=invite_id), params=params) def invites_from(self, guild_id: Snowflake) -> Response[List[invite.Invite]]: - return self.request( - Route("GET", "/guilds/{guild_id}/invites", guild_id=guild_id) - ) + return self.request(Route("GET", "/guilds/{guild_id}/invites", guild_id=guild_id)) - def invites_from_channel( - self, channel_id: Snowflake - ) -> Response[List[invite.Invite]]: - return self.request( - Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id) - ) + def invites_from_channel(self, channel_id: Snowflake) -> Response[List[invite.Invite]]: + return self.request(Route("GET", "/channels/{channel_id}/invites", channel_id=channel_id)) - def delete_invite( - self, invite_id: str, *, reason: Optional[str] = None - ) -> Response[None]: - return self.request( - Route("DELETE", "/invites/{invite_id}", invite_id=invite_id), reason=reason - ) + def delete_invite(self, invite_id: str, *, reason: Optional[str] = None) -> Response[None]: + return self.request(Route("DELETE", "/invites/{invite_id}", invite_id=invite_id), reason=reason) # Role management @@ -1893,9 +1750,7 @@ def edit_role( payload = {k: v for k, v in fields.items() if k in valid_keys} return self.request(r, json=payload, reason=reason) - def delete_role( - self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None - ) -> Response[None]: + def delete_role(self, guild_id: Snowflake, role_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: r = Route( "DELETE", "/guilds/{guild_id}/roles/{role_id}", @@ -1912,13 +1767,9 @@ def replace_roles( *, reason: Optional[str] = None, ) -> Response[member.MemberWithUser]: - return self.edit_member( - guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason - ) + return self.edit_member(guild_id=guild_id, user_id=user_id, roles=role_ids, reason=reason) - def create_role( - self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any - ) -> Response[role.Role]: + def create_role(self, guild_id: Snowflake, *, reason: Optional[str] = None, **fields: Any) -> Response[role.Role]: r = Route("POST", "/guilds/{guild_id}/roles", guild_id=guild_id) return self.request(r, json=fields, reason=reason) @@ -2002,12 +1853,8 @@ def delete_channel_permissions( # Welcome Screen - def get_welcome_screen( - self, guild_id: Snowflake - ) -> Response[welcome_screen.WelcomeScreen]: - return self.request( - Route("GET", "/guilds/{guild_id}/welcome-screen", guild_id=guild_id) - ) + def get_welcome_screen(self, guild_id: Snowflake) -> Response[welcome_screen.WelcomeScreen]: + return self.request(Route("GET", "/guilds/{guild_id}/welcome-screen", guild_id=guild_id)) def edit_welcome_screen( self, guild_id: Snowflake, payload: Any, *, reason: Optional[str] = None @@ -2034,22 +1881,14 @@ def move_member( *, reason: Optional[str] = None, ) -> Response[member.MemberWithUser]: - return self.edit_member( - guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason - ) + return self.edit_member(guild_id=guild_id, user_id=user_id, channel_id=channel_id, reason=reason) # Stage instance management - def get_stage_instance( - self, channel_id: Snowflake - ) -> Response[channel.StageInstance]: - return self.request( - Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id) - ) + def get_stage_instance(self, channel_id: Snowflake) -> Response[channel.StageInstance]: + return self.request(Route("GET", "/stage-instances/{channel_id}", channel_id=channel_id)) - def create_stage_instance( - self, *, reason: Optional[str], **payload: Any - ) -> Response[channel.StageInstance]: + def create_stage_instance(self, *, reason: Optional[str], **payload: Any) -> Response[channel.StageInstance]: valid_keys = ( "channel_id", "topic", @@ -2057,9 +1896,7 @@ def create_stage_instance( ) payload = {k: v for k, v in payload.items() if k in valid_keys} - return self.request( - Route("POST", "/stage-instances"), json=payload, reason=reason - ) + return self.request(Route("POST", "/stage-instances"), json=payload, reason=reason) def edit_stage_instance( self, channel_id: Snowflake, *, reason: Optional[str] = None, **payload: Any @@ -2076,9 +1913,7 @@ def edit_stage_instance( reason=reason, ) - def delete_stage_instance( - self, channel_id: Snowflake, *, reason: Optional[str] = None - ) -> Response[None]: + def delete_stage_instance(self, channel_id: Snowflake, *, reason: Optional[str] = None) -> Response[None]: return self.request( Route("DELETE", "/stage-instances/{channel_id}", channel_id=channel_id), reason=reason, @@ -2136,9 +1971,7 @@ def create_scheduled_event( reason=reason, ) - def delete_scheduled_event( - self, guild_id: Snowflake, event_id: Snowflake - ) -> Response[None]: + def delete_scheduled_event(self, guild_id: Snowflake, event_id: Snowflake) -> Response[None]: return self.request( Route( "DELETE", @@ -2212,9 +2045,7 @@ def get_scheduled_event_users( # Application commands (global) - def get_global_commands( - self, application_id: Snowflake - ) -> Response[List[interactions.ApplicationCommand]]: + def get_global_commands(self, application_id: Snowflake) -> Response[List[interactions.ApplicationCommand]]: return self.request( Route( "GET", @@ -2234,9 +2065,7 @@ def get_global_command( ) return self.request(r) - def upsert_global_command( - self, application_id: Snowflake, payload - ) -> Response[interactions.ApplicationCommand]: + def upsert_global_command(self, application_id: Snowflake, payload) -> Response[interactions.ApplicationCommand]: r = Route( "POST", "/applications/{application_id}/commands", @@ -2264,9 +2093,7 @@ def edit_global_command( ) return self.request(r, json=payload) - def delete_global_command( - self, application_id: Snowflake, command_id: Snowflake - ) -> Response[None]: + def delete_global_command(self, application_id: Snowflake, command_id: Snowflake) -> Response[None]: r = Route( "DELETE", "/applications/{application_id}/commands/{command_id}", @@ -2489,9 +2316,7 @@ def edit_original_interaction_response( allowed_mentions=allowed_mentions, ) - def delete_original_interaction_response( - self, application_id: Snowflake, token: str - ) -> Response[None]: + def delete_original_interaction_response(self, application_id: Snowflake, token: str) -> Response[None]: r = Route( "DELETE", "/webhooks/{application_id}/{interaction_token}/messages/@original", @@ -2550,9 +2375,7 @@ def edit_followup_message( allowed_mentions=allowed_mentions, ) - def delete_followup_message( - self, application_id: Snowflake, token: str, message_id: Snowflake - ) -> Response[None]: + def delete_followup_message(self, application_id: Snowflake, token: str, message_id: Snowflake) -> Response[None]: r = Route( "DELETE", "/webhooks/{application_id}/{interaction_token}/messages/{message_id}", @@ -2636,9 +2459,7 @@ async def get_gateway(self, *, encoding: str = "json", zlib: bool = True) -> str value = "{0}?encoding={1}&v=10" return value.format(data["url"], encoding) - async def get_bot_gateway( - self, *, encoding: str = "json", zlib: bool = True - ) -> Tuple[int, str]: + async def get_bot_gateway(self, *, encoding: str = "json", zlib: bool = True) -> Tuple[int, str]: try: data = await self.request(Route("GET", "/gateway/bot")) except HTTPException as exc: diff --git a/discord/integrations.py b/discord/integrations.py index 6d9670cf33..060fcce816 100644 --- a/discord/integrations.py +++ b/discord/integrations.py @@ -201,9 +201,7 @@ class StreamIntegration(Integration): def _from_data(self, data: StreamIntegrationPayload) -> None: super()._from_data(data) self.revoked: bool = data["revoked"] - self.expire_behaviour: ExpireBehaviour = try_enum( - ExpireBehaviour, data["expire_behavior"] - ) + self.expire_behaviour: ExpireBehaviour = try_enum(ExpireBehaviour, data["expire_behavior"]) self.expire_grace_period: int = data["expire_grace_period"] self.synced_at: datetime.datetime = parse_time(data["synced_at"]) self._role_id: Optional[int] = _get_as_snowflake(data, "role_id") @@ -256,9 +254,7 @@ async def edit( payload: Dict[str, Any] = {} if expire_behaviour is not MISSING: if not isinstance(expire_behaviour, ExpireBehaviour): - raise InvalidArgument( - "expire_behaviour field must be of type ExpireBehaviour" - ) + raise InvalidArgument("expire_behaviour field must be of type ExpireBehaviour") payload["expire_behavior"] = expire_behaviour.value @@ -360,9 +356,7 @@ class BotIntegration(Integration): def _from_data(self, data: BotIntegrationPayload) -> None: super()._from_data(data) - self.application = IntegrationApplication( - data=data["application"], state=self._state - ) + self.application = IntegrationApplication(data=data["application"], state=self._state) def _integration_factory(value: str) -> Tuple[Type[Integration], str]: diff --git a/discord/interactions.py b/discord/interactions.py index 2f7e7d8a08..abca1a6ba5 100644 --- a/discord/interactions.py +++ b/discord/interactions.py @@ -209,14 +209,8 @@ def channel(self) -> Optional[InteractionChannel]: channel = guild and guild._resolve_channel(self.channel_id) if channel is None: if self.channel_id is not None: - type = ( - ChannelType.text - if self.guild_id is not None - else ChannelType.private - ) - return PartialMessageable( - state=self._state, id=self.channel_id, type=type - ) + type = ChannelType.text if self.guild_id is not None else ChannelType.private + return PartialMessageable(state=self._state, id=self.channel_id, type=type) return None return channel @@ -613,14 +607,10 @@ async def send_message( state = self._parent._state if allowed_mentions is None: - payload["allowed_mentions"] = ( - state.allowed_mentions and state.allowed_mentions.to_dict() - ) + payload["allowed_mentions"] = state.allowed_mentions and state.allowed_mentions.to_dict() elif state.allowed_mentions is not None: - payload["allowed_mentions"] = state.allowed_mentions.merge( - allowed_mentions - ).to_dict() + payload["allowed_mentions"] = state.allowed_mentions.merge(allowed_mentions).to_dict() else: payload["allowed_mentions"] = allowed_mentions.to_dict() if file is not None and files is not None: @@ -634,9 +624,7 @@ async def send_message( if files is not None: if len(files) > 10: - raise InvalidArgument( - "files parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") diff --git a/discord/invite.py b/discord/invite.py index 4e1a7d7165..18afa7baa8 100644 --- a/discord/invite.py +++ b/discord/invite.py @@ -103,9 +103,7 @@ def __str__(self) -> str: return self.name def __repr__(self) -> str: - return ( - f"" - ) + return f"" @property def mention(self) -> str: @@ -176,9 +174,7 @@ def __init__(self, state: ConnectionState, data: InviteGuildPayload, id: int): self._icon: Optional[str] = data.get("icon") self._banner: Optional[str] = data.get("banner") self._splash: Optional[str] = data.get("splash") - self.verification_level: VerificationLevel = try_enum( - VerificationLevel, data.get("verification_level") - ) + self.verification_level: VerificationLevel = try_enum(VerificationLevel, data.get("verification_level")) self.description: Optional[str] = data.get("description") def __str__(self) -> str: @@ -207,18 +203,14 @@ def banner(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's banner asset, if available.""" if self._banner is None: return None - return Asset._from_guild_image( - self._state, self.id, self._banner, path="banners" - ) + return Asset._from_guild_image(self._state, self.id, self._banner, path="banners") @property def splash(self) -> Optional[Asset]: """Optional[:class:`Asset`]: Returns the guild's invite splash asset, if available.""" if self._splash is None: return None - return Asset._from_guild_image( - self._state, self.id, self._splash, path="splashes" - ) + return Asset._from_guild_image(self._state, self.id, self._splash, path="splashes") I = TypeVar("I", bound="Invite") @@ -360,55 +352,35 @@ def __init__( self._state: ConnectionState = state self.max_age: Optional[int] = data.get("max_age") self.code: str = data["code"] - self.guild: Optional[InviteGuildType] = self._resolve_guild( - data.get("guild"), guild - ) + self.guild: Optional[InviteGuildType] = self._resolve_guild(data.get("guild"), guild) self.revoked: Optional[bool] = data.get("revoked") - self.created_at: Optional[datetime.datetime] = parse_time( - data.get("created_at") - ) + self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at")) self.temporary: Optional[bool] = data.get("temporary") self.uses: Optional[int] = data.get("uses") self.max_uses: Optional[int] = data.get("max_uses") - self.approximate_presence_count: Optional[int] = data.get( - "approximate_presence_count" - ) - self.approximate_member_count: Optional[int] = data.get( - "approximate_member_count" - ) + self.approximate_presence_count: Optional[int] = data.get("approximate_presence_count") + self.approximate_member_count: Optional[int] = data.get("approximate_member_count") expires_at = data.get("expires_at", None) - self.expires_at: Optional[datetime.datetime] = ( - parse_time(expires_at) if expires_at else None - ) + self.expires_at: Optional[datetime.datetime] = parse_time(expires_at) if expires_at else None inviter_data = data.get("inviter") - self.inviter: Optional[User] = ( - None if inviter_data is None else self._state.create_user(inviter_data) - ) + self.inviter: Optional[User] = None if inviter_data is None else self._state.create_user(inviter_data) - self.channel: Optional[InviteChannelType] = self._resolve_channel( - data.get("channel"), channel - ) + self.channel: Optional[InviteChannelType] = self._resolve_channel(data.get("channel"), channel) target_user_data = data.get("target_user") self.target_user: Optional[User] = ( - None - if target_user_data is None - else self._state.create_user(target_user_data) + None if target_user_data is None else self._state.create_user(target_user_data) ) - self.target_type: InviteTarget = try_enum( - InviteTarget, data.get("target_type", 0) - ) + self.target_type: InviteTarget = try_enum(InviteTarget, data.get("target_type", 0)) from .scheduled_events import ScheduledEvent scheduled_event: ScheduledEventPayload = data.get("guild_scheduled_event") self.scheduled_event: Optional[ScheduledEvent] = ( - ScheduledEvent(state=state, data=scheduled_event) - if scheduled_event - else None + ScheduledEvent(state=state, data=scheduled_event) if scheduled_event else None ) application = data.get("target_application") @@ -417,9 +389,7 @@ def __init__( ) @classmethod - def from_incomplete( - cls: Type[I], *, state: ConnectionState, data: InvitePayload - ) -> I: + def from_incomplete(cls: Type[I], *, state: ConnectionState, data: InvitePayload) -> I: guild: Optional[Union[Guild, PartialInviteGuild]] try: guild_data = data["guild"] @@ -435,9 +405,7 @@ def from_incomplete( # As far as I know, invites always need a channel # So this should never raise. - channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel( - data["channel"] - ) + channel: Union[PartialInviteChannel, GuildChannel] = PartialInviteChannel(data["channel"]) if guild is not None and not isinstance(guild, PartialInviteGuild): # Upgrade the partial data if applicable channel = guild.get_channel(channel.id) or channel @@ -445,9 +413,7 @@ def from_incomplete( return cls(state=state, data=data, guild=guild, channel=channel) @classmethod - def from_gateway( - cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload - ) -> I: + def from_gateway(cls: Type[I], *, state: ConnectionState, data: GatewayInvitePayload) -> I: guild_id: Optional[int] = _get_as_snowflake(data, "guild_id") guild: Optional[Union[Guild, Object]] = state._get_guild(guild_id) channel_id = int(data["channel_id"]) diff --git a/discord/iterators.py b/discord/iterators.py index 824601fa83..f4c1492a95 100644 --- a/discord/iterators.py +++ b/discord/iterators.py @@ -300,9 +300,7 @@ def __init__( if self.limit is None: raise ValueError("history does not support around with limit=None") if self.limit > 101: - raise ValueError( - "history max limit 101 when specifying around parameter" - ) + raise ValueError("history max limit 101 when specifying around parameter") elif self.limit == 101: self.limit = 100 # Thanks discord @@ -358,9 +356,7 @@ async def fill_messages(self): channel = self.channel for element in data: - await self.messages.put( - self.state.create_message(channel=channel, data=element) - ) + await self.messages.put(self.state.create_message(channel=channel, data=element)) async def _retrieve_messages(self, retrieve) -> List[Message]: """Retrieve messages and update next parameters.""" @@ -369,9 +365,7 @@ async def _retrieve_messages(self, retrieve) -> List[Message]: async def _retrieve_messages_before_strategy(self, retrieve): """Retrieve messages using before parameter.""" before = self.before.id if self.before else None - data: List[MessagePayload] = await self.logs_from( - self.channel.id, retrieve, before=before - ) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, before=before) if len(data): if self.limit is not None: self.limit -= retrieve @@ -381,9 +375,7 @@ async def _retrieve_messages_before_strategy(self, retrieve): async def _retrieve_messages_after_strategy(self, retrieve): """Retrieve messages using after parameter.""" after = self.after.id if self.after else None - data: List[MessagePayload] = await self.logs_from( - self.channel.id, retrieve, after=after - ) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, after=after) if len(data): if self.limit is not None: self.limit -= retrieve @@ -394,9 +386,7 @@ async def _retrieve_messages_around_strategy(self, retrieve): """Retrieve messages using around parameter.""" if self.around: around = self.around.id if self.around else None - data: List[MessagePayload] = await self.logs_from( - self.channel.id, retrieve, around=around - ) + data: List[MessagePayload] = await self.logs_from(self.channel.id, retrieve, around=around) self.around = None return data return [] @@ -516,9 +506,7 @@ async def _fill(self): if element["action_type"] is None: continue - await self.entries.put( - AuditLogEntry(data=element, users=self._users, guild=self.guild) - ) + await self.entries.put(AuditLogEntry(data=element, users=self._users, guild=self.guild)) class GuildIterator(_AsyncIterator["Guild"]): diff --git a/discord/member.py b/discord/member.py index 36f11f3633..694e6d2744 100644 --- a/discord/member.py +++ b/discord/member.py @@ -135,9 +135,7 @@ class VoiceState: "suppress", ) - def __init__( - self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None - ): + def __init__(self, *, data: VoiceStatePayload, channel: Optional[VocalGuildChannel] = None): self.session_id: str = data.get("session_id") self._update(data, channel) @@ -169,9 +167,7 @@ def __repr__(self) -> str: def flatten_user(cls): - for attr, value in itertools.chain( - BaseUser.__dict__.items(), User.__dict__.items() - ): + for attr, value in itertools.chain(BaseUser.__dict__.items(), User.__dict__.items()): # ignore private/special methods if attr.startswith("_"): continue @@ -184,9 +180,7 @@ def flatten_user(cls): # slotted members are implemented as member_descriptors in Type.__dict__ if not hasattr(value, "__annotations__"): getter = attrgetter(f"_user.{attr}") - setattr( - cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`") - ) + setattr(cls, attr, property(getter, doc=f"Equivalent to :attr:`User.{attr}`")) else: # Technically, this can also use attrgetter # However I'm not sure how I feel about "functions" returning properties @@ -307,27 +301,21 @@ class Member(discord.abc.Messageable, _UserTag): accent_colour: Optional[Colour] communication_disabled_until: Optional[datetime.datetime] - def __init__( - self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState - ): + def __init__(self, *, data: MemberWithUserPayload, guild: Guild, state: ConnectionState): self._state: ConnectionState = state self._user: User = state.store_user(data["user"]) self.guild: Guild = guild - self.joined_at: Optional[datetime.datetime] = utils.parse_time( - data.get("joined_at") - ) - self.premium_since: Optional[datetime.datetime] = utils.parse_time( - data.get("premium_since") - ) + self.joined_at: Optional[datetime.datetime] = utils.parse_time(data.get("joined_at")) + self.premium_since: Optional[datetime.datetime] = utils.parse_time(data.get("premium_since")) self._roles: utils.SnowflakeList = utils.SnowflakeList(map(int, data["roles"])) self._client_status: Dict[Optional[str], str] = {None: "offline"} self.activities: Tuple[ActivityTypes, ...] = tuple() self.nick: Optional[str] = data.get("nick", None) self.pending: bool = data.get("pending", False) self._avatar: Optional[str] = data.get("avatar") - self.communication_disabled_until: Optional[ - datetime.datetime - ] = utils.parse_time(data.get("communication_disabled_until")) + self.communication_disabled_until: Optional[datetime.datetime] = utils.parse_time( + data.get("communication_disabled_until") + ) def __str__(self) -> str: return str(self._user) @@ -418,13 +406,9 @@ def _update(self, data: MemberPayload) -> None: self.premium_since = utils.parse_time(data.get("premium_since")) self._roles = utils.SnowflakeList(map(int, data["roles"])) self._avatar = data.get("avatar") - self.communication_disabled_until = utils.parse_time( - data.get("communication_disabled_until") - ) + self.communication_disabled_until = utils.parse_time(data.get("communication_disabled_until")) - def _presence_update( - self, data: PartialPresenceUpdate, user: UserPayload - ) -> Optional[Tuple[User, User]]: + def _presence_update(self, data: PartialPresenceUpdate, user: UserPayload) -> Optional[Tuple[User, User]]: self.activities = tuple(map(create_activity, data["activities"])) self._client_status = { sys.intern(key): sys.intern(value) for key, value in data.get("client_status", {}).items() # type: ignore @@ -575,9 +559,7 @@ def guild_avatar(self) -> Optional[Asset]: """ if self._avatar is None: return None - return Asset._from_guild_avatar( - self._state, self.guild.id, self.id, self._avatar - ) + return Asset._from_guild_avatar(self._state, self.guild.id, self.id, self._avatar) @property def activity(self) -> Optional[ActivityTypes]: @@ -669,8 +651,7 @@ def timed_out(self) -> bool: """ return ( self.communication_disabled_until is not None - and self.communication_disabled_until - > datetime.datetime.now(datetime.timezone.utc) + and self.communication_disabled_until > datetime.datetime.now(datetime.timezone.utc) ) async def ban( @@ -683,9 +664,7 @@ async def ban( Bans this member. Equivalent to :meth:`Guild.ban`. """ - await self.guild.ban( - self, reason=reason, delete_message_days=delete_message_days - ) + await self.guild.ban(self, reason=reason, delete_message_days=delete_message_days) async def unban(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -812,9 +791,7 @@ async def edit( await http.edit_my_voice_state(guild_id, voice_state_payload) else: if not suppress: - voice_state_payload[ - "request_to_speak_timestamp" - ] = datetime.datetime.utcnow().isoformat() + voice_state_payload["request_to_speak_timestamp"] = datetime.datetime.utcnow().isoformat() await http.edit_voice_state(guild_id, self.id, voice_state_payload) if voice_channel is not MISSING: @@ -825,9 +802,7 @@ async def edit( if communication_disabled_until is not MISSING: if communication_disabled_until is not None: - payload[ - "communication_disabled_until" - ] = communication_disabled_until.isoformat() + payload["communication_disabled_until"] = communication_disabled_until.isoformat() else: payload["communication_disabled_until"] = communication_disabled_until @@ -835,9 +810,7 @@ async def edit( data = await http.edit_member(guild_id, self.id, reason=reason, **payload) return Member(data=data, guild=self.guild, state=self._state) - async def timeout( - self, until: Optional[datetime.datetime], *, reason: Optional[str] = None - ) -> None: + async def timeout(self, until: Optional[datetime.datetime], *, reason: Optional[str] = None) -> None: """|coro| Applies a timeout to a member in the guild until a set datetime. @@ -860,9 +833,7 @@ async def timeout( """ await self.edit(communication_disabled_until=until, reason=reason) - async def timeout_for( - self, duration: datetime.timedelta, *, reason: Optional[str] = None - ) -> None: + async def timeout_for(self, duration: datetime.timedelta, *, reason: Optional[str] = None) -> None: """|coro| Applies a timeout to a member in the guild for a set duration. A shortcut method for :meth:`~.timeout`, and @@ -885,9 +856,7 @@ async def timeout_for( HTTPException An error occurred doing the request. """ - await self.timeout( - datetime.datetime.now(datetime.timezone.utc) + duration, reason=reason - ) + await self.timeout(datetime.datetime.now(datetime.timezone.utc) + duration, reason=reason) async def remove_timeout(self, *, reason: Optional[str] = None) -> None: """|coro| @@ -945,9 +914,7 @@ async def request_to_speak(self) -> None: else: await self._state.http.edit_my_voice_state(self.guild.id, payload) - async def move_to( - self, channel: VocalGuildChannel, *, reason: Optional[str] = None - ) -> None: + async def move_to(self, channel: VocalGuildChannel, *, reason: Optional[str] = None) -> None: """|coro| Moves a member to a new voice channel (they must be connected first). @@ -970,9 +937,7 @@ async def move_to( """ await self.edit(voice_channel=channel, reason=reason) - async def add_roles( - self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True - ) -> None: + async def add_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: r"""|coro| Gives the member a number of :class:`Role`\s. @@ -1002,9 +967,7 @@ async def add_roles( """ if not atomic: - new_roles = utils._unique( - Object(id=r.id) for s in (self.roles[1:], roles) for r in s - ) + new_roles = utils._unique(Object(id=r.id) for s in (self.roles[1:], roles) for r in s) await self.edit(roles=new_roles, reason=reason) else: req = self._state.http.add_role @@ -1013,9 +976,7 @@ async def add_roles( for role in roles: await req(guild_id, user_id, role.id, reason=reason) - async def remove_roles( - self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True - ) -> None: + async def remove_roles(self, *roles: Snowflake, reason: Optional[str] = None, atomic: bool = True) -> None: r"""|coro| Removes :class:`Role`\s from this member. diff --git a/discord/mentions.py b/discord/mentions.py index 35d0e84816..2e3338e878 100644 --- a/discord/mentions.py +++ b/discord/mentions.py @@ -141,12 +141,8 @@ def merge(self, other: AllowedMentions) -> AllowedMentions: everyone = self.everyone if other.everyone is default else other.everyone users = self.users if other.users is default else other.users roles = self.roles if other.roles is default else other.roles - replied_user = ( - self.replied_user if other.replied_user is default else other.replied_user - ) - return AllowedMentions( - everyone=everyone, roles=roles, users=users, replied_user=replied_user - ) + replied_user = self.replied_user if other.replied_user is default else other.replied_user + return AllowedMentions(everyone=everyone, roles=roles, users=users, replied_user=replied_user) def __repr__(self) -> str: return ( diff --git a/discord/message.py b/discord/message.py index ac3602f932..2d61736197 100644 --- a/discord/message.py +++ b/discord/message.py @@ -114,9 +114,7 @@ def convert_emoji_reaction(emoji): # No existing emojis have <> in them, so this should be okay. return emoji.strip("<>") - raise InvalidArgument( - f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}." - ) + raise InvalidArgument(f"emoji argument must be str, Emoji, or Reaction not {emoji.__class__.__name__}.") class Attachment(Hashable): @@ -458,9 +456,7 @@ def __init__( self.fail_if_not_exists: bool = fail_if_not_exists @classmethod - def with_state( - cls: Type[MR], state: ConnectionState, data: MessageReferencePayload - ) -> MR: + def with_state(cls: Type[MR], state: ConnectionState, data: MessageReferencePayload) -> MR: self = cls.__new__(cls) self.message_id = utils._get_as_snowflake(data, "message_id") self.channel_id = int(data.pop("channel_id")) @@ -471,9 +467,7 @@ def with_state( return self @classmethod - def from_message( - cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True - ) -> MR: + def from_message(cls: Type[MR], message: Message, *, fail_if_not_exists: bool = True) -> MR: """Creates a :class:`MessageReference` from an existing :class:`~discord.Message`. .. versionadded:: 1.6 @@ -520,9 +514,7 @@ def __repr__(self) -> str: return f"" def to_dict(self) -> MessageReferencePayload: - result: MessageReferencePayload = ( - {"message_id": self.message_id} if self.message_id is not None else {} - ) + result: MessageReferencePayload = {"message_id": self.message_id} if self.message_id is not None else {} result["channel_id"] = self.channel_id if self.guild_id is not None: result["guild_id"] = self.guild_id @@ -720,19 +712,13 @@ def __init__( self._state: ConnectionState = state self.id: int = int(data["id"]) self.webhook_id: Optional[int] = utils._get_as_snowflake(data, "webhook_id") - self.reactions: List[Reaction] = [ - Reaction(message=self, data=d) for d in data.get("reactions", []) - ] - self.attachments: List[Attachment] = [ - Attachment(data=a, state=self._state) for a in data["attachments"] - ] + self.reactions: List[Reaction] = [Reaction(message=self, data=d) for d in data.get("reactions", [])] + self.attachments: List[Attachment] = [Attachment(data=a, state=self._state) for a in data["attachments"]] self.embeds: List[Embed] = [Embed.from_dict(a) for a in data["embeds"]] self.application: Optional[MessageApplicationPayload] = data.get("application") self.activity: Optional[MessageActivityPayload] = data.get("activity") self.channel: MessageableChannel = channel - self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time( - data["edited_timestamp"] - ) + self._edited_timestamp: Optional[datetime.datetime] = utils.parse_time(data["edited_timestamp"]) self.type: MessageType = try_enum(MessageType, data["type"]) self.pinned: bool = data["pinned"] self.flags: MessageFlags = MessageFlags._from_value(data.get("flags", 0)) @@ -740,12 +726,8 @@ def __init__( self.tts: bool = data["tts"] self.content: str = data["content"] self.nonce: Optional[Union[int, str]] = data.get("nonce") - self.stickers: List[StickerItem] = [ - StickerItem(data=d, state=state) for d in data.get("sticker_items", []) - ] - self.components: List[Component] = [ - _component_factory(d) for d in data.get("components", []) - ] + self.stickers: List[StickerItem] = [StickerItem(data=d, state=state) for d in data.get("sticker_items", [])] + self.components: List[Component] = [_component_factory(d) for d in data.get("components", [])] try: # if the channel doesn't have a guild attribute, we handle that @@ -819,9 +801,7 @@ def _add_reaction(self, data, emoji, user_id) -> Reaction: return reaction - def _remove_reaction( - self, data: ReactionPayload, emoji: EmojiInputType, user_id: int - ) -> Reaction: + def _remove_reaction(self, data: ReactionPayload, emoji: EmojiInputType, user_id: int) -> Reaction: reaction = utils.find(lambda r: r.emoji == emoji, self.reactions) if reaction is None: @@ -958,9 +938,7 @@ def _handle_mention_roles(self, role_mentions: List[int]) -> None: def _handle_components(self, components: List[ComponentPayload]): self.components = [_component_factory(d) for d in components] - def _rebind_cached_references( - self, new_guild: Guild, new_channel: Union[TextChannel, Thread] - ) -> None: + def _rebind_cached_references(self, new_guild: Guild, new_channel: Union[TextChannel, Thread]) -> None: self.guild = new_guild self.channel = new_channel @@ -1012,30 +990,20 @@ def clean_content(self) -> str: respectively, along with this function. """ - transformations = { - re.escape(f"<#{channel.id}>"): f"#{channel.name}" - for channel in self.channel_mentions - } + transformations = {re.escape(f"<#{channel.id}>"): f"#{channel.name}" for channel in self.channel_mentions} - mention_transforms = { - re.escape(f"<@{member.id}>"): f"@{member.display_name}" - for member in self.mentions - } + mention_transforms = {re.escape(f"<@{member.id}>"): f"@{member.display_name}" for member in self.mentions} # add the <@!user_id> cases as well.. second_mention_transforms = { - re.escape(f"<@!{member.id}>"): f"@{member.display_name}" - for member in self.mentions + re.escape(f"<@!{member.id}>"): f"@{member.display_name}" for member in self.mentions } transformations.update(mention_transforms) transformations.update(second_mention_transforms) if self.guild is not None: - role_transforms = { - re.escape(f"<@&{role.id}>"): f"@{role.name}" - for role in self.role_mentions - } + role_transforms = {re.escape(f"<@&{role.id}>"): f"@{role.name}" for role in self.role_mentions} transformations.update(role_transforms) def repl(obj): @@ -1093,9 +1061,7 @@ def system_content(self): if self.channel.type is ChannelType.group: return f"{self.author.name} added {self.mentions[0].name} to the group." else: - return ( - f"{self.author.name} added {self.mentions[0].name} to the thread." - ) + return f"{self.author.name} added {self.mentions[0].name} to the thread." if self.type is MessageType.recipient_remove: if self.channel.type is ChannelType.group: @@ -1350,9 +1316,7 @@ async def edit( if content is not MISSING: payload["content"] = str(content) if content is not None else None if embed is not MISSING and embeds is not MISSING: - raise InvalidArgument( - "cannot pass both embed and embeds parameter to edit()" - ) + raise InvalidArgument("cannot pass both embed and embeds parameter to edit()") if embed is not MISSING: payload["embeds"] = [] if embed is None else [embed.to_dict()] @@ -1365,16 +1329,11 @@ async def edit( payload["flags"] = flags.value if allowed_mentions is MISSING: - if ( - self._state.allowed_mentions is not None - and self.author.id == self._state.self_id - ): + if self._state.allowed_mentions is not None and self.author.id == self._state.self_id: payload["allowed_mentions"] = self._state.allowed_mentions.to_dict() elif allowed_mentions is not None: if self._state.allowed_mentions is not None: - payload["allowed_mentions"] = self._state.allowed_mentions.merge( - allowed_mentions - ).to_dict() + payload["allowed_mentions"] = self._state.allowed_mentions.merge(allowed_mentions).to_dict() else: payload["allowed_mentions"] = allowed_mentions.to_dict() @@ -1406,9 +1365,7 @@ async def edit( elif files is not MISSING: if len(files) > 10: - raise InvalidArgument( - "files parameter must be a list of up to 10 elements" - ) + raise InvalidArgument("files parameter must be a list of up to 10 elements") elif not all(isinstance(file, File) for file in files): raise InvalidArgument("files parameter must be a list of File") if "attachments" not in payload: @@ -1426,9 +1383,7 @@ async def edit( for f in files: f.close() else: - data = await self._state.http.edit_message( - self.channel.id, self.id, **payload - ) + data = await self._state.http.edit_message(self.channel.id, self.id, **payload) message = Message(state=self._state, channel=self.channel, data=data) if view and not view.is_finished(): @@ -1547,9 +1502,7 @@ async def add_reaction(self, emoji: EmojiInputType) -> None: emoji = convert_emoji_reaction(emoji) await self._state.http.add_reaction(self.channel.id, self.id, emoji) - async def remove_reaction( - self, emoji: Union[EmojiInputType, Reaction], member: Snowflake - ) -> None: + async def remove_reaction(self, emoji: Union[EmojiInputType, Reaction], member: Snowflake) -> None: """|coro| Remove a reaction by the member from the message. @@ -1586,9 +1539,7 @@ async def remove_reaction( if member.id == self._state.self_id: await self._state.http.remove_own_reaction(self.channel.id, self.id, emoji) else: - await self._state.http.remove_reaction( - self.channel.id, self.id, emoji, member.id - ) + await self._state.http.remove_reaction(self.channel.id, self.id, emoji, member.id) async def clear_reaction(self, emoji: Union[EmojiInputType, Reaction]) -> None: """|coro| @@ -1637,9 +1588,7 @@ async def clear_reactions(self) -> None: """ await self._state.http.clear_reactions(self.channel.id, self.id) - async def create_thread( - self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING - ) -> Thread: + async def create_thread(self, *, name: str, auto_archive_duration: ThreadArchiveDuration = MISSING) -> Thread: """|coro| Creates a public thread from this message. @@ -1683,8 +1632,7 @@ async def create_thread( self.channel.id, self.id, name=name, - auto_archive_duration=auto_archive_duration - or default_auto_archive_duration, + auto_archive_duration=auto_archive_duration or default_auto_archive_duration, ) return Thread(guild=self.guild, state=self._state, data=data) @@ -1733,9 +1681,7 @@ def to_reference(self, *, fail_if_not_exists: bool = True) -> MessageReference: The reference to this message. """ - return MessageReference.from_message( - self, fail_if_not_exists=fail_if_not_exists - ) + return MessageReference.from_message(self, fail_if_not_exists=fail_if_not_exists) def to_message_reference_dict(self) -> MessageReferencePayload: data: MessageReferencePayload = { @@ -1810,9 +1756,7 @@ def __init__(self, *, channel: PartialMessageableChannel, id: int): ChannelType.public_thread, ChannelType.private_thread, ): - raise TypeError( - f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}" - ) + raise TypeError(f"Expected TextChannel, DMChannel or Thread not {type(channel)!r}") self.channel: PartialMessageableChannel = channel self._state: ConnectionState = channel._state @@ -1953,9 +1897,7 @@ async def edit(self, **fields: Any) -> Optional[Message]: else: if allowed_mentions is not None: if self._state.allowed_mentions is not None: - allowed_mentions = self._state.allowed_mentions.merge( - allowed_mentions - ).to_dict() + allowed_mentions = self._state.allowed_mentions.merge(allowed_mentions).to_dict() else: allowed_mentions = allowed_mentions.to_dict() fields["allowed_mentions"] = allowed_mentions @@ -1969,9 +1911,7 @@ async def edit(self, **fields: Any) -> Optional[Message]: self._state.prevent_view_updates_for(self.id) fields["components"] = view.to_components() if view else [] if fields: - data = await self._state.http.edit_message( - self.channel.id, self.id, **fields - ) + data = await self._state.http.edit_message(self.channel.id, self.id, **fields) if delete_after is not None: await self.delete(delay=delete_after) diff --git a/discord/object.py b/discord/object.py index 64cf1398fb..b7a12f7843 100644 --- a/discord/object.py +++ b/discord/object.py @@ -76,9 +76,7 @@ def __init__(self, id: SupportsIntCast): try: id = int(id) except ValueError: - raise TypeError( - f"id parameter must be convertible to int not {id.__class__!r}" - ) from None + raise TypeError(f"id parameter must be convertible to int not {id.__class__!r}") from None else: self.id = id diff --git a/discord/opus.py b/discord/opus.py index 00c69f1a6e..0cbac13c6e 100644 --- a/discord/opus.py +++ b/discord/opus.py @@ -397,9 +397,7 @@ def __del__(self) -> None: def _create_state(self) -> EncoderStruct: ret = ctypes.c_int() - return _lib.opus_encoder_create( - self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret) - ) + return _lib.opus_encoder_create(self.SAMPLING_RATE, self.CHANNELS, self.application, ctypes.byref(ret)) def set_bitrate(self, kbps: int) -> int: kbps = min(512, max(16, int(kbps))) @@ -409,18 +407,14 @@ def set_bitrate(self, kbps: int) -> int: def set_bandwidth(self, req: BAND_CTL) -> None: if req not in band_ctl: - raise KeyError( - f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}' - ) + raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(band_ctl)}') k = band_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_BANDWIDTH, k) def set_signal_type(self, req: SIGNAL_CTL) -> None: if req not in signal_ctl: - raise KeyError( - f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}' - ) + raise KeyError(f'{req!r} is not a valid bandwidth setting. Try one of: {",".join(signal_ctl)}') k = signal_ctl[req] _lib.opus_encoder_ctl(self._state, CTL_SET_SIGNAL, k) @@ -456,9 +450,7 @@ def __del__(self): def _create_state(self): ret = ctypes.c_int() - return _lib.opus_decoder_create( - self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret) - ) + return _lib.opus_decoder_create(self.SAMPLING_RATE, self.CHANNELS, ctypes.byref(ret)) @staticmethod def packet_get_nb_frames(data): @@ -516,15 +508,10 @@ def decode(self, data, *, fec=False): samples_per_frame = self.packet_get_samples_per_frame(data) frame_size = frames * samples_per_frame - pcm = ( - ctypes.c_int16 - * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)) - )() + pcm = (ctypes.c_int16 * (frame_size * channel_count * ctypes.sizeof(ctypes.c_int16)))() pcm_ptr = ctypes.cast(pcm, c_int16_ptr) - ret = _lib.opus_decode( - self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec - ) + ret = _lib.opus_decode(self._state, data, len(data) if data else 0, pcm_ptr, frame_size, fec) return array.array("h", pcm[: ret * channel_count]).tobytes() @@ -556,9 +543,7 @@ def run(self): if data.decrypted_data is None: continue else: - data.decoded_data = self.get_decoder(data.ssrc).decode( - data.decrypted_data - ) + data.decoded_data = self.get_decoder(data.ssrc).decode(data.decrypted_data) except OpusError: print("Error occurred while decoding opus frame.") continue diff --git a/discord/partial_emoji.py b/discord/partial_emoji.py index 8e915c255e..68e224f358 100644 --- a/discord/partial_emoji.py +++ b/discord/partial_emoji.py @@ -93,9 +93,7 @@ class PartialEmoji(_EmojiTag, AssetMixin): __slots__ = ("animated", "name", "id", "_state") - _CUSTOM_EMOJI_RE = re.compile( - r"a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?" - ) + _CUSTOM_EMOJI_RE = re.compile(r"a)?:?(?P[A-Za-z0-9\_]+):(?P[0-9]{13,20})>?") if TYPE_CHECKING: id: Optional[int] @@ -107,9 +105,7 @@ def __init__(self, *, name: str, animated: bool = False, id: Optional[int] = Non self._state: Optional[ConnectionState] = None @classmethod - def from_dict( - cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]] - ) -> PE: + def from_dict(cls: Type[PE], data: Union[PartialEmojiPayload, Dict[str, Any]]) -> PE: return cls( animated=data.get("animated", False), id=utils._get_as_snowflake(data, "id"), diff --git a/discord/permissions.py b/discord/permissions.py index 9eeda3d8a8..4ddb2709fc 100644 --- a/discord/permissions.py +++ b/discord/permissions.py @@ -119,9 +119,7 @@ class Permissions(BaseFlags): def __init__(self, permissions: int = 0, **kwargs: bool): if not isinstance(permissions, int): - raise TypeError( - f"Expected int parameter, received {permissions.__class__.__name__} instead." - ) + raise TypeError(f"Expected int parameter, received {permissions.__class__.__name__} instead.") self.value = permissions for key, value in kwargs.items(): @@ -134,18 +132,14 @@ def is_subset(self, other: Permissions) -> bool: if isinstance(other, Permissions): return (self.value & other.value) == self.value else: - raise TypeError( - f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}" - ) + raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") def is_superset(self, other: Permissions) -> bool: """Returns ``True`` if self has the same or more permissions as other.""" if isinstance(other, Permissions): return (self.value | other.value) == self.value else: - raise TypeError( - f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}" - ) + raise TypeError(f"cannot compare {self.__class__.__name__} with {other.__class__.__name__}") def is_strict_subset(self, other: Permissions) -> bool: """Returns ``True`` if the permissions on other are a strict subset of those on self.""" @@ -731,9 +725,7 @@ def __eq__(self, other: Any) -> bool: def _set(self, key: str, value: Optional[bool]) -> None: if value not in (True, None, False): - raise TypeError( - f"Expected bool or NoneType, received {value.__class__.__name__}" - ) + raise TypeError(f"Expected bool or NoneType, received {value.__class__.__name__}") if value is None: self._values.pop(key, None) diff --git a/discord/player.py b/discord/player.py index 0a5b803f01..63e0fb14a2 100644 --- a/discord/player.py +++ b/discord/player.py @@ -165,9 +165,7 @@ def __init__( ): piping = subprocess_kwargs.get("stdin") == subprocess.PIPE if piping and isinstance(source, str): - raise TypeError( - "parameter conflict: 'source' parameter cannot be a string when piping to stdin" - ) + raise TypeError("parameter conflict: 'source' parameter cannot be a string when piping to stdin") args = [executable, *args] kwargs = {"stdout": subprocess.PIPE} @@ -181,24 +179,18 @@ def __init__( if piping: n = f"popen-stdin-writer:{id(self):#x}" self._stdin = self._process.stdin - self._pipe_thread = threading.Thread( - target=self._pipe_writer, args=(source,), daemon=True, name=n - ) + self._pipe_thread = threading.Thread(target=self._pipe_writer, args=(source,), daemon=True, name=n) self._pipe_thread.start() def _spawn_process(self, args: Any, **subprocess_kwargs: Any) -> subprocess.Popen: process = None try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs - ) + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, **subprocess_kwargs) except FileNotFoundError: executable = args.partition(" ")[0] if isinstance(args, str) else args[0] raise ClientException(f"{executable} was not found.") from None except subprocess.SubprocessError as exc: - raise ClientException( - f"Popen failed: {exc.__class__.__name__}: {exc}" - ) from exc + raise ClientException(f"Popen failed: {exc.__class__.__name__}: {exc}") from exc else: return process @@ -212,9 +204,7 @@ def _kill_process(self) -> None: try: proc.kill() except Exception: - _log.exception( - "Ignoring error attempting to kill ffmpeg process %s", proc.pid - ) + _log.exception("Ignoring error attempting to kill ffmpeg process %s", proc.pid) if proc.poll() is None: _log.info( @@ -453,9 +443,7 @@ async def from_probe( cls: Type[FT], source: str, *, - method: Optional[ - Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]] - ] = None, + method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None, **kwargs: Any, ) -> FT: """|coro| @@ -522,9 +510,7 @@ async def probe( cls, source: str, *, - method: Optional[ - Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]] - ] = None, + method: Optional[Union[str, Callable[[str, str], Tuple[Optional[str], Optional[int]]]]] = None, executable: Optional[str] = None, ) -> Tuple[Optional[str], Optional[int]]: """|coro| @@ -569,10 +555,7 @@ async def probe( probefunc = method fallback = cls._probe_codec_fallback else: - raise TypeError( - "Expected str or callable for parameter 'probe', " - f"not '{method.__class__.__name__}'" - ) + raise TypeError("Expected str or callable for parameter 'probe', " f"not '{method.__class__.__name__}'") codec = bitrate = None loop = asyncio.get_event_loop() @@ -583,9 +566,7 @@ async def probe( _log.exception("Probe '%s' using '%s' failed", method, executable) return # type: ignore - _log.exception( - "Probe '%s' using '%s' failed, trying fallback", method, executable - ) + _log.exception("Probe '%s' using '%s' failed, trying fallback", method, executable) try: codec, bitrate = await loop.run_in_executor(None, lambda: fallback(source, executable)) # type: ignore except Exception: @@ -598,14 +579,8 @@ async def probe( return codec, bitrate @staticmethod - def _probe_codec_native( - source, executable: str = "ffmpeg" - ) -> Tuple[Optional[str], Optional[int]]: - exe = ( - f"{executable[:2]}probe" - if executable in {"ffmpeg", "avconv"} - else executable - ) + def _probe_codec_native(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]: + exe = f"{executable[:2]}probe" if executable in {"ffmpeg", "avconv"} else executable args = [ exe, @@ -632,9 +607,7 @@ def _probe_codec_native( return codec, bitrate @staticmethod - def _probe_codec_fallback( - source, executable: str = "ffmpeg" - ) -> Tuple[Optional[str], Optional[int]]: + def _probe_codec_fallback(source, executable: str = "ffmpeg") -> Tuple[Optional[str], Optional[int]]: args = [executable, "-hide_banner", "-i", source] proc = subprocess.Popen( args, @@ -824,8 +797,6 @@ def _set_source(self, source: AudioSource) -> None: def _speak(self, speaking: bool) -> None: try: - asyncio.run_coroutine_threadsafe( - self.client.ws.speak(speaking), self.client.loop - ) + asyncio.run_coroutine_threadsafe(self.client.ws.speak(speaking), self.client.loop) except Exception as e: _log.info("Speaking call in player failed: %s", e) diff --git a/discord/raw_models.py b/discord/raw_models.py index 7559b47c78..b86ff27d00 100644 --- a/discord/raw_models.py +++ b/discord/raw_models.py @@ -201,9 +201,7 @@ class RawReactionActionEvent(_RawReprMixin): "member", ) - def __init__( - self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str - ) -> None: + def __init__(self, data: ReactionActionEvent, emoji: PartialEmoji, event_type: str) -> None: self.message_id: int = int(data["message_id"]) self.channel_id: int = int(data["channel_id"]) self.user_id: int = int(data["user_id"]) @@ -353,9 +351,7 @@ class RawTypingEvent(_RawReprMixin): def __init__(self, data: TypingEvent) -> None: self.channel_id: int = int(data["channel_id"]) self.user_id: int = int(data["user_id"]) - self.when: datetime.datetime = datetime.datetime.fromtimestamp( - data.get("timestamp"), tz=datetime.timezone.utc - ) + self.when: datetime.datetime = datetime.datetime.fromtimestamp(data.get("timestamp"), tz=datetime.timezone.utc) self.member: Optional[Member] = None try: diff --git a/discord/reaction.py b/discord/reaction.py index 7dc25560d3..ebcfbf3fe7 100644 --- a/discord/reaction.py +++ b/discord/reaction.py @@ -87,9 +87,7 @@ def __init__( emoji: Optional[Union[PartialEmoji, Emoji, str]] = None, ): self.message: Message = message - self.emoji: Union[ - PartialEmoji, Emoji, str - ] = emoji or message._state.get_reaction_emoji(data["emoji"]) + self.emoji: Union[PartialEmoji, Emoji, str] = emoji or message._state.get_reaction_emoji(data["emoji"]) self.count: int = data.get("count", 1) self.me: bool = data.get("me") @@ -165,9 +163,7 @@ async def clear(self) -> None: """ await self.message.clear_reaction(self.emoji) - def users( - self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None - ) -> ReactionIterator: + def users(self, *, limit: Optional[int] = None, after: Optional[Snowflake] = None) -> ReactionIterator: """Returns an :class:`AsyncIterator` representing the users that have reacted to the message. The ``after`` parameter must represent a member diff --git a/discord/role.py b/discord/role.py index 0f764f2c20..96ac0a7084 100644 --- a/discord/role.py +++ b/discord/role.py @@ -82,9 +82,7 @@ def __init__(self, data: RoleTagPayload): # This is different from other fields where "null" means "not there". # So in this case, a value of None is the same as True. # Which means we would need a different sentinel. - self._premium_subscriber: Optional[Any] = data.get( - "premium_subscriber", MISSING - ) + self._premium_subscriber: Optional[Any] = data.get("premium_subscriber", MISSING) def is_bot_managed(self) -> bool: """:class:`bool`: Whether the role is associated with a bot.""" @@ -293,11 +291,7 @@ def is_assignable(self) -> bool: .. versionadded:: 2.0 """ me = self.guild.me - return ( - not self.is_default() - and not self.managed - and (me.top_role > self or me.id == self.guild.owner_id) - ) + return not self.is_default() and not self.managed and (me.top_role > self or me.id == self.guild.owner_id) @property def permissions(self) -> Permissions: @@ -357,23 +351,15 @@ async def _move(self, position: int, reason: Optional[str]) -> None: http = self._state.http - change_range = range( - min(self.position, position), max(self.position, position) + 1 - ) - roles = [ - r.id - for r in self.guild.roles[1:] - if r.position in change_range and r.id != self.id - ] + change_range = range(min(self.position, position), max(self.position, position) + 1) + roles = [r.id for r in self.guild.roles[1:] if r.position in change_range and r.id != self.id] if self.position > position: roles.insert(0, self.id) else: roles.append(self.id) - payload: List[RolePositionUpdate] = [ - {"id": z[0], "position": z[1]} for z in zip(roles, change_range) - ] + payload: List[RolePositionUpdate] = [{"id": z[0], "position": z[1]} for z in zip(roles, change_range)] await http.move_role_position(self.guild.id, payload, reason=reason) async def edit( @@ -477,9 +463,7 @@ async def edit( payload["unicode_emoji"] = unicode_emoji payload["icon"] = None - data = await self._state.http.edit_role( - self.guild.id, self.id, reason=reason, **payload - ) + data = await self._state.http.edit_role(self.guild.id, self.id, reason=reason, **payload) return Role(guild=self.guild, data=data, state=self._state) async def delete(self, *, reason: Optional[str] = None) -> None: diff --git a/discord/scheduled_events.py b/discord/scheduled_events.py index 89127f99b1..31263bdcb3 100644 --- a/discord/scheduled_events.py +++ b/discord/scheduled_events.py @@ -204,16 +204,12 @@ def __init__( self.name: str = data.get("name") self.description: Optional[str] = data.get("description", None) self._cover: Optional[str] = data.get("image", None) - self.start_time: datetime.datetime = datetime.datetime.fromisoformat( - data.get("scheduled_start_time") - ) + self.start_time: datetime.datetime = datetime.datetime.fromisoformat(data.get("scheduled_start_time")) end_time = data.get("scheduled_end_time", None) if end_time != None: end_time = datetime.datetime.fromisoformat(end_time) self.end_time: Optional[datetime.datetime] = end_time - self.status: ScheduledEventStatus = try_enum( - ScheduledEventStatus, data.get("status") - ) + self.status: ScheduledEventStatus = try_enum(ScheduledEventStatus, data.get("status")) self.subscriber_count: Optional[int] = data.get("user_count", None) self.creator_id = data.get("creator_id", None) self.creator: Optional[Member] = creator @@ -221,9 +217,7 @@ def __init__( entity_metadata = data.get("entity_metadata") channel_id = data.get("channel_id", None) if channel_id is None: - self.location = ScheduledEventLocation( - state=state, value=entity_metadata["location"] - ) + self.location = ScheduledEventLocation(state=state, value=entity_metadata["location"]) else: self.location = ScheduledEventLocation(state=state, value=int(channel_id)) @@ -271,9 +265,7 @@ async def edit( name: str = MISSING, description: str = MISSING, status: Union[int, ScheduledEventStatus] = MISSING, - location: Union[ - str, int, VoiceChannel, StageChannel, ScheduledEventLocation - ] = MISSING, + location: Union[str, int, VoiceChannel, StageChannel, ScheduledEventLocation] = MISSING, start_time: datetime.datetime = MISSING, end_time: datetime.datetime = MISSING, cover: Optional[bytes] = MISSING, @@ -348,9 +340,7 @@ async def edit( payload["image"] = utils._bytes_to_base64_data(cover) if location is not MISSING: - if not isinstance( - location, (ScheduledEventLocation, utils._MissingSentinel) - ): + if not isinstance(location, (ScheduledEventLocation, utils._MissingSentinel)): location = ScheduledEventLocation(state=self._state, value=location) if location.type is ScheduledEventLocationType.external: @@ -364,9 +354,7 @@ async def edit( if end_time is MISSING and location.type is ScheduledEventLocationType.external: end_time = self.end_time if end_time is None: - raise ValidationError( - "end_time needs to be passed if location type is external." - ) + raise ValidationError("end_time needs to be passed if location type is external.") if start_time is not MISSING: payload["scheduled_start_time"] = start_time.isoformat() @@ -375,12 +363,8 @@ async def edit( payload["scheduled_end_time"] = end_time.isoformat() if payload != {}: - data = await self._state.http.edit_scheduled_event( - self.guild.id, self.id, **payload, reason=reason - ) - return ScheduledEvent( - data=data, guild=self.guild, creator=self.creator, state=self._state - ) + data = await self._state.http.edit_scheduled_event(self.guild.id, self.id, **payload, reason=reason) + return ScheduledEvent(data=data, guild=self.guild, creator=self.creator, state=self._state) async def delete(self) -> None: """|coro| diff --git a/discord/shard.py b/discord/shard.py index 63611cd5af..766c506dd7 100644 --- a/discord/shard.py +++ b/discord/shard.py @@ -81,9 +81,7 @@ class EventType: class EventItem: __slots__ = ("type", "shard", "error") - def __init__( - self, etype: int, shard: Optional["Shard"], error: Optional[Exception] - ) -> None: + def __init__(self, etype: int, shard: Optional["Shard"], error: Optional[Exception]) -> None: self.type: int = etype self.shard: Optional["Shard"] = shard self.error: Optional[Exception] = error @@ -165,11 +163,7 @@ async def _handle_disconnect(self, e: Exception) -> None: if isinstance(e, ConnectionClosed): if e.code == 4014: - self._queue_put( - EventItem( - EventType.terminate, self, PrivilegedIntentsRequired(self.id) - ) - ) + self._queue_put(EventItem(EventType.terminate, self, PrivilegedIntentsRequired(self.id))) return if e.code != 1000: self._queue_put(EventItem(EventType.close, self, e)) @@ -357,9 +351,7 @@ def __init__( if self.shard_ids is not None: if self.shard_count is None: - raise ClientException( - "When passing manual shard_ids, you must provide a shard_count." - ) + raise ClientException("When passing manual shard_ids, you must provide a shard_count.") elif not isinstance(self.shard_ids, (list, tuple)): raise ClientException("shard_ids parameter must be a list or a tuple.") @@ -370,9 +362,7 @@ def __init__( self._connection._get_client = lambda: self self.__queue = asyncio.PriorityQueue() - def _get_websocket( - self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None - ) -> DiscordWebSocket: + def _get_websocket(self, guild_id: Optional[int] = None, *, shard_id: Optional[int] = None) -> DiscordWebSocket: if shard_id is None: # guild_id won't be None if shard_id is None and shard_count won't be None here shard_id = (guild_id >> 22) % self.shard_count # type: ignore @@ -406,9 +396,7 @@ def latencies(self) -> List[Tuple[int, float]]: This returns a list of tuples with elements ``(shard_id, latency)``. """ - return [ - (shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items() - ] + return [(shard_id, shard.ws.latency) for shard_id, shard in self.__shards.items()] def get_shard(self, shard_id: int) -> Optional[ShardInfo]: """Optional[:class:`ShardInfo`]: Gets the shard information at a given shard ID or ``None`` if not found.""" @@ -422,18 +410,11 @@ def get_shard(self, shard_id: int) -> Optional[ShardInfo]: @property def shards(self) -> Dict[int, ShardInfo]: """Mapping[int, :class:`ShardInfo`]: Returns a mapping of shard IDs to their respective info object.""" - return { - shard_id: ShardInfo(parent, self.shard_count) - for shard_id, parent in self.__shards.items() - } + return {shard_id: ShardInfo(parent, self.shard_count) for shard_id, parent in self.__shards.items()} - async def launch_shard( - self, gateway: str, shard_id: int, *, initial: bool = False - ) -> None: + async def launch_shard(self, gateway: str, shard_id: int, *, initial: bool = False) -> None: try: - coro = DiscordWebSocket.from_client( - self, initial=initial, gateway=gateway, shard_id=shard_id - ) + coro = DiscordWebSocket.from_client(self, initial=initial, gateway=gateway, shard_id=shard_id) ws = await asyncio.wait_for(coro, timeout=180.0) except Exception: _log.exception("Failed to connect for shard_id: %s. Retrying...", shard_id) @@ -498,10 +479,7 @@ async def close(self) -> None: except Exception: pass - to_close = [ - asyncio.ensure_future(shard.close(), loop=self.loop) - for shard in self.__shards.values() - ] + to_close = [asyncio.ensure_future(shard.close(), loop=self.loop) for shard in self.__shards.values()] if to_close: await asyncio.wait(to_close) diff --git a/discord/sinks/core.py b/discord/sinks/core.py index 29b8b3d694..ec6109bf93 100644 --- a/discord/sinks/core.py +++ b/discord/sinks/core.py @@ -109,9 +109,7 @@ def __init__(self, data, client): unpacker = struct.Struct(">xxHII") self.sequence, self.timestamp, self.ssrc = unpacker.unpack_from(self.header) - self.decrypted_data = getattr(self.client, f"_decrypt_{self.client.mode}")( - self.header, self.data - ) + self.decrypted_data = getattr(self.client, f"_decrypt_{self.client.mode}")(self.header, self.data) self.decoded_data = None self.user_id = None diff --git a/discord/sinks/m4a.py b/discord/sinks/m4a.py index 58207008dd..b2e92110c3 100644 --- a/discord/sinks/m4a.py +++ b/discord/sinks/m4a.py @@ -56,9 +56,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise M4ASinkError( - "Audio may only be formatted after recording is finished." - ) + raise M4ASinkError("Audio may only be formatted after recording is finished.") m4a_file = f"{time.time()}.tmp" args = [ "ffmpeg", @@ -75,19 +73,13 @@ def format_audio(self, audio): m4a_file, ] if os.path.exists(m4a_file): - os.remove( - m4a_file - ) # process will get stuck asking whether or not to overwrite, if file already exists. + os.remove(m4a_file) # process will get stuck asking whether or not to overwrite, if file already exists. try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) except FileNotFoundError: raise M4ASinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise M4ASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise M4ASinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc process.communicate(audio.file.read()) diff --git a/discord/sinks/mka.py b/discord/sinks/mka.py index 10ddff1adc..1feb39e46c 100644 --- a/discord/sinks/mka.py +++ b/discord/sinks/mka.py @@ -55,9 +55,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise MKASinkError( - "Audio may only be formatted after recording is finished." - ) + raise MKASinkError("Audio may only be formatted after recording is finished.") args = [ "ffmpeg", "-f", @@ -82,9 +80,7 @@ def format_audio(self, audio): except FileNotFoundError: raise MKASinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise MKASinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise MKASinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc out = process.communicate(audio.file.read())[0] out = io.BytesIO(out) diff --git a/discord/sinks/mkv.py b/discord/sinks/mkv.py index 00958294eb..454cfc0178 100644 --- a/discord/sinks/mkv.py +++ b/discord/sinks/mkv.py @@ -55,9 +55,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise MKVSinkError( - "Audio may only be formatted after recording is finished." - ) + raise MKVSinkError("Audio may only be formatted after recording is finished.") args = [ "ffmpeg", "-f", @@ -81,9 +79,7 @@ def format_audio(self, audio): except FileNotFoundError: raise MKVSinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise MKVSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise MKVSinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc out = process.communicate(audio.file.read())[0] out = io.BytesIO(out) diff --git a/discord/sinks/mp3.py b/discord/sinks/mp3.py index 1ff88d69c3..4cd48b6ba9 100644 --- a/discord/sinks/mp3.py +++ b/discord/sinks/mp3.py @@ -62,9 +62,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise MP3SinkError( - "Audio may only be formatted after recording is finished." - ) + raise MP3SinkError("Audio may only be formatted after recording is finished.") args = [ "ffmpeg", "-f", @@ -89,9 +87,7 @@ def format_audio(self, audio): except FileNotFoundError: raise MP3SinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise MP3SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise MP3SinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc out = process.communicate(audio.file.read())[0] out = io.BytesIO(out) diff --git a/discord/sinks/mp4.py b/discord/sinks/mp4.py index 93f13ff0d2..346b50a334 100644 --- a/discord/sinks/mp4.py +++ b/discord/sinks/mp4.py @@ -56,9 +56,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise MP4SinkError( - "Audio may only be formatted after recording is finished." - ) + raise MP4SinkError("Audio may only be formatted after recording is finished.") mp4_file = f"{time.time()}.tmp" args = [ "ffmpeg", @@ -75,19 +73,13 @@ def format_audio(self, audio): mp4_file, ] if os.path.exists(mp4_file): - os.remove( - mp4_file - ) # process will get stuck asking whether or not to overwrite, if file already exists. + os.remove(mp4_file) # process will get stuck asking whether or not to overwrite, if file already exists. try: - process = subprocess.Popen( - args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE - ) + process = subprocess.Popen(args, creationflags=CREATE_NO_WINDOW, stdin=subprocess.PIPE) except FileNotFoundError: raise MP4SinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise MP4SinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise MP4SinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc process.communicate(audio.file.read()) diff --git a/discord/sinks/ogg.py b/discord/sinks/ogg.py index 0395aaef83..d01e82620b 100644 --- a/discord/sinks/ogg.py +++ b/discord/sinks/ogg.py @@ -55,9 +55,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise OGGSinkError( - "Audio may only be formatted after recording is finished." - ) + raise OGGSinkError("Audio may only be formatted after recording is finished.") args = [ "ffmpeg", "-f", @@ -82,9 +80,7 @@ def format_audio(self, audio): except FileNotFoundError: raise OGGSinkError("ffmpeg was not found.") from None except subprocess.SubprocessError as exc: - raise OGGSinkError( - "Popen failed: {0.__class__.__name__}: {0}".format(exc) - ) from exc + raise OGGSinkError("Popen failed: {0.__class__.__name__}: {0}".format(exc)) from exc out = process.communicate(audio.file.read())[0] out = io.BytesIO(out) diff --git a/discord/sinks/wave.py b/discord/sinks/wave.py index cf8529189d..576313760f 100644 --- a/discord/sinks/wave.py +++ b/discord/sinks/wave.py @@ -55,9 +55,7 @@ def __init__(self, *, filters=None): def format_audio(self, audio): if self.vc.recording: - raise WaveSinkError( - "Audio may only be formatted after recording is finished." - ) + raise WaveSinkError("Audio may only be formatted after recording is finished.") data = audio.file with wave.open(data, "wb") as f: diff --git a/discord/stage_instance.py b/discord/stage_instance.py index 73f8be3771..9f42c8c0ae 100644 --- a/discord/stage_instance.py +++ b/discord/stage_instance.py @@ -87,9 +87,7 @@ class StageInstance(Hashable): "_cs_channel", ) - def __init__( - self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload - ) -> None: + def __init__(self, *, state: ConnectionState, guild: Guild, data: StageInstancePayload) -> None: self._state = state self.guild = guild self._update(data) @@ -98,9 +96,7 @@ def _update(self, data: StageInstancePayload): self.id: int = int(data["id"]) self.channel_id: int = int(data["channel_id"]) self.topic: str = data["topic"] - self.privacy_level: StagePrivacyLevel = try_enum( - StagePrivacyLevel, data["privacy_level"] - ) + self.privacy_level: StagePrivacyLevel = try_enum(StagePrivacyLevel, data["privacy_level"]) self.discoverable_disabled: bool = data.get("discoverable_disabled", False) def __repr__(self) -> str: @@ -155,16 +151,12 @@ async def edit( if privacy_level is not MISSING: if not isinstance(privacy_level, StagePrivacyLevel): - raise InvalidArgument( - "privacy_level field must be of type PrivacyLevel" - ) + raise InvalidArgument("privacy_level field must be of type PrivacyLevel") payload["privacy_level"] = privacy_level.value if payload: - await self._state.http.edit_stage_instance( - self.channel_id, **payload, reason=reason - ) + await self._state.http.edit_stage_instance(self.channel_id, **payload, reason=reason) async def delete(self, *, reason: Optional[str] = None) -> None: """|coro| diff --git a/discord/state.py b/discord/state.py index 1763c65f70..0ca6504095 100644 --- a/discord/state.py +++ b/discord/state.py @@ -145,9 +145,7 @@ def done(self) -> None: _log = logging.getLogger(__name__) -async def logging_coroutine( - coroutine: Coroutine[Any, Any, T], *, info: str -) -> Optional[T]: +async def logging_coroutine(coroutine: Coroutine[Any, Any, T], *, info: str) -> Optional[T]: try: await coroutine except Exception: @@ -181,9 +179,7 @@ def __init__( self.hooks: Dict[str, Callable] = hooks self.shard_count: Optional[int] = None self._ready_task: Optional[asyncio.Task] = None - self.application_id: Optional[int] = utils._get_as_snowflake( - options, "application_id" - ) + self.application_id: Optional[int] = utils._get_as_snowflake(options, "application_id") self.heartbeat_timeout: float = options.get("heartbeat_timeout", 60.0) self.guild_ready_timeout: float = options.get("guild_ready_timeout", 2.0) if self.guild_ready_timeout < 0: @@ -191,9 +187,7 @@ def __init__( allowed_mentions = options.get("allowed_mentions") - if allowed_mentions is not None and not isinstance( - allowed_mentions, AllowedMentions - ): + if allowed_mentions is not None and not isinstance(allowed_mentions, AllowedMentions): raise TypeError("allowed_mentions parameter must be AllowedMentions") self.allowed_mentions: Optional[AllowedMentions] = allowed_mentions @@ -216,27 +210,19 @@ def __init__( elif not isinstance(intents, Intents): raise TypeError(f"intents parameter must be Intent not {type(intents)!r}") if not intents.guilds: - _log.warning( - "Guilds intent seems to be disabled. This may cause state related issues." - ) + _log.warning("Guilds intent seems to be disabled. This may cause state related issues.") - self._chunk_guilds: bool = options.get( - "chunk_guilds_at_startup", intents.members - ) + self._chunk_guilds: bool = options.get("chunk_guilds_at_startup", intents.members) # Ensure these two are set properly if not intents.members and self._chunk_guilds: - raise ValueError( - "Intents.members must be enabled to chunk guilds at startup." - ) + raise ValueError("Intents.members must be enabled to chunk guilds at startup.") cache_flags = options.get("member_cache_flags", None) if cache_flags is None: cache_flags = MemberCacheFlags.from_intents(intents) elif not isinstance(cache_flags, MemberCacheFlags): - raise TypeError( - f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}" - ) + raise TypeError(f"member_cache_flags parameter must be MemberCacheFlags not {type(cache_flags)!r}") else: cache_flags._verify_intents(intents) @@ -438,9 +424,7 @@ def get_sticker(self, sticker_id: Optional[int]) -> Optional[GuildSticker]: def private_channels(self) -> List[PrivateChannel]: return list(self._private_channels.values()) - def _get_private_channel( - self, channel_id: Optional[int] - ) -> Optional[PrivateChannel]: + def _get_private_channel(self, channel_id: Optional[int]) -> Optional[PrivateChannel]: try: # the keys of self._private_channels are ints value = self._private_channels[channel_id] # type: ignore @@ -450,9 +434,7 @@ def _get_private_channel( self._private_channels.move_to_end(channel_id) # type: ignore return value - def _get_private_channel_by_user( - self, user_id: Optional[int] - ) -> Optional[DMChannel]: + def _get_private_channel_by_user(self, user_id: Optional[int]) -> Optional[DMChannel]: # the keys of self._private_channels are ints return self._private_channels_by_user.get(user_id) # type: ignore @@ -482,11 +464,7 @@ def _remove_private_channel(self, channel: PrivateChannel) -> None: self._private_channels_by_user.pop(recipient.id, None) def _get_message(self, msg_id: Optional[int]) -> Optional[Message]: - return ( - utils.find(lambda m: m.id == msg_id, reversed(self._messages)) - if self._messages - else None - ) + return utils.find(lambda m: m.id == msg_id, reversed(self._messages)) if self._messages else None def _add_guild_from_data(self, data: GuildPayload) -> Guild: guild = Guild(data=data, state=self) @@ -495,15 +473,9 @@ def _add_guild_from_data(self, data: GuildPayload) -> Guild: def _guild_needs_chunking(self, guild: Guild) -> bool: # If presences are enabled then we get back the old guild.large behaviour - return ( - self._chunk_guilds - and not guild.chunked - and not (self._intents.presences and not guild.large) - ) - - def _get_guild_channel( - self, data: MessagePayload - ) -> Tuple[Union[Channel, Thread], Optional[Guild]]: + return self._chunk_guilds and not guild.chunked and not (self._intents.presences and not guild.large) + + def _get_guild_channel(self, data: MessagePayload) -> Tuple[Union[Channel, Thread], Optional[Guild]]: channel_id = int(data["channel_id"]) try: guild = self._get_guild(int(data["guild_id"])) @@ -525,9 +497,7 @@ async def chunker( nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id) # This is ignored upstream - await ws.request_chunks( - guild_id, query=query, limit=limit, presences=presences, nonce=nonce - ) + await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) async def query_members( self, @@ -573,9 +543,7 @@ async def _delay_ready(self) -> None: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for( - self._ready_state.get(), timeout=self.guild_ready_timeout - ) + guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) except asyncio.TimeoutError: break else: @@ -668,9 +636,7 @@ def parse_message_delete(self, data) -> None: def parse_message_delete_bulk(self, data) -> None: raw = RawBulkMessageDeleteEvent(data) if self._messages: - found_messages = [ - message for message in self._messages if message.id in raw.message_ids - ] + found_messages = [message for message in self._messages if message.id in raw.message_ids] else: found_messages = [] raw.cached_messages = found_messages @@ -702,9 +668,7 @@ def parse_message_update(self, data) -> None: def parse_message_reaction_add(self, data) -> None: emoji = data["emoji"] emoji_id = utils._get_as_snowflake(emoji, "id") - emoji = PartialEmoji.with_state( - self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"] - ) + emoji = PartialEmoji.with_state(self, id=emoji_id, animated=emoji.get("animated", False), name=emoji["name"]) raw = RawReactionActionEvent(data, emoji, "REACTION_ADD") member_data = data.get("member") @@ -785,9 +749,7 @@ def parse_interaction_create(self, data) -> None: interaction.user.id, interaction.data["custom_id"], ) - asyncio.create_task( - self._modal_store.dispatch(user_id, custom_id, interaction) - ) + asyncio.create_task(self._modal_store.dispatch(user_id, custom_id, interaction)) self.dispatch("interaction", interaction) @@ -914,11 +876,7 @@ def parse_channel_pins_update(self, data) -> None: ) return - last_pin = ( - utils.parse_time(data["last_pin_timestamp"]) - if data["last_pin_timestamp"] - else None - ) + last_pin = utils.parse_time(data["last_pin_timestamp"]) if data["last_pin_timestamp"] else None if guild is None: self.dispatch("private_channel_pins_update", channel, last_pin) @@ -1208,9 +1166,7 @@ async def chunk_guild(self, guild, *, wait=True, cache=None): cache = cache or self.member_cache_flags.joined request = self._chunk_requests.get(guild.id) if request is None: - self._chunk_requests[guild.id] = request = ChunkRequest( - guild.id, self.loop, self._get_guild, cache=cache - ) + self._chunk_requests[guild.id] = request = ChunkRequest(guild.id, self.loop, self._get_guild, cache=cache) await self.chunker(guild.id, nonce=request.nonce) if wait: @@ -1369,9 +1325,7 @@ def parse_guild_members_chunk(self, data) -> None: # the guild won't be None here members = [Member(guild=guild, data=member, state=self) for member in data.get("members", [])] # type: ignore - _log.debug( - "Processed a chunk for %s members in guild ID %s.", len(members), guild_id - ) + _log.debug("Processed a chunk for %s members in guild ID %s.", len(members), guild_id) if presences: member_dict = {str(member.id): member for member in members} @@ -1394,14 +1348,8 @@ def parse_guild_scheduled_event_create(self, data) -> None: ) return - creator = ( - None - if not data.get("creator", None) - else guild.get_member(data.get("creator_id")) - ) - scheduled_event = ScheduledEvent( - state=self, guild=guild, creator=creator, data=data - ) + creator = None if not data.get("creator", None) else guild.get_member(data.get("creator_id")) + scheduled_event = ScheduledEvent(state=self, guild=guild, creator=creator, data=data) guild._add_scheduled_event(scheduled_event) self.dispatch("scheduled_event_create", scheduled_event) @@ -1414,14 +1362,8 @@ def parse_guild_scheduled_event_update(self, data) -> None: ) return - creator = ( - None - if not data.get("creator", None) - else guild.get_member(data.get("creator_id")) - ) - scheduled_event = ScheduledEvent( - state=self, guild=guild, creator=creator, data=data - ) + creator = None if not data.get("creator", None) else guild.get_member(data.get("creator_id")) + scheduled_event = ScheduledEvent(state=self, guild=guild, creator=creator, data=data) old_event = guild.get_scheduled_event(data["id"]) guild._add_scheduled_event(scheduled_event) self.dispatch("scheduled_event_update", old_event, scheduled_event) @@ -1435,14 +1377,8 @@ def parse_guild_scheduled_event_delete(self, data) -> None: ) return - creator = ( - None - if not data.get("creator", None) - else guild.get_member(data.get("creator_id")) - ) - scheduled_event = ScheduledEvent( - state=self, guild=guild, creator=creator, data=data - ) + creator = None if not data.get("creator", None) else guild.get_member(data.get("creator_id")) + scheduled_event = ScheduledEvent(state=self, guild=guild, creator=creator, data=data) scheduled_event.status = ScheduledEventStatus.canceled guild._remove_scheduled_event(scheduled_event) self.dispatch("scheduled_event_delete", scheduled_event) @@ -1574,9 +1510,7 @@ def parse_stage_instance_update(self, data) -> None: if stage_instance is not None: old_stage_instance = copy.copy(stage_instance) stage_instance._update(data) - self.dispatch( - "stage_instance_update", old_stage_instance, stage_instance - ) + self.dispatch("stage_instance_update", old_stage_instance, stage_instance) else: _log.debug( "STAGE_INSTANCE_UPDATE referencing unknown stage instance ID: %s. Discarding.", @@ -1614,20 +1548,12 @@ def parse_voice_state_update(self, data) -> None: voice = self._get_voice_client(guild.id) if voice is not None: coro = voice.on_voice_state_update(data) - asyncio.create_task( - logging_coroutine( - coro, info="Voice Protocol voice state update handler" - ) - ) + asyncio.create_task(logging_coroutine(coro, info="Voice Protocol voice state update handler")) member, before, after = guild._update_voice_state(data, channel_id) # type: ignore if member is not None: if flags.voice: - if ( - channel_id is None - and flags._voice_only - and member.id != self_id - ): + if channel_id is None and flags._voice_only and member.id != self_id: # Only remove from cache if we only have the voice flag enabled # Member doesn't meet the Snowflake protocol currently guild._remove_member(member) # type: ignore @@ -1650,11 +1576,7 @@ def parse_voice_server_update(self, data) -> None: vc = self._get_voice_client(key_id) if vc is not None: coro = vc.on_voice_server_update(data) - asyncio.create_task( - logging_coroutine( - coro, info="Voice Protocol voice server update handler" - ) - ) + asyncio.create_task(logging_coroutine(coro, info="Voice Protocol voice server update handler")) def parse_typing_start(self, data) -> None: raw = RawTypingEvent(data) @@ -1677,9 +1599,7 @@ def parse_typing_start(self, data) -> None: if user is not None: self.dispatch("typing", channel, user, raw.when) - def _get_typing_user( - self, channel: Optional[MessageableChannel], user_id: int - ) -> Optional[Union[User, Member]]: + def _get_typing_user(self, channel: Optional[MessageableChannel], user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, DMChannel): return channel.recipient or self.get_user(user_id) @@ -1691,9 +1611,7 @@ def _get_typing_user( return self.get_user(user_id) - def _get_reaction_user( - self, channel: MessageableChannel, user_id: int - ) -> Optional[Union[User, Member]]: + def _get_reaction_user(self, channel: MessageableChannel, user_id: int) -> Optional[Union[User, Member]]: if isinstance(channel, TextChannel): return channel.guild.get_member(user_id) return self.get_user(user_id) @@ -1714,9 +1632,7 @@ def get_reaction_emoji(self, data) -> Union[Emoji, PartialEmoji]: name=data["name"], ) - def _upgrade_partial_emoji( - self, emoji: PartialEmoji - ) -> Union[Emoji, PartialEmoji, str]: + def _upgrade_partial_emoji(self, emoji: PartialEmoji) -> Union[Emoji, PartialEmoji, str]: emoji_id = emoji.id if not emoji_id: return emoji.name @@ -1741,9 +1657,7 @@ def get_channel(self, id: Optional[int]) -> Optional[Union[Channel, Thread]]: def create_message( self, *, - channel: Union[ - TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable - ], + channel: Union[TextChannel, Thread, DMChannel, GroupChannel, PartialMessageable], data: MessagePayload, ) -> Message: return Message(state=self, channel=channel, data=data) @@ -1764,9 +1678,7 @@ def _update_message_references(self) -> None: new_guild = self._get_guild(msg.guild.id) if new_guild is not None and new_guild is not msg.guild: channel_id = msg.channel.id - channel = new_guild._resolve_channel(channel_id) or Object( - id=channel_id - ) + channel = new_guild._resolve_channel(channel_id) or Object(id=channel_id) # channel will either be a TextChannel, Thread or Object msg._rebind_cached_references(new_guild, channel) # type: ignore @@ -1781,9 +1693,7 @@ async def chunker( nonce: Optional[str] = None, ) -> None: ws = self._get_websocket(guild_id, shard_id=shard_id) - await ws.request_chunks( - guild_id, query=query, limit=limit, presences=presences, nonce=nonce - ) + await ws.request_chunks(guild_id, query=query, limit=limit, presences=presences, nonce=nonce) async def _delay_ready(self) -> None: await self.shards_launched.wait() @@ -1794,9 +1704,7 @@ async def _delay_ready(self) -> None: # this snippet of code is basically waiting N seconds # until the last GUILD_CREATE was sent try: - guild = await asyncio.wait_for( - self._ready_state.get(), timeout=self.guild_ready_timeout - ) + guild = await asyncio.wait_for(self._ready_state.get(), timeout=self.guild_ready_timeout) except asyncio.TimeoutError: break else: @@ -1807,9 +1715,7 @@ async def _delay_ready(self) -> None: ) if len(current_bucket) >= max_concurrency: try: - await utils.sane_wait_for( - current_bucket, timeout=max_concurrency * 70.0 - ) + await utils.sane_wait_for(current_bucket, timeout=max_concurrency * 70.0) except asyncio.TimeoutError: fmt = "Shard ID %s failed to wait for chunks from a sub-bucket with length %d" _log.warning(fmt, guild.shard_id, len(current_bucket)) @@ -1877,9 +1783,7 @@ def parse_ready(self, data) -> None: pass else: self.application_id = utils._get_as_snowflake(application, "id") - self.application_flags = ApplicationFlags._from_value( - application["flags"] - ) + self.application_flags = ApplicationFlags._from_value(application["flags"]) for guild_data in data["guilds"]: self._add_guild_from_data(guild_data) diff --git a/discord/sticker.py b/discord/sticker.py index 20f4dbf01a..70f99ab45a 100644 --- a/discord/sticker.py +++ b/discord/sticker.py @@ -206,9 +206,7 @@ def __init__(self, *, state: ConnectionState, data: StickerItemPayload): self._state: ConnectionState = state self.name: str = data["name"] self.id: int = int(data["id"]) - self.format: StickerFormatType = try_enum( - StickerFormatType, data["format_type"] - ) + self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"]) self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: @@ -282,9 +280,7 @@ def _from_data(self, data: StickerPayload) -> None: self.id: int = int(data["id"]) self.name: str = data["name"] self.description: str = data["description"] - self.format: StickerFormatType = try_enum( - StickerFormatType, data["format_type"] - ) + self.format: StickerFormatType = try_enum(StickerFormatType, data["format_type"]) self.url: str = f"{Asset.BASE}/stickers/{self.id}.{self.format.file_extension}" def __repr__(self) -> str: @@ -350,9 +346,7 @@ def _from_data(self, data: StandardStickerPayload) -> None: self.tags = [] def __repr__(self) -> str: - return ( - f"" - ) + return f"" async def pack(self) -> StickerPack: """|coro| @@ -371,9 +365,7 @@ async def pack(self) -> StickerPack: :class:`StickerPack` The retrieved sticker pack. """ - data: ListPremiumStickerPacksPayload = ( - await self._state.http.list_premium_sticker_packs() - ) + data: ListPremiumStickerPacksPayload = await self._state.http.list_premium_sticker_packs() packs = data["sticker_packs"] pack = find(lambda d: int(d["id"]) == self.pack_id, packs) @@ -498,9 +490,7 @@ async def edit( payload["tags"] = emoji - data: GuildStickerPayload = await self._state.http.modify_guild_sticker( - self.guild_id, self.id, payload, reason - ) + data: GuildStickerPayload = await self._state.http.modify_guild_sticker(self.guild_id, self.id, payload, reason) return GuildSticker(state=self._state, data=data) async def delete(self, *, reason: Optional[str] = None) -> None: diff --git a/discord/team.py b/discord/team.py index 95174948bf..2d264b0137 100644 --- a/discord/team.py +++ b/discord/team.py @@ -69,9 +69,7 @@ def __init__(self, state: ConnectionState, data: TeamPayload): self.name: str = data["name"] self._icon: Optional[str] = data["icon"] self.owner_id: Optional[int] = utils._get_as_snowflake(data, "owner_user_id") - self.members: List[TeamMember] = [ - TeamMember(self, self._state, member) for member in data["members"] - ] + self.members: List[TeamMember] = [TeamMember(self, self._state, member) for member in data["members"]] def __repr__(self) -> str: return f"<{self.__class__.__name__} id={self.id} name={self.name}>" @@ -134,9 +132,7 @@ class TeamMember(BaseUser): def __init__(self, team: Team, state: ConnectionState, data: TeamMemberPayload): self.team: Team = team - self.membership_state: TeamMembershipState = try_enum( - TeamMembershipState, data["membership_state"] - ) + self.membership_state: TeamMembershipState = try_enum(TeamMembershipState, data["membership_state"]) self.permissions: List[str] = data["permissions"] super().__init__(state=state, data=data["user"]) diff --git a/discord/template.py b/discord/template.py index 9895aef01e..e40cc4a277 100644 --- a/discord/template.py +++ b/discord/template.py @@ -141,16 +141,10 @@ def _store(self, data: TemplatePayload) -> None: self.name: str = data["name"] self.description: Optional[str] = data["description"] creator_data = data.get("creator") - self.creator: Optional[User] = ( - None if creator_data is None else self._state.create_user(creator_data) - ) + self.creator: Optional[User] = None if creator_data is None else self._state.create_user(creator_data) - self.created_at: Optional[datetime.datetime] = parse_time( - data.get("created_at") - ) - self.updated_at: Optional[datetime.datetime] = parse_time( - data.get("updated_at") - ) + self.created_at: Optional[datetime.datetime] = parse_time(data.get("created_at")) + self.updated_at: Optional[datetime.datetime] = parse_time(data.get("updated_at")) guild_id = int(data["source_guild_id"]) guild: Optional[Guild] = self._state._get_guild(guild_id) @@ -173,9 +167,7 @@ def __repr__(self) -> str: f" creator={self.creator!r} source_guild={self.source_guild!r} is_dirty={self.is_dirty}>" ) - async def create_guild( - self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None - ) -> Guild: + async def create_guild(self, name: str, region: Optional[VoiceRegion] = None, icon: Any = None) -> Guild: """|coro| Creates a :class:`.Guild` using the template. @@ -212,9 +204,7 @@ async def create_guild( region = region or VoiceRegion.us_west region_value = region.value - data = await self._state.http.create_from_template( - self.code, name, region_value, icon - ) + data = await self._state.http.create_from_template(self.code, name, region_value, icon) return Guild(data=data, state=self._state) async def sync(self) -> Template: @@ -294,9 +284,7 @@ async def edit( if description is not MISSING: payload["description"] = description - data = await self._state.http.edit_template( - self.source_guild.id, self.code, payload - ) + data = await self._state.http.edit_template(self.source_guild.id, self.code, payload) return Template(state=self._state, data=data) async def delete(self) -> None: diff --git a/discord/threads.py b/discord/threads.py index f580435880..4b4381582b 100644 --- a/discord/threads.py +++ b/discord/threads.py @@ -252,11 +252,7 @@ def last_message(self) -> Optional[Message]: Optional[:class:`Message`] The last message in this channel or ``None`` if not found. """ - return ( - self._state._get_message(self.last_message_id) - if self.last_message_id - else None - ) + return self._state._get_message(self.last_message_id) if self.last_message_id else None @property def category(self) -> Optional[CategoryChannel]: @@ -487,9 +483,7 @@ def is_me(m): ret: List[Message] = [] count = 0 - minimum_time = ( - int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 - ) + minimum_time = int((time.time() - 14 * 24 * 60 * 60) * 1000.0 - 1420070400000) << 22 async def _single_delete_strategy(messages: Iterable[Message]): for m in messages: diff --git a/discord/types/interactions.py b/discord/types/interactions.py index 45ffe665d3..b6761bdc8d 100644 --- a/discord/types/interactions.py +++ b/discord/types/interactions.py @@ -104,44 +104,32 @@ class _ApplicationCommandInteractionDataOption(TypedDict): name: str -class _ApplicationCommandInteractionDataOptionSubcommand( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionSubcommand(_ApplicationCommandInteractionDataOption): type: Literal[1, 2] options: List[ApplicationCommandInteractionDataOption] -class _ApplicationCommandInteractionDataOptionString( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionString(_ApplicationCommandInteractionDataOption): type: Literal[3] value: str -class _ApplicationCommandInteractionDataOptionInteger( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionInteger(_ApplicationCommandInteractionDataOption): type: Literal[4] value: int -class _ApplicationCommandInteractionDataOptionBoolean( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionBoolean(_ApplicationCommandInteractionDataOption): type: Literal[5] value: bool -class _ApplicationCommandInteractionDataOptionSnowflake( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionSnowflake(_ApplicationCommandInteractionDataOption): type: Literal[6, 7, 8, 9, 11] value: Snowflake -class _ApplicationCommandInteractionDataOptionNumber( - _ApplicationCommandInteractionDataOption -): +class _ApplicationCommandInteractionDataOptionNumber(_ApplicationCommandInteractionDataOption): type: Literal[10] value: float diff --git a/discord/types/voice.py b/discord/types/voice.py index e4dbe8c4cf..3a4ce86813 100644 --- a/discord/types/voice.py +++ b/discord/types/voice.py @@ -28,9 +28,7 @@ from .member import MemberWithUser from .snowflake import Snowflake -SupportedModes = Literal[ - "xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305" -] +SupportedModes = Literal["xsalsa20_poly1305_lite", "xsalsa20_poly1305_suffix", "xsalsa20_poly1305"] class _PartialVoiceStateOptional(TypedDict, total=False): diff --git a/discord/ui/button.py b/discord/ui/button.py index d19bad5b77..e7bfaca219 100644 --- a/discord/ui/button.py +++ b/discord/ui/button.py @@ -112,9 +112,7 @@ def __init__( elif isinstance(emoji, _EmojiTag): emoji = emoji._to_partial() else: - raise TypeError( - f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}" - ) + raise TypeError(f"expected emoji to be str, Emoji, or PartialEmoji not {emoji.__class__}") self._underlying = ButtonComponent._raw_construct( type=ComponentType.button, @@ -194,9 +192,7 @@ def emoji(self, value: Optional[Union[str, Emoji, PartialEmoji]]): # type: igno elif isinstance(value, _EmojiTag): self._underlying.emoji = value._to_partial() else: - raise TypeError( - f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead" - ) + raise TypeError(f"expected str, Emoji, or PartialEmoji, received {value.__class__} instead") @classmethod def from_component(cls: Type[B], button: ButtonComponent) -> B: diff --git a/discord/ui/input_text.py b/discord/ui/input_text.py index d2890d4388..f4285d30eb 100644 --- a/discord/ui/input_text.py +++ b/discord/ui/input_text.py @@ -85,9 +85,7 @@ def style(self) -> InputTextStyle: @style.setter def style(self, value: InputTextStyle): if not isinstance(value, InputTextStyle): - raise TypeError( - f"style must be of type InputTextStyle not {value.__class__}" - ) + raise TypeError(f"style must be of type InputTextStyle not {value.__class__}") self._underlying.style = value @property diff --git a/discord/ui/item.py b/discord/ui/item.py index 285c1e9e1e..86c2bc9a7d 100644 --- a/discord/ui/item.py +++ b/discord/ui/item.py @@ -101,9 +101,7 @@ def is_persistent(self) -> bool: return self._provided_custom_id def __repr__(self) -> str: - attrs = " ".join( - f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__ - ) + attrs = " ".join(f"{key}={getattr(self, key)!r}" for key in self.__item_repr_attributes__) return f"<{self.__class__.__name__} {attrs}>" @property diff --git a/discord/ui/modal.py b/discord/ui/modal.py index d909488823..a4a3221377 100644 --- a/discord/ui/modal.py +++ b/discord/ui/modal.py @@ -124,9 +124,7 @@ def add_item(self, item: InputText) -> None: if item.row is not None: total = self.weights[item.row] + item.width if total > 5: - raise ValueError( - f"item would not fit at row {item.row} ({total} > 5 width)" - ) + raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)") self.weights[item.row] = total item._rendered_row = item.row else: diff --git a/discord/ui/view.py b/discord/ui/view.py index c2f2ca17b1..dfd3d8893c 100644 --- a/discord/ui/view.py +++ b/discord/ui/view.py @@ -105,9 +105,7 @@ def add_item(self, item: Item) -> None: if item.row is not None: total = self.weights[item.row] + item.width if total > 5: - raise ValueError( - f"item would not fit at row {item.row} ({total} > 5 width)" - ) + raise ValueError(f"item would not fit at row {item.row} ({total} > 5 width)") self.weights[item.row] = total item._rendered_row = item.row else: @@ -167,9 +165,7 @@ def __init__(self, *items: Item, timeout: Optional[float] = 180.0): self.timeout = timeout self.children: List[Item] = [] for func in self.__view_children_items__: - item: Item = func.__discord_ui_model_type__( - **func.__discord_ui_model_kwargs__ - ) + item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) item.callback = partial(func, self, item) item._view = self setattr(self, func.__name__, item) @@ -227,9 +223,7 @@ def key(item: Item) -> int: return components @classmethod - def from_message( - cls, message: Message, /, *, timeout: Optional[float] = 180.0 - ) -> View: + def from_message(cls, message: Message, /, *, timeout: Optional[float] = 180.0) -> View: """Converts a message's components into a :class:`View`. The :attr:`.Message.components` of a message are read-only @@ -345,9 +339,7 @@ async def on_timeout(self) -> None: """ pass - async def on_error( - self, error: Exception, item: Item, interaction: Interaction - ) -> None: + async def on_error(self, error: Exception, item: Item, interaction: Interaction) -> None: """|coro| A callback that is called when an item's callback or :meth:`interaction_check` @@ -365,9 +357,7 @@ async def on_error( The interaction that led to the failure. """ print(f"Ignoring exception in view {self} for item {item}:", file=sys.stderr) - traceback.print_exception( - error.__class__, error, error.__traceback__, file=sys.stderr - ) + traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr) async def _scheduled_task(self, item: Item, interaction: Interaction): try: @@ -399,9 +389,7 @@ def _dispatch_timeout(self): return self.__stopped.set_result(True) - asyncio.create_task( - self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}" - ) + asyncio.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}") def _dispatch_item(self, item: Item, interaction: Interaction): if self.__stopped.done(): @@ -417,9 +405,7 @@ def refresh(self, components: List[Component]): old_state: Dict[Tuple[int, str], Item] = { (item.type.value, item.custom_id): item for item in self.children if item.is_dispatchable() # type: ignore } - children: List[Item] = [ - item for item in self.children if not item.is_dispatchable() - ] + children: List[Item] = [item for item in self.children if not item.is_dispatchable()] for component in _walk_all_components(components): try: older = old_state[(component.type.value, component.custom_id)] # type: ignore @@ -465,9 +451,7 @@ def is_persistent(self) -> bool: A persistent view has all their components with a set ``custom_id`` and a :attr:`timeout` set to ``None``. """ - return self.timeout is None and all( - item.is_persistent() for item in self.children - ) + return self.timeout is None and all(item.is_persistent() for item in self.children) async def wait(self) -> bool: """Waits until the view has finished interacting. @@ -494,11 +478,7 @@ def __init__(self, state: ConnectionState): @property def persistent_views(self) -> Sequence[View]: - views = { - view.id: view - for (_, (view, _)) in self._views.items() - if view.is_persistent() - } + views = {view.id: view for (_, (view, _)) in self._views.items() if view.is_persistent()} return list(views.values()) def __verify_integrity(self): @@ -537,9 +517,7 @@ def dispatch(self, component_type: int, custom_id: str, interaction: Interaction key = (component_type, message_id, custom_id) # Fallback to None message_id searches in case a persistent view # was added without an associated message_id - value = self._views.get(key) or self._views.get( - (component_type, None, custom_id) - ) + value = self._views.get(key) or self._views.get((component_type, None, custom_id)) if value is None: return diff --git a/discord/user.py b/discord/user.py index 2c91ea2dbc..101a3b1e12 100644 --- a/discord/user.py +++ b/discord/user.py @@ -162,9 +162,7 @@ def avatar(self) -> Optional[Asset]: @property def default_avatar(self) -> Asset: """:class:`Asset`: Returns the default avatar for a given user. This is calculated by the user's discriminator.""" - return Asset._from_default_avatar( - self._state, int(self.discriminator) % len(DefaultAvatar) - ) + return Asset._from_default_avatar(self._state, int(self.discriminator) % len(DefaultAvatar)) @property def display_avatar(self) -> Asset: @@ -350,9 +348,7 @@ def _update(self, data: UserPayload) -> None: self._flags = data.get("flags", 0) self.mfa_enabled = data.get("mfa_enabled", False) - async def edit( - self, *, username: str = MISSING, avatar: bytes = MISSING - ) -> ClientUser: + async def edit(self, *, username: str = MISSING, avatar: bytes = MISSING) -> ClientUser: """|coro| Edits the current profile of the client. @@ -480,9 +476,7 @@ def mutual_guilds(self) -> List[Guild]: .. versionadded:: 1.7 """ - return [ - guild for guild in self._state._guilds.values() if guild.get_member(self.id) - ] + return [guild for guild in self._state._guilds.values() if guild.get_member(self.id)] async def create_dm(self) -> DMChannel: """|coro| diff --git a/discord/utils.py b/discord/utils.py index 3646d2f470..0878c78712 100644 --- a/discord/utils.py +++ b/discord/utils.py @@ -277,9 +277,7 @@ def decorated(*args: P.args, **kwargs: P.kwargs) -> T: else: fmt = "{0.__name__} is deprecated." - warnings.warn( - fmt.format(func, instead), stacklevel=3, category=DeprecationWarning - ) + warnings.warn(fmt.format(func, instead), stacklevel=3, category=DeprecationWarning) warnings.simplefilter("default", DeprecationWarning) # reset filter return func(*args, **kwargs) @@ -465,9 +463,7 @@ def get(iterable: Iterable[T], **attrs: Any) -> Optional[T]: return elem return None - converted = [ - (attrget(attr.replace("__", ".")), value) for attr, value in attrs.items() - ] + converted = [(attrget(attr.replace("__", ".")), value) for attr, value in attrs.items()] for elem in iterable: if _all(pred(elem) == value for pred, value in converted): @@ -543,9 +539,7 @@ def _parse_ratelimit_header(request: Any, *, use_clock: bool = False) -> float: return float(reset_after) utc = datetime.timezone.utc now = datetime.datetime.now(utc) - reset = datetime.datetime.fromtimestamp( - float(request.headers["X-Ratelimit-Reset"]), utc - ) + reset = datetime.datetime.fromtimestamp(float(request.headers["X-Ratelimit-Reset"]), utc) return (reset - now).total_seconds() @@ -568,9 +562,7 @@ async def async_all(gen, *, check=_isawaitable): async def sane_wait_for(futures, *, timeout): ensured = [asyncio.ensure_future(fut) for fut in futures] - done, pending = await asyncio.wait( - ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED - ) + done, pending = await asyncio.wait(ensured, timeout=timeout, return_when=asyncio.ALL_COMPLETED) if len(pending) != 0: raise asyncio.TimeoutError() @@ -593,9 +585,7 @@ def compute_timedelta(dt: datetime.datetime): return max((dt - now).total_seconds(), 0) -async def sleep_until( - when: datetime.datetime, result: Optional[T] = None -) -> Optional[T]: +async def sleep_until(when: datetime.datetime, result: Optional[T] = None) -> Optional[T]: """|coro| Sleep until a specified time. @@ -738,9 +728,7 @@ def resolve_template(code: Union[Template, str]) -> str: return code -_MARKDOWN_ESCAPE_SUBREGEX = "|".join( - r"\{0}(?=([\s\S]*((? str: +def escape_markdown(text: str, *, as_needed: bool = False, ignore_links: bool = True) -> str: r"""A helper function that escapes Discord's markdown. Parameters @@ -1002,16 +988,11 @@ def evaluate_annotation( is_literal = True evaluated_args = tuple( - evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) - for arg in args + evaluate_annotation(arg, globals, locals, cache, implicit_str=implicit_str) for arg in args ) - if is_literal and not all( - isinstance(x, (str, int, bool, type(None))) for x in evaluated_args - ): - raise TypeError( - "Literal arguments must be of type str, int, bool, or NoneType." - ) + if is_literal and not all(isinstance(x, (str, int, bool, type(None))) for x in evaluated_args): + raise TypeError("Literal arguments must be of type str, int, bool, or NoneType.") if evaluated_args == args: return tp diff --git a/discord/voice_client.py b/discord/voice_client.py index d474945d18..c4798248e9 100644 --- a/discord/voice_client.py +++ b/discord/voice_client.py @@ -358,9 +358,7 @@ def prepare_handshake(self) -> None: self._voice_state_complete.clear() self._voice_server_complete.clear() self._handshaking = True - _log.info( - "Starting voice handshake... (connection attempt %d)", self._connections + 1 - ) + _log.info("Starting voice handshake... (connection attempt %d)", self._connections + 1) self._connections += 1 def finish_handshake(self) -> None: @@ -423,9 +421,7 @@ async def potential_reconnect(self) -> bool: self._potentially_reconnecting = True try: # We only care about VOICE_SERVER_UPDATE since VOICE_STATE_UPDATE can come before we get disconnected - await asyncio.wait_for( - self._voice_server_complete.wait(), timeout=self.timeout - ) + await asyncio.wait_for(self._voice_server_complete.wait(), timeout=self.timeout) except asyncio.TimeoutError: self._potentially_reconnecting = False await self.disconnect(force=True) @@ -480,16 +476,12 @@ async def poll_voice_ws(self, reconnect: bool) -> None: await self.disconnect() break if exc.code == 4014: - _log.info( - "Disconnected from voice by force... potentially reconnecting." - ) + _log.info("Disconnected from voice by force... potentially reconnecting.") successful = await self.potential_reconnect() if successful: continue - _log.info( - "Reconnect was unsuccessful, disconnecting from voice normally..." - ) + _log.info("Reconnect was unsuccessful, disconnecting from voice normally...") await self.disconnect() break if not reconnect: @@ -497,9 +489,7 @@ async def poll_voice_ws(self, reconnect: bool) -> None: raise retry = backoff.delay() - _log.exception( - "Disconnected from voice... Reconnecting in %.2fs.", retry - ) + _log.exception("Disconnected from voice... Reconnecting in %.2fs.", retry) self._connected.clear() await asyncio.sleep(retry) await self.voice_disconnect() @@ -618,13 +608,9 @@ def strip_header_ext(data): return data def get_ssrc(self, user_id): - return {info["user_id"]: ssrc for ssrc, info in self.ws.ssrc_map.items()}[ - user_id - ] + return {info["user_id"]: ssrc for ssrc, info in self.ws.ssrc_map.items()}[user_id] - def play( - self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None - ) -> None: + def play(self, source: AudioSource, *, after: Callable[[Optional[Exception]], Any] = None) -> None: """Plays an :class:`AudioSource`. The finalizer, ``after`` is called after the source has been exhausted @@ -660,9 +646,7 @@ def play( raise ClientException("Already playing audio.") if not isinstance(source, AudioSource): - raise TypeError( - f"source must be an AudioSource not {source.__class__.__name__}" - ) + raise TypeError(f"source must be an AudioSource not {source.__class__.__name__}") if not self.encoder and not source.is_opus(): self.encoder = opus.Encoder() @@ -813,9 +797,7 @@ def recv_audio(self, sink, callback, *args): self.stopping_time = time.perf_counter() self.sink.cleanup() - callback = asyncio.run_coroutine_threadsafe( - callback(self.sink, *args), self.loop - ) + callback = asyncio.run_coroutine_threadsafe(callback(self.sink, *args), self.loop) result = callback.result() if result is not None: @@ -830,10 +812,7 @@ def recv_decoded_audio(self, data): silence = data.timestamp - self.user_timestamps[data.ssrc] - 960 self.user_timestamps[data.ssrc] = data.timestamp - data.decoded_data = ( - struct.pack(" Response[WebhookPayload]: route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id) - return self.request( - route, session, reason=reason, payload=payload, auth_token=token - ) + return self.request(route, session, reason=reason, payload=payload, auth_token=token) def edit_webhook_with_token( self, @@ -500,9 +493,7 @@ def edit_original_interaction_response( webhook_id=application_id, webhook_token=token, ) - return self.request( - r, session, payload=payload, multipart=multipart, files=files - ) + return self.request(r, session, payload=payload, multipart=multipart, files=files) def delete_original_interaction_response( self, @@ -572,9 +563,7 @@ def handle_message_parameters( if allowed_mentions: if previous_allowed_mentions is not None: - payload["allowed_mentions"] = previous_allowed_mentions.merge( - allowed_mentions - ).to_dict() + payload["allowed_mentions"] = previous_allowed_mentions.merge(allowed_mentions).to_dict() else: payload["allowed_mentions"] = allowed_mentions.to_dict() elif previous_allowed_mentions is not None: @@ -611,9 +600,7 @@ def handle_message_parameters( return ExecuteWebhookParameters(payload=payload, multipart=multipart, files=files) -async_context: ContextVar[AsyncWebhookAdapter] = ContextVar( - "async_webhook_context", default=AsyncWebhookAdapter() -) +async_context: ContextVar[AsyncWebhookAdapter] = ContextVar("async_webhook_context", default=AsyncWebhookAdapter()) class PartialWebhookChannel(Hashable): @@ -685,9 +672,7 @@ def __getattr__(self, attr): class _WebhookState: __slots__ = ("_parent", "_webhook") - def __init__( - self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]] - ): + def __init__(self, webhook: Any, parent: Optional[Union[ConnectionState, _WebhookState]]): self._webhook: Any = webhook self._parent: Optional[ConnectionState] @@ -861,9 +846,7 @@ async def delete(self, *, delay: Optional[float] = None) -> None: async def inner_call(delay: float = delay): await asyncio.sleep(delay) try: - await self._state._webhook.delete_message( - self.id, thread_id=thread_id - ) + await self._state._webhook.delete_message(self.id, thread_id=thread_id) except HTTPException: pass @@ -895,9 +878,7 @@ def __init__( state: Optional[ConnectionState] = None, ): self.auth_token: Optional[str] = token - self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState( - self, parent=state - ) + self._state: Union[ConnectionState, _WebhookState] = state or _WebhookState(self, parent=state) self._update(data) def _update(self, data: WebhookPayload): @@ -1229,17 +1210,11 @@ async def fetch(self, *, prefer_auth: bool = True) -> Webhook: adapter = async_context.get() if prefer_auth and self.auth_token: - data = await adapter.fetch_webhook( - self.id, self.auth_token, session=self.session - ) + data = await adapter.fetch_webhook(self.id, self.auth_token, session=self.session) elif self.token: - data = await adapter.fetch_webhook_with_token( - self.id, self.token, session=self.session - ) + data = await adapter.fetch_webhook_with_token(self.id, self.token, session=self.session) else: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") return Webhook(data, self.session, token=self.auth_token, state=self._state) @@ -1272,20 +1247,14 @@ async def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True This webhook does not have a token associated with it. """ if self.token is None and self.auth_token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter = async_context.get() if prefer_auth and self.auth_token: - await adapter.delete_webhook( - self.id, token=self.auth_token, session=self.session, reason=reason - ) + await adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason) elif self.token: - await adapter.delete_webhook_with_token( - self.id, self.token, session=self.session, reason=reason - ) + await adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason) async def edit( self, @@ -1331,18 +1300,14 @@ async def edit( or it tried editing a channel without authentication. """ if self.token is None and self.auth_token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") payload = {} if name is not MISSING: payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = ( - utils._bytes_to_base64_data(avatar) if avatar is not None else None - ) + payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None adapter = async_context.get() @@ -1381,9 +1346,7 @@ async def edit( if data is None: raise RuntimeError("Unreachable code hit: data was not assigned") - return Webhook( - data=data, session=self.session, token=self.auth_token, state=self._state - ) + return Webhook(data=data, session=self.session, token=self.auth_token, state=self._state) def _create_message(self, data): state = _WebhookState(self, parent=self._state) @@ -1542,30 +1505,22 @@ async def send( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") - previous_mentions: Optional[AllowedMentions] = getattr( - self._state, "allowed_mentions", None - ) + previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None) if content is None: content = MISSING application_webhook = self.type is WebhookType.application if ephemeral and not application_webhook: - raise InvalidArgument( - "ephemeral messages can only be sent from application webhooks" - ) + raise InvalidArgument("ephemeral messages can only be sent from application webhooks") if application_webhook: wait = True if view is not MISSING: if isinstance(self._state, _WebhookState): - raise InvalidArgument( - "Webhook views require an associated state with the webhook" - ) + raise InvalidArgument("Webhook views require an associated state with the webhook") if ephemeral is True and view.timeout is None: view.timeout = 15 * 60.0 @@ -1617,9 +1572,7 @@ async def delete(): return msg - async def fetch_message( - self, id: int, *, thread_id: Optional[int] = None - ) -> WebhookMessage: + async def fetch_message(self, id: int, *, thread_id: Optional[int] = None) -> WebhookMessage: """|coro| Retrieves a single :class:`~discord.WebhookMessage` owned by this webhook. @@ -1651,9 +1604,7 @@ async def fetch_message( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter = async_context.get() data = await adapter.get_webhook_message( @@ -1751,21 +1702,15 @@ async def edit_message( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") if view is not MISSING: if isinstance(self._state, _WebhookState): - raise InvalidArgument( - "This webhook does not have state associated with it" - ) + raise InvalidArgument("This webhook does not have state associated with it") self._state.prevent_view_updates_for(message_id) - previous_mentions: Optional[AllowedMentions] = getattr( - self._state, "allowed_mentions", None - ) + previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None) params = handle_message_parameters( content=content, file=file, @@ -1799,9 +1744,7 @@ async def edit_message( self._state.store_view(view, message_id) return message - async def delete_message( - self, message_id: int, *, thread_id: Optional[int] = None - ) -> None: + async def delete_message(self, message_id: int, *, thread_id: Optional[int] = None) -> None: """|coro| Deletes a message owned by this webhook. @@ -1826,9 +1769,7 @@ async def delete_message( Deleted a message that is not yours. """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter = async_context.get() await adapter.delete_webhook_message( diff --git a/discord/webhook/sync.py b/discord/webhook/sync.py index d158141869..6561c4dd4d 100644 --- a/discord/webhook/sync.py +++ b/discord/webhook/sync.py @@ -189,10 +189,7 @@ def request( response.status = response.status_code # type: ignore data = response.text or None - if ( - data - and response.headers["Content-Type"] == "application/json" - ): + if data and response.headers["Content-Type"] == "application/json": data = json.loads(data) remaining = response.headers.get("X-Ratelimit-Remaining") @@ -282,9 +279,7 @@ def edit_webhook( reason: Optional[str] = None, ): route = Route("PATCH", "/webhooks/{webhook_id}", webhook_id=webhook_id) - return self.request( - route, session, reason=reason, payload=payload, auth_token=token - ) + return self.request(route, session, reason=reason, payload=payload, auth_token=token) def edit_webhook_with_token( self, @@ -682,9 +677,7 @@ def partial( return cls(data, session, token=bot_token) @classmethod - def from_url( - cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None - ) -> SyncWebhook: + def from_url(cls, url: str, *, session: Session = MISSING, bot_token: Optional[str] = None) -> SyncWebhook: """Creates a partial :class:`Webhook` from a webhook URL. Parameters @@ -764,13 +757,9 @@ def fetch(self, *, prefer_auth: bool = True) -> SyncWebhook: if prefer_auth and self.auth_token: data = adapter.fetch_webhook(self.id, self.auth_token, session=self.session) elif self.token: - data = adapter.fetch_webhook_with_token( - self.id, self.token, session=self.session - ) + data = adapter.fetch_webhook_with_token(self.id, self.token, session=self.session) else: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") return SyncWebhook(data, self.session, token=self.auth_token, state=self._state) @@ -799,20 +788,14 @@ def delete(self, *, reason: Optional[str] = None, prefer_auth: bool = True) -> N This webhook does not have a token associated with it. """ if self.token is None and self.auth_token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter: WebhookAdapter = _get_webhook_adapter() if prefer_auth and self.auth_token: - adapter.delete_webhook( - self.id, token=self.auth_token, session=self.session, reason=reason - ) + adapter.delete_webhook(self.id, token=self.auth_token, session=self.session, reason=reason) elif self.token: - adapter.delete_webhook_with_token( - self.id, self.token, session=self.session, reason=reason - ) + adapter.delete_webhook_with_token(self.id, self.token, session=self.session, reason=reason) def edit( self, @@ -857,18 +840,14 @@ def edit( The newly edited webhook. """ if self.token is None and self.auth_token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") payload = {} if name is not MISSING: payload["name"] = str(name) if name is not None else None if avatar is not MISSING: - payload["avatar"] = ( - utils._bytes_to_base64_data(avatar) if avatar is not None else None - ) + payload["avatar"] = utils._bytes_to_base64_data(avatar) if avatar is not None else None adapter: WebhookAdapter = _get_webhook_adapter() @@ -907,9 +886,7 @@ def edit( if data is None: raise RuntimeError("Unreachable code hit: data was not assigned") - return SyncWebhook( - data=data, session=self.session, token=self.auth_token, state=self._state - ) + return SyncWebhook(data=data, session=self.session, token=self.auth_token, state=self._state) def _create_message(self, data): state = _WebhookState(self, parent=self._state) @@ -1039,13 +1016,9 @@ def send( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") - previous_mentions: Optional[AllowedMentions] = getattr( - self._state, "allowed_mentions", None - ) + previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None) if content is None: content = MISSING @@ -1079,9 +1052,7 @@ def send( if wait: return self._create_message(data) - def fetch_message( - self, id: int, *, thread_id: Optional[int] = None - ) -> SyncWebhookMessage: + def fetch_message(self, id: int, *, thread_id: Optional[int] = None) -> SyncWebhookMessage: """Retrieves a single :class:`~discord.SyncWebhookMessage` owned by this webhook. .. versionadded:: 2.0 @@ -1111,9 +1082,7 @@ def fetch_message( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter: WebhookAdapter = _get_webhook_adapter() data = adapter.get_webhook_message( @@ -1185,13 +1154,9 @@ def edit_message( """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") - previous_mentions: Optional[AllowedMentions] = getattr( - self._state, "allowed_mentions", None - ) + previous_mentions: Optional[AllowedMentions] = getattr(self._state, "allowed_mentions", None) params = handle_message_parameters( content=content, file=file, @@ -1219,9 +1184,7 @@ def edit_message( ) return self._create_message(data) - def delete_message( - self, message_id: int, *, thread_id: Optional[int] = None - ) -> None: + def delete_message(self, message_id: int, *, thread_id: Optional[int] = None) -> None: """Deletes a message owned by this webhook. This is a lower level interface to :meth:`WebhookMessage.delete` in case @@ -1244,9 +1207,7 @@ def delete_message( Deleted a message that is not yours. """ if self.token is None: - raise InvalidArgument( - "This webhook does not have a token associated with it" - ) + raise InvalidArgument("This webhook does not have a token associated with it") adapter: WebhookAdapter = _get_webhook_adapter() adapter.delete_webhook_message( diff --git a/discord/welcome_screen.py b/discord/welcome_screen.py index fe87e657dc..a428471094 100644 --- a/discord/welcome_screen.py +++ b/discord/welcome_screen.py @@ -95,9 +95,7 @@ def to_dict(self) -> WelcomeScreenChannelPayload: return dict_ @classmethod - def _from_dict( - cls, data: WelcomeScreenChannelPayload, guild: Guild - ) -> WelcomeScreenChannel: + def _from_dict(cls, data: WelcomeScreenChannelPayload, guild: Guild) -> WelcomeScreenChannel: channel_id = _get_as_snowflake(data, "channel_id") channel = guild.get_channel(channel_id) description = data.get("description") @@ -132,8 +130,7 @@ def __repr__(self): def _update(self, data: WelcomeScreenPayload): self.description: str = data.get("description") self.welcome_channels: List[WelcomeScreenChannel] = [ - WelcomeScreenChannel._from_dict(channel, self._guild) - for channel in data.get("welcome_channels", []) + WelcomeScreenChannel._from_dict(channel, self._guild) for channel in data.get("welcome_channels", []) ] @property @@ -216,9 +213,7 @@ async def edit(self, **options): for channel in welcome_channels: if not isinstance(channel, WelcomeScreenChannel): - raise TypeError( - "welcome_channels parameter must be a list of WelcomeScreenChannel." - ) + raise TypeError("welcome_channels parameter must be a list of WelcomeScreenChannel.") welcome_channels_data.append(channel.to_dict()) diff --git a/discord/widget.py b/discord/widget.py index d9a8edc5dc..d10d666843 100644 --- a/discord/widget.py +++ b/discord/widget.py @@ -179,12 +179,8 @@ def __init__( super().__init__(state=state, data=data) self.nick: Optional[str] = data.get("nick") self.status: Status = try_enum(Status, data.get("status")) - self.deafened: Optional[bool] = data.get("deaf", False) or data.get( - "self_deaf", False - ) - self.muted: Optional[bool] = data.get("mute", False) or data.get( - "self_mute", False - ) + self.deafened: Optional[bool] = data.get("deaf", False) or data.get("self_deaf", False) + self.muted: Optional[bool] = data.get("mute", False) or data.get("self_mute", False) self.suppress: Optional[bool] = data.get("suppress", False) try: @@ -259,11 +255,7 @@ def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: self.channels: List[WidgetChannel] = [] for channel in data.get("channels", []): _id = int(channel["id"]) - self.channels.append( - WidgetChannel( - id=_id, name=channel["name"], position=channel["position"] - ) - ) + self.channels.append(WidgetChannel(id=_id, name=channel["name"], position=channel["position"])) self.members: List[WidgetMember] = [] channels = {channel.id: channel for channel in self.channels} @@ -272,9 +264,7 @@ def __init__(self, *, state: ConnectionState, data: WidgetPayload) -> None: if connected_channel in channels: connected_channel = channels[connected_channel] # type: ignore elif connected_channel: - connected_channel = WidgetChannel( - id=connected_channel, name="", position=0 - ) + connected_channel = WidgetChannel(id=connected_channel, name="", position=0) self.members.append(WidgetMember(state=self._state, data=member, connected_channel=connected_channel)) # type: ignore @@ -287,9 +277,7 @@ def __eq__(self, other: Any) -> bool: return False def __repr__(self) -> str: - return ( - f"" - ) + return f"" @property def created_at(self) -> datetime.datetime: diff --git a/examples/app_commands/info.py b/examples/app_commands/info.py index 36b543f885..415d836941 100644 --- a/examples/app_commands/info.py +++ b/examples/app_commands/info.py @@ -18,9 +18,7 @@ @bot.slash_command(name="userinfo", description="gets the info of a user") async def info(ctx, user: discord.Member = None): - user = ( - user or ctx.author - ) # if no user is provided it'll use the the author of the message + user = user or ctx.author # if no user is provided it'll use the the author of the message e = discord.Embed() e.set_author(name=user.name) e.add_field(name="ID", value=user.id, inline=False) # user ID diff --git a/examples/app_commands/slash_autocomplete.py b/examples/app_commands/slash_autocomplete.py index d5ba3ff65b..1751434aae 100644 --- a/examples/app_commands/slash_autocomplete.py +++ b/examples/app_commands/slash_autocomplete.py @@ -86,9 +86,7 @@ "yellowgreen", ] -BASIC_ALLOWED = [ - ... -] # this would normally be a list of discord user IDs for the purpose of this example +BASIC_ALLOWED = [...] # this would normally be a list of discord user IDs for the purpose of this example async def color_searcher(ctx: discord.AutocompleteContext): @@ -96,9 +94,7 @@ async def color_searcher(ctx: discord.AutocompleteContext): In this example, we've added logic to only display any results in the returned list if the user's ID exists in the BASIC_ALLOWED list. This is to demonstrate passing a callback in the discord.utils.basic_autocomplete function. """ - return [ - color for color in LOTS_OF_COLORS if ctx.interaction.user.id in BASIC_ALLOWED - ] + return [color for color in LOTS_OF_COLORS if ctx.interaction.user.id in BASIC_ALLOWED] async def get_colors(ctx: discord.AutocompleteContext): @@ -120,9 +116,7 @@ async def get_animals(ctx: discord.AutocompleteContext): elif picked_color == "blue": return ["blue jay", "blue whale"] elif picked_color == "indigo": - return [ - "eastern indigo snake" - ] # needs to return an iterable even if only one item + return ["eastern indigo snake"] # needs to return an iterable even if only one item elif picked_color == "violet": return ["purple emperor butterfly", "orchid dottyback"] else: @@ -136,9 +130,7 @@ async def autocomplete_example( animal: Option(str, "Pick an animal!", autocomplete=get_animals), ): """This demonstrates using the ctx.options parameter to create slash command options that are dependent on the values entered for other options.""" - await ctx.respond( - f"You picked {color} for the color, which allowed you to choose {animal} for the animal." - ) + await ctx.respond(f"You picked {color} for the color, which allowed you to choose {animal} for the animal.") @bot.slash_command(name="ac_basic_example") @@ -152,9 +144,7 @@ async def autocomplete_basic_example( animal: Option( str, "Pick an animal from this small list", - autocomplete=discord.utils.basic_autocomplete( - ["snail", "python", "cricket", "orca"] - ), + autocomplete=discord.utils.basic_autocomplete(["snail", "python", "cricket", "orca"]), ), # Demonstrates passing a static iterable discord.utils.basic_autocomplete ): """This demonstrates using the discord.utils.basic_autocomplete helper function. diff --git a/examples/app_commands/slash_basic.py b/examples/app_commands/slash_basic.py index 17b0b512db..1733c938de 100644 --- a/examples/app_commands/slash_basic.py +++ b/examples/app_commands/slash_basic.py @@ -27,13 +27,9 @@ async def global_command(ctx, num: int): # Takes one integer parameter @bot.slash_command(guild_ids=[...]) -async def joined( - ctx, member: discord.Member = None -): # Passing a default value makes the argument optional +async def joined(ctx, member: discord.Member = None): # Passing a default value makes the argument optional user = member or ctx.author - await ctx.respond( - f"{user.name} joined at {discord.utils.format_dt(user.joined_at)}" - ) + await ctx.respond(f"{user.name} joined at {discord.utils.format_dt(user.joined_at)}") # To learn how to add descriptions and choices to options, check slash_options.py diff --git a/examples/app_commands/slash_cog_groups.py b/examples/app_commands/slash_cog_groups.py index f7bab83201..9312b009a8 100644 --- a/examples/app_commands/slash_cog_groups.py +++ b/examples/app_commands/slash_cog_groups.py @@ -11,16 +11,12 @@ def __init__(self, bot): greetings = SlashCommandGroup("greetings", "Various greeting from cogs!") - international_greetings = greetings.create_subgroup( - "international", "International greetings" - ) + international_greetings = greetings.create_subgroup("international", "International greetings") secret_greetings = SlashCommandGroup( "secret_greetings", "Secret greetings", - permissions=[ - CommandPermission("owner", 2, True) - ], # Ensures the owner_id user can access this, and no one else + permissions=[CommandPermission("owner", 2, True)], # Ensures the owner_id user can access this, and no one else ) @greetings.command() diff --git a/examples/app_commands/slash_groups.py b/examples/app_commands/slash_groups.py index ad3a70ba50..5b2b5adc3f 100644 --- a/examples/app_commands/slash_groups.py +++ b/examples/app_commands/slash_groups.py @@ -5,9 +5,7 @@ # If you use commands.Bot, @bot.slash_command should be used for # slash commands. You can use @bot.slash_command with discord.Bot as well -math = bot.create_group( - "math", "Commands related to mathematics." -) # create a slash command group +math = bot.create_group("math", "Commands related to mathematics.") # create a slash command group @math.command(guild_ids=[...]) # create a slash command diff --git a/examples/app_commands/slash_options.py b/examples/app_commands/slash_options.py index 0d77bccf12..43dba427a9 100644 --- a/examples/app_commands/slash_options.py +++ b/examples/app_commands/slash_options.py @@ -16,9 +16,7 @@ async def hello( # you also can create optional argument using: # age: Option(int, "Enter your age") = 18 ): - await ctx.respond( - f"Hello {name}! Your gender is {gender} and you are {age} years old." - ) + await ctx.respond(f"Hello {name}! Your gender is {gender} and you are {age} years old.") @bot.slash_command(guild_ids=[...]) diff --git a/examples/audio_recording.py b/examples/audio_recording.py index 27e65df0d3..cce4056af1 100644 --- a/examples/audio_recording.py +++ b/examples/audio_recording.py @@ -67,13 +67,8 @@ async def start( async def finished_callback(sink, channel: discord.TextChannel, *args): recorded_users = [f"<@{user_id}>" for user_id, audio in sink.audio_data.items()] await sink.vc.disconnect() - files = [ - discord.File(audio.file, f"{user_id}.{sink.encoding}") - for user_id, audio in sink.audio_data.items() - ] - await channel.send( - f"Finished! Recorded audio for {', '.join(recorded_users)}.", files=files - ) + files = [discord.File(audio.file, f"{user_id}.{sink.encoding}") for user_id, audio in sink.audio_data.items()] + await channel.send(f"Finished! Recorded audio for {', '.join(recorded_users)}.", files=files) @bot.command() diff --git a/examples/basic_voice.py b/examples/basic_voice.py index 1cc5f9b863..70ffd30c30 100644 --- a/examples/basic_voice.py +++ b/examples/basic_voice.py @@ -40,9 +40,7 @@ def __init__(self, source, *, data, volume=0.5): @classmethod async def from_url(cls, url, *, loop=None, stream=False): loop = loop or asyncio.get_event_loop() - data = await loop.run_in_executor( - None, lambda: ytdl.extract_info(url, download=not stream) - ) + data = await loop.run_in_executor(None, lambda: ytdl.extract_info(url, download=not stream)) if "entries" in data: # Takes the first item from a playlist @@ -70,9 +68,7 @@ async def play(self, ctx, *, query): """Plays a file from the local filesystem""" source = discord.PCMVolumeTransformer(discord.FFmpegPCMAudio(query)) - ctx.voice_client.play( - source, after=lambda e: print(f"Player error: {e}") if e else None - ) + ctx.voice_client.play(source, after=lambda e: print(f"Player error: {e}") if e else None) await ctx.send(f"Now playing: {query}") @@ -82,9 +78,7 @@ async def yt(self, ctx, *, url): async with ctx.typing(): player = await YTDLSource.from_url(url, loop=self.bot.loop) - ctx.voice_client.play( - player, after=lambda e: print(f"Player error: {e}") if e else None - ) + ctx.voice_client.play(player, after=lambda e: print(f"Player error: {e}") if e else None) await ctx.send(f"Now playing: {player.title}") @@ -94,9 +88,7 @@ async def stream(self, ctx, *, url): async with ctx.typing(): player = await YTDLSource.from_url(url, loop=self.bot.loop, stream=True) - ctx.voice_client.play( - player, after=lambda e: print(f"Player error: {e}") if e else None - ) + ctx.voice_client.play(player, after=lambda e: print(f"Player error: {e}") if e else None) await ctx.send(f"Now playing: {player.title}") diff --git a/examples/converters.py b/examples/converters.py index 6abea1ee5e..fc9922a148 100644 --- a/examples/converters.py +++ b/examples/converters.py @@ -72,9 +72,7 @@ async def convert(self, ctx: commands.Context, argument: str): # If the value could not be converted we can raise an error # So our error handlers can deal with it in one place. # The error has to be CommandError derived, so BadArgument works fine here. - raise commands.BadArgument( - f'No Member or TextChannel could be converted from "{argument}"' - ) + raise commands.BadArgument(f'No Member or TextChannel could be converted from "{argument}"') @bot.command() @@ -88,9 +86,7 @@ async def notify(ctx: commands.Context, target: ChannelOrMemberConverter): @bot.command() -async def ignore( - ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel] -): +async def ignore(ctx: commands.Context, target: typing.Union[discord.Member, discord.TextChannel]): # This command signature utilises the `typing.Union` typehint. # The `commands` framework attempts a conversion of each type in this Union *in order*. # So, it will attempt to convert whatever is passed to `target` to a `discord.Member` instance. @@ -101,15 +97,9 @@ async def ignore( # To check the resulting type, `isinstance` is used if isinstance(target, discord.Member): - await ctx.send( - f"Member found: {target.mention}, adding them to the ignore list." - ) - elif isinstance( - target, discord.TextChannel - ): # This could be an `else` but for completeness' sake. - await ctx.send( - f"Channel found: {target.mention}, adding it to the ignore list." - ) + await ctx.send(f"Member found: {target.mention}, adding them to the ignore list.") + elif isinstance(target, discord.TextChannel): # This could be an `else` but for completeness' sake. + await ctx.send(f"Channel found: {target.mention}, adding it to the ignore list.") # Built-in type converters. diff --git a/examples/cooldown.py b/examples/cooldown.py index 6e94e903aa..7c2a0a1917 100644 --- a/examples/cooldown.py +++ b/examples/cooldown.py @@ -6,9 +6,7 @@ # an application command with cooldown @bot.slash_command() -@commands.cooldown( - 1, 5, commands.BucketType.user -) # the command can only be used once in 5 seconds +@commands.cooldown(1, 5, commands.BucketType.user) # the command can only be used once in 5 seconds async def slash(ctx): await ctx.respond("You can't use this command again in 5 seconds.") diff --git a/examples/custom_context.py b/examples/custom_context.py index e59f3503bd..5d628578a8 100644 --- a/examples/custom_context.py +++ b/examples/custom_context.py @@ -66,9 +66,7 @@ async def slash_guess(ctx, number: int): """Guess a random number from 1 to 6.""" value = random.randint(1, 6) if number == value: - await ctx.success( - "Congratulations! You guessed the number." - ) # use the new helper function + await ctx.success("Congratulations! You guessed the number.") # use the new helper function else: await ctx.respond("You are wrong! Try again.") diff --git a/examples/guessing_game.py b/examples/guessing_game.py index 4ff8d43a72..4953dc6c45 100644 --- a/examples/guessing_game.py +++ b/examples/guessing_game.py @@ -25,9 +25,7 @@ def is_correct(m): try: guess = await self.wait_for("message", check=is_correct, timeout=5.0) except asyncio.TimeoutError: - return await message.channel.send( - f"Sorry, you took too long it was {answer}." - ) + return await message.channel.send(f"Sorry, you took too long it was {answer}.") if int(guess.content) == answer: await message.channel.send("You are right!") diff --git a/examples/modal_dialogs.py b/examples/modal_dialogs.py index a64c284a0e..403e0750e8 100644 --- a/examples/modal_dialogs.py +++ b/examples/modal_dialogs.py @@ -69,12 +69,8 @@ async def button_callback(self, button, interaction): min_values=1, max_values=1, options=[ - discord.SelectOption( - label="First Modal", description="Shows the first modal" - ), - discord.SelectOption( - label="Second Modal", description="Shows the second modal" - ), + discord.SelectOption(label="First Modal", description="Shows the first modal"), + discord.SelectOption(label="Second Modal", description="Shows the second modal"), ], ) async def select_callback(self, select, interaction): diff --git a/examples/reaction_roles.py b/examples/reaction_roles.py index 93c2af80f4..1c4d5f672f 100644 --- a/examples/reaction_roles.py +++ b/examples/reaction_roles.py @@ -7,19 +7,11 @@ class MyClient(discord.Client): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) - self.role_message_id = ( - 0 # ID of the message that can be reacted to to add/remove a role. - ) + self.role_message_id = 0 # ID of the message that can be reacted to to add/remove a role. self.emoji_to_role = { - discord.PartialEmoji( - name="🔴" - ): 0, # ID of the role associated with unicode emoji '🔴'. - discord.PartialEmoji( - name="🟡" - ): 0, # ID of the role associated with unicode emoji '🟡'. - discord.PartialEmoji( - name="green", id=0 - ): 0, # ID of the role associated with a partial emoji's ID. + discord.PartialEmoji(name="🔴"): 0, # ID of the role associated with unicode emoji '🔴'. + discord.PartialEmoji(name="🟡"): 0, # ID of the role associated with unicode emoji '🟡'. + discord.PartialEmoji(name="green", id=0): 0, # ID of the role associated with a partial emoji's ID. } async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent): diff --git a/examples/secret.py b/examples/secret.py index c8b267f685..92b51379ed 100644 --- a/examples/secret.py +++ b/examples/secret.py @@ -3,9 +3,7 @@ import discord from discord.ext import commands -bot = commands.Bot( - command_prefix=commands.when_mentioned, description="Nothing to see here!" -) +bot = commands.Bot(command_prefix=commands.when_mentioned, description="Nothing to see here!") # the `hidden` keyword argument hides it from the help command. @bot.group(hidden=True) @@ -29,15 +27,11 @@ def create_overwrites(ctx, *objects): # A dict comprehension is being utilised here to set the same permission overwrites # For each `discord.Role` or `discord.Member`. - overwrites = { - obj: discord.PermissionOverwrite(view_channel=True) for obj in objects - } + overwrites = {obj: discord.PermissionOverwrite(view_channel=True) for obj in objects} # Prevents the default role (@everyone) from viewing the channel # if it isn't already allowed to view the channel. - overwrites.setdefault( - ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False) - ) + overwrites.setdefault(ctx.guild.default_role, discord.PermissionOverwrite(view_channel=False)) # Makes sure the client is always allowed to view the channel. overwrites[ctx.guild.me] = discord.PermissionOverwrite(view_channel=True) @@ -81,16 +75,12 @@ async def voice( overwrites = create_overwrites(ctx, *objects) - await ctx.guild.create_voice_channel( - name, overwrites=overwrites, reason="Very secret business." - ) + await ctx.guild.create_voice_channel(name, overwrites=overwrites, reason="Very secret business.") @secret.command() @commands.guild_only() -async def emoji( - ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role -): +async def emoji(ctx: commands.Context, emoji: discord.PartialEmoji, *roles: discord.Role): """This clones a specified emoji that only specified roles are allowed to use. """ @@ -100,9 +90,7 @@ async def emoji( # The key parameter here is `roles`, which controls # What roles are able to use the emoji. - await ctx.guild.create_custom_emoji( - name=emoji.name, image=emoji_bytes, roles=roles, reason="Very secret business." - ) + await ctx.guild.create_custom_emoji(name=emoji.name, image=emoji_bytes, roles=roles, reason="Very secret business.") bot.run("token") diff --git a/examples/views/button_roles.py b/examples/views/button_roles.py index 8591fd31e2..4d0bb49516 100644 --- a/examples/views/button_roles.py +++ b/examples/views/button_roles.py @@ -48,9 +48,7 @@ async def callback(self, interaction: discord.Interaction): if role not in user.roles: # Give the user the role if they don't already have it. await user.add_roles(role) - await interaction.response.send_message( - f"🎉 You have been given the role {role.mention}", ephemeral=True - ) + await interaction.response.send_message(f"🎉 You have been given the role {role.mention}", ephemeral=True) else: # Else, Take the role from the user await user.remove_roles(role) diff --git a/examples/views/confirm.py b/examples/views/confirm.py index ef996f0c0d..a85221b4fa 100644 --- a/examples/views/confirm.py +++ b/examples/views/confirm.py @@ -21,9 +21,7 @@ def __init__(self): # Stop the View from listening to more input. # We also send the user an ephemeral message that we're confirming their choice. @discord.ui.button(label="Confirm", style=discord.ButtonStyle.green) - async def confirm( - self, button: discord.ui.Button, interaction: discord.Interaction - ): + async def confirm(self, button: discord.ui.Button, interaction: discord.Interaction): await interaction.response.send_message("Confirming", ephemeral=True) self.value = True self.stop() diff --git a/examples/views/dropdown.py b/examples/views/dropdown.py index 7d43af76f1..e60e4a97cd 100644 --- a/examples/views/dropdown.py +++ b/examples/views/dropdown.py @@ -12,15 +12,9 @@ def __init__(self): # Set the options that will be presented inside the dropdown options = [ - discord.SelectOption( - label="Red", description="Your favourite colour is red", emoji="🟥" - ), - discord.SelectOption( - label="Green", description="Your favourite colour is green", emoji="🟩" - ), - discord.SelectOption( - label="Blue", description="Your favourite colour is blue", emoji="🟦" - ), + discord.SelectOption(label="Red", description="Your favourite colour is red", emoji="🟥"), + discord.SelectOption(label="Green", description="Your favourite colour is green", emoji="🟩"), + discord.SelectOption(label="Blue", description="Your favourite colour is blue", emoji="🟦"), ] # The placeholder is what will be shown when no option is chosen @@ -38,9 +32,7 @@ async def callback(self, interaction: discord.Interaction): # The user's favourite colour or choice. The self object refers to the # Select object, and the values attribute gets a list of the user's # selected options. We only want the first one. - await interaction.response.send_message( - f"Your favourite colour is {self.values[0]}" - ) + await interaction.response.send_message(f"Your favourite colour is {self.values[0]}") class DropdownView(discord.ui.View): diff --git a/examples/views/ephemeral.py b/examples/views/ephemeral.py index 6479d33cdf..6af6bafafa 100644 --- a/examples/views/ephemeral.py +++ b/examples/views/ephemeral.py @@ -35,13 +35,9 @@ class EphemeralCounter(discord.ui.View): # When this button is pressed, it will respond with a Counter view that will # give the button presser their own personal button they can press 5 times. @discord.ui.button(label="Click", style=discord.ButtonStyle.blurple) - async def receive( - self, button: discord.ui.Button, interaction: discord.Interaction - ): + async def receive(self, button: discord.ui.Button, interaction: discord.Interaction): # ephemeral=True makes the message hidden from everyone except the button presser - await interaction.response.send_message( - "Enjoy!", view=Counter(), ephemeral=True - ) + await interaction.response.send_message("Enjoy!", view=Counter(), ephemeral=True) bot = EphemeralCounterBot() diff --git a/examples/views/paginator.py b/examples/views/paginator.py index a01749055e..02fedaf1c8 100644 --- a/examples/views/paginator.py +++ b/examples/views/paginator.py @@ -24,15 +24,9 @@ def __init__(self, bot): discord.Embed(title="Page Seven, Embed 2"), ], ] - self.pages[3].set_image( - url="https://c.tenor.com/pPKOYQpTO8AAAAAM/monkey-developer.gif" - ) - self.pages[4].add_field( - name="Example Field", value="Example Value", inline=False - ) - self.pages[4].add_field( - name="Another Example Field", value="Another Example Value", inline=False - ) + self.pages[3].set_image(url="https://c.tenor.com/pPKOYQpTO8AAAAAM/monkey-developer.gif") + self.pages[4].add_field(name="Example Field", value="Example Value", inline=False) + self.pages[4].add_field(name="Another Example Field", value="Another Example Value", inline=False) self.more_pages = [ "Second Page One", @@ -69,17 +63,13 @@ async def pagetest_loop(self, ctx: discord.ApplicationContext): @pagetest.command(name="strings") async def pagetest_strings(self, ctx: discord.ApplicationContext): """Demonstrates passing a list of strings as pages.""" - paginator = pages.Paginator( - pages=["Page 1", "Page 2", "Page 3"], loop_pages=True - ) + paginator = pages.Paginator(pages=["Page 1", "Page 2", "Page 3"], loop_pages=True) await paginator.respond(ctx.interaction, ephemeral=False) @pagetest.command(name="timeout") async def pagetest_timeout(self, ctx: discord.ApplicationContext): """Demonstrates having the buttons be disabled when the paginator view times out.""" - paginator = pages.Paginator( - pages=self.get_pages(), disable_on_timeout=True, timeout=30 - ) + paginator = pages.Paginator(pages=self.get_pages(), disable_on_timeout=True, timeout=30) await paginator.respond(ctx.interaction, ephemeral=False) @pagetest.command(name="remove_buttons") @@ -94,13 +84,9 @@ async def pagetest_remove(self, ctx: discord.ApplicationContext): async def pagetest_init(self, ctx: discord.ApplicationContext): """Demonstrates how to pass a list of custom buttons when creating the Paginator instance.""" pagelist = [ - pages.PaginatorButton( - "first", label="<<-", style=discord.ButtonStyle.green - ), + pages.PaginatorButton("first", label="<<-", style=discord.ButtonStyle.green), pages.PaginatorButton("prev", label="<-", style=discord.ButtonStyle.green), - pages.PaginatorButton( - "page_indicator", style=discord.ButtonStyle.gray, disabled=True - ), + pages.PaginatorButton("page_indicator", style=discord.ButtonStyle.gray, disabled=True), pages.PaginatorButton("next", label="->", style=discord.ButtonStyle.green), pages.PaginatorButton("last", label="->>", style=discord.ButtonStyle.green), ] @@ -124,20 +110,10 @@ async def pagetest_custom_buttons(self, ctx: discord.ApplicationContext): show_disabled=False, ) paginator.add_button( - pages.PaginatorButton( - "prev", label="<", style=discord.ButtonStyle.green, loop_label="lst" - ) - ) - paginator.add_button( - pages.PaginatorButton( - "page_indicator", style=discord.ButtonStyle.gray, disabled=True - ) - ) - paginator.add_button( - pages.PaginatorButton( - "next", style=discord.ButtonStyle.green, loop_label="fst" - ) + pages.PaginatorButton("prev", label="<", style=discord.ButtonStyle.green, loop_label="lst") ) + paginator.add_button(pages.PaginatorButton("page_indicator", style=discord.ButtonStyle.gray, disabled=True)) + paginator.add_button(pages.PaginatorButton("next", style=discord.ButtonStyle.green, loop_label="fst")) await paginator.respond(ctx.interaction, ephemeral=False) @pagetest.command(name="emoji_buttons") @@ -146,9 +122,7 @@ async def pagetest_emoji_buttons(self, ctx: discord.ApplicationContext): page_buttons = [ pages.PaginatorButton("first", emoji="⏪", style=discord.ButtonStyle.green), pages.PaginatorButton("prev", emoji="⬅", style=discord.ButtonStyle.green), - pages.PaginatorButton( - "page_indicator", style=discord.ButtonStyle.gray, disabled=True - ), + pages.PaginatorButton("page_indicator", style=discord.ButtonStyle.gray, disabled=True), pages.PaginatorButton("next", emoji="➡", style=discord.ButtonStyle.green), pages.PaginatorButton("last", emoji="⏩", style=discord.ButtonStyle.green), ] @@ -212,13 +186,9 @@ async def pagetest_cancel(self, ctx: discord.ApplicationContext): async def pagetest_groups(self, ctx: discord.ApplicationContext): """Demonstrates using page groups to switch between different sets of pages.""" page_buttons = [ - pages.PaginatorButton( - "first", label="<<-", style=discord.ButtonStyle.green - ), + pages.PaginatorButton("first", label="<<-", style=discord.ButtonStyle.green), pages.PaginatorButton("prev", label="<-", style=discord.ButtonStyle.green), - pages.PaginatorButton( - "page_indicator", style=discord.ButtonStyle.gray, disabled=True - ), + pages.PaginatorButton("page_indicator", style=discord.ButtonStyle.gray, disabled=True), pages.PaginatorButton("next", label="->", style=discord.ButtonStyle.green), pages.PaginatorButton("last", label="->>", style=discord.ButtonStyle.green), ] @@ -276,17 +246,9 @@ async def pagetest_target(self, ctx: discord.ApplicationContext): async def pagetest_prefix(self, ctx: commands.Context): """Demonstrates using the paginator with a prefix-based command.""" paginator = pages.Paginator(pages=self.get_pages(), use_default_buttons=False) - paginator.add_button( - pages.PaginatorButton("prev", label="<", style=discord.ButtonStyle.green) - ) - paginator.add_button( - pages.PaginatorButton( - "page_indicator", style=discord.ButtonStyle.gray, disabled=True - ) - ) - paginator.add_button( - pages.PaginatorButton("next", style=discord.ButtonStyle.green) - ) + paginator.add_button(pages.PaginatorButton("prev", label="<", style=discord.ButtonStyle.green)) + paginator.add_button(pages.PaginatorButton("page_indicator", style=discord.ButtonStyle.gray, disabled=True)) + paginator.add_button(pages.PaginatorButton("next", style=discord.ButtonStyle.green)) await paginator.send(ctx) @commands.command() diff --git a/examples/views/persistent.py b/examples/views/persistent.py index 02e2371090..c6e03c3b2b 100644 --- a/examples/views/persistent.py +++ b/examples/views/persistent.py @@ -22,15 +22,11 @@ def __init__(self): async def green(self, button: discord.ui.Button, interaction: discord.Interaction): await interaction.response.send_message("This is green.", ephemeral=True) - @discord.ui.button( - label="Red", style=discord.ButtonStyle.red, custom_id="persistent_view:red" - ) + @discord.ui.button(label="Red", style=discord.ButtonStyle.red, custom_id="persistent_view:red") async def red(self, button: discord.ui.Button, interaction: discord.Interaction): await interaction.response.send_message("This is red.", ephemeral=True) - @discord.ui.button( - label="Grey", style=discord.ButtonStyle.grey, custom_id="persistent_view:grey" - ) + @discord.ui.button(label="Grey", style=discord.ButtonStyle.grey, custom_id="persistent_view:grey") async def grey(self, button: discord.ui.Button, interaction: discord.Interaction): await interaction.response.send_message("This is grey.", ephemeral=True)