Skip to content

Commit

Permalink
Type up **kwargs of various methods
Browse files Browse the repository at this point in the history
  • Loading branch information
NCPlayz authored May 11, 2021
1 parent 8bc489d commit 757cfad
Show file tree
Hide file tree
Showing 14 changed files with 453 additions and 72 deletions.
68 changes: 62 additions & 6 deletions discord/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

import copy
import asyncio
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, TypeVar, Union, overload, runtime_checkable
from typing import Any, Dict, List, Mapping, Optional, TYPE_CHECKING, Protocol, Type, TypeVar, Union, overload, runtime_checkable

from .iterators import HistoryIterator
from .context_managers import Typing
Expand All @@ -49,6 +49,8 @@
'Connectable',
)

T = TypeVar('T', bound=VoiceProtocol)

if TYPE_CHECKING:
from datetime import datetime

Expand All @@ -58,7 +60,8 @@
from .guild import Guild
from .member import Member
from .channel import CategoryChannel

from .embeds import Embed
from .message import Message, MessageReference

MISSING = utils.MISSING

Expand Down Expand Up @@ -95,6 +98,7 @@ def created_at(self) -> datetime:
""":class:`datetime.datetime`: Returns the model's creation time as an aware datetime in UTC."""
raise NotImplementedError

SnowflakeTime = Union[Snowflake, datetime]

@runtime_checkable
class User(Snowflake, Protocol):
Expand Down Expand Up @@ -653,14 +657,34 @@ async def delete(self, *, reason: Optional[str] = None) -> None:
"""
await self._state.http.delete_channel(self.id, reason=reason)

@overload
async def set_permissions(
self,
target: Union[Member, Role],
*,
overwrite: Optional[PermissionOverwrite] = _undefined,
reason: Optional[str] = None,
overwrite: Optional[Union[PermissionOverwrite, _Undefined]] = ...,
reason: Optional[str] = ...,
) -> None:
...

@overload
async def set_permissions(
self,
target: Union[Member, Role],
*,
reason: Optional[str] = ...,
**permissions: bool,
) -> None:
...

async def set_permissions(
self,
target,
*,
overwrite=_undefined,
reason=None,
**permissions
):
r"""|coro|
Sets the channel specific permission overwrites for a target in the
Expand Down Expand Up @@ -815,7 +839,7 @@ async def move(
offset: int = MISSING,
category: Optional[Snowflake] = MISSING,
sync_permissions: bool = MISSING,
reason: str = MISSING,
reason: Optional[str] = MISSING,
) -> None:
...

Expand Down Expand Up @@ -1091,6 +1115,38 @@ class Messageable(Protocol):
async def _get_channel(self):
raise NotImplementedError

@overload
async def send(
self,
content: Optional[str] =...,
*,
tts: bool = ...,
embed: Embed = ...,
file: File = ...,
delete_after: int = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
) -> Message:
...

@overload
async def send(
self,
content: Optional[str] = ...,
*,
tts: bool = ...,
embed: Embed = ...,
files: List[File] = ...,
delete_after: int = ...,
nonce: Union[str, int] = ...,
allowed_mentions: AllowedMentions = ...,
reference: Union[Message, MessageReference] = ...,
mention_author: bool = ...,
) -> Message:
...

async def send(self, content=None, *, tts=False, embed=None, file=None,
files=None, delete_after=None, nonce=None,
allowed_mentions=None, reference=None,
Expand Down Expand Up @@ -1402,7 +1458,7 @@ def _get_voice_client_key(self):
def _get_voice_state_pair(self):
raise NotImplementedError

async def connect(self, *, timeout=60.0, reconnect=True, cls=VoiceClient):
async def connect(self, *, timeout: float = 60.0, reconnect: bool = True, cls: Type[T] = VoiceClient) -> T:
"""|coro|
Connects to voice and creates a :class:`VoiceClient` to establish
Expand Down
138 changes: 127 additions & 11 deletions discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
DEALINGS IN THE SOFTWARE.
"""

from __future__ import annotations

import time
import asyncio
from typing import Callable, Dict, List, Optional, TYPE_CHECKING, Union, overload

import discord.abc
from .permissions import Permissions
from .permissions import PermissionOverwrite, Permissions
from .enums import ChannelType, try_enum, VoiceRegion, VideoQualityMode
from .mixins import Hashable
from . import utils
Expand All @@ -44,6 +47,14 @@
'_channel_factory',
)

if TYPE_CHECKING:
from .role import Role
from .member import Member
from .abc import Snowflake
from .message import Message
from .webhook import Webhook
from .abc import SnowflakeTime

async def _single_delete_strategy(messages):
for m in messages:
await m.delete()
Expand Down Expand Up @@ -190,6 +201,27 @@ def last_message(self):
"""
return self._state._get_message(self.last_message_id) if self.last_message_id else None

@overload
async def edit(
self,
*,
reason: Optional[str] = ...,
name: str = ...,
topic: str = ...,
position: int = ...,
nsfw: bool = ...,
sync_permissions: bool = ...,
category: Optional[CategoryChannel] = ...,
slowmode_delay: int = ...,
type: ChannelType = ...,
overwrites: Dict[Union[Role, Member, Snowflake], PermissionOverwrite] = ...,
) -> None:
...

@overload
async def edit(self) -> None:
...

async def edit(self, *, reason=None, **options):
"""|coro|
Expand Down Expand Up @@ -246,7 +278,7 @@ async def edit(self, *, reason=None, **options):
await self._edit(options, reason=reason)

@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name=None, reason=None):
async def clone(self, *, name: str = None, reason: str = None) -> TextChannel:
return await self._clone_impl({
'topic': self.topic,
'nsfw': self.nsfw,
Expand Down Expand Up @@ -302,7 +334,17 @@ async def delete_messages(self, messages):
message_ids = [m.id for m in messages]
await self._state.http.delete_messages(self.id, message_ids)

async def purge(self, *, limit=100, check=None, before=None, after=None, around=None, oldest_first=False, bulk=True):
async def purge(
self,
*,
limit: int = 100,
check: Callable[[Message], bool] = None,
before: Optional[SnowflakeTime] = None,
after: Optional[SnowflakeTime] = None,
around: Optional[SnowflakeTime] = None,
oldest_first: Optional[bool] = False,
bulk: bool = True,
) -> List[Message]:
"""|coro|
Purges a list of messages that meet the criteria given by the predicate
Expand Down Expand Up @@ -428,7 +470,7 @@ async def webhooks(self):
data = await self._state.http.channel_webhooks(self.id)
return [Webhook.from_state(d, state=self._state) for d in data]

async def create_webhook(self, *, name, avatar=None, reason=None):
async def create_webhook(self, *, name: str, avatar: bytes = None, reason: str = None) -> Webhook:
"""|coro|
Creates a webhook for this channel.
Expand Down Expand Up @@ -468,7 +510,7 @@ async def create_webhook(self, *, name, avatar=None, reason=None):
data = await self._state.http.create_webhook(self.id, name=str(name), avatar=avatar, reason=reason)
return Webhook.from_state(data, state=self._state)

async def follow(self, *, destination, reason=None):
async def follow(self, *, destination: TextChannel, reason: Optional[str] = None) -> Webhook:
"""
Follows a channel using a webhook.
Expand Down Expand Up @@ -680,12 +722,33 @@ def type(self):
return ChannelType.voice

@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name=None, reason=None):
async def clone(self, *, name: str = None, reason: str = None) -> VoiceChannel:
return await self._clone_impl({
'bitrate': self.bitrate,
'user_limit': self.user_limit
}, name=name, reason=reason)

@overload
async def edit(
self,
*,
reason: Optional[str] = ...,
name: str = ...,
bitrate: int = ...,
user_limit: int = ...,
position: int = ...,
sync_permissions: int = ...,
category: Optional[CategoryChannel] = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
) -> None:
...

@overload
async def edit(self) -> None:
...

async def edit(self, *, reason=None, **options):
"""|coro|
Expand Down Expand Up @@ -822,11 +885,31 @@ def type(self):
return ChannelType.stage_voice

@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name=None, reason=None):
async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StageChannel:
return await self._clone_impl({
'topic': self.topic,
}, name=name, reason=reason)

@overload
async def edit(
self,
*,
reason: Optional[str] = ...,
name: str = ...,
topic: Optional[str] = ...,
position: int = ...,
sync_permissions: int = ...,
category: Optional[CategoryChannel] = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
rtc_region: Optional[VoiceRegion] = ...,
video_quality_mode: VideoQualityMode = ...,
) -> None:
...

@overload
async def edit(self) -> None:
...

async def edit(self, *, reason=None, **options):
"""|coro|
Expand All @@ -839,7 +922,7 @@ async def edit(self, *, reason=None, **options):
----------
name: :class:`str`
The new channel's name.
topic: :class:`str`
topic: Optional[:class:`str`]
The new channel's topic.
position: :class:`int`
The new channel's position.
Expand Down Expand Up @@ -873,7 +956,6 @@ async def edit(self, *, reason=None, **options):
"""

await self._edit(options, reason=reason)

class CategoryChannel(discord.abc.GuildChannel, Hashable):
"""Represents a Discord channel category.
Expand Down Expand Up @@ -948,11 +1030,27 @@ def is_nsfw(self):
return self.nsfw or self.guild.nsfw

@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name=None, reason=None):
async def clone(self, *, name: str = None, reason: Optional[str] = None) -> CategoryChannel:
return await self._clone_impl({
'nsfw': self.nsfw
}, name=name, reason=reason)

@overload
async def edit(
self,
*,
reason: Optional[str] = ...,
name: str = ...,
position: int = ...,
nsfw: bool = ...,
overwrites: Dict[Union[Role, Member], PermissionOverwrite] = ...,
) -> None:
...

@overload
async def edit(self) -> None:
...

async def edit(self, *, reason=None, **options):
"""|coro|
Expand Down Expand Up @@ -1159,11 +1257,29 @@ def is_nsfw(self):
return self.nsfw or self.guild.nsfw

@utils.copy_doc(discord.abc.GuildChannel.clone)
async def clone(self, *, name=None, reason=None):
async def clone(self, *, name: str = None, reason: Optional[str] = None) -> StoreChannel:
return await self._clone_impl({
'nsfw': self.nsfw
}, name=name, reason=reason)

@overload
async def edit(
self,
*,
name: str = ...,
position: int = ...,
nsfw: bool = ...,
sync_permissions: bool = ...,
category: Optional[CategoryChannel],
reason: Optional[str],
overwrites: Dict[Union[Role, Member], PermissionOverwrite]
) -> None:
...

@overload
async def edit(self) -> None:
...

async def edit(self, *, reason=None, **options):
"""|coro|
Expand Down
Loading

0 comments on commit 757cfad

Please sign in to comment.