Skip to content

Commit

Permalink
Merge pull request Pycord-Development#417 from GodusOV/master
Browse files Browse the repository at this point in the history
Adding file and files keyword arguments to discord.Message.edit
  • Loading branch information
BobDotCom authored Nov 11, 2021
2 parents c5eef49 + 2a322cb commit 4d8b3d9
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 2 deletions.
60 changes: 60 additions & 0 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@
sticker,
)
from .types.snowflake import Snowflake, SnowflakeList
from .types.message import Attachment

from types import TracebackType

Expand Down Expand Up @@ -548,6 +549,65 @@ def send_files(
stickers=stickers,
components=components,
)

def edit_multipart_helper(
self,
route: Route,
files: Sequence[File],
**payload,
) -> Response[message.Message]:
form = []

form.append({'name': 'payload_json', 'value': utils._to_json(payload)})
if len(files) == 1:
file = files[0]
form.append(
{
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)
else:
for index, file in enumerate(files):
form.append(
{
'name': f'file{index}',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream',
}
)

return self.request(route, form=form, files=files)

def edit_files(
self,
channel_id: Snowflake,
message_id: Snowflake,
files: Sequence[File],
**fields,
) -> Response[message.Message]:
r = Route('PATCH', f'/channels/{channel_id}/messages/{message_id}', channel_id=channel_id, message_id=message_id)
payload: Dict[str, Any] = {}
if 'attachments' in fields:
payload['attachments'] = fields['attachments']
if 'flags' in fields:
payload['flags'] = fields['flags']
if 'content' in fields:
payload['content'] = fields['content']
if 'embeds' in fields:
payload['embeds'] = fields['embeds']
if 'allowed_mentions' in fields:
payload['allowed_mentions'] = fields['allowed_mentions']
if 'components' in fields:
payload['components'] = fields['components']
return self.edit_multipart_helper(
r,
files=files,
**payload,
)

def delete_message(
self, channel_id: Snowflake, message_id: Snowflake, *, reason: Optional[str] = None
Expand Down
63 changes: 61 additions & 2 deletions discord/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -1156,6 +1156,22 @@ async def edit(
*,
content: Optional[str] = ...,
embed: Optional[Embed] = ...,
file: Optional[File] = ...,
attachments: List[Attachment] = ...,
suppress: bool = ...,
delete_after: Optional[float] = ...,
allowed_mentions: Optional[AllowedMentions] = ...,
view: Optional[View] = ...,
) -> Message:
...

@overload
async def edit(
self,
*,
content: Optional[str] = ...,
embed: Optional[Embed] = ...,
files: Optional[List[File]] = ...,
attachments: List[Attachment] = ...,
suppress: bool = ...,
delete_after: Optional[float] = ...,
Expand All @@ -1170,6 +1186,7 @@ async def edit(
*,
content: Optional[str] = ...,
embeds: List[Embed] = ...,
file: File = ...,
attachments: List[Attachment] = ...,
suppress: bool = ...,
delete_after: Optional[float] = ...,
Expand All @@ -1183,6 +1200,8 @@ async def edit(
content: Optional[str] = MISSING,
embed: Optional[Embed] = MISSING,
embeds: List[Embed] = MISSING,
file: Sequence[File] = MISSING,
files: List[Sequence[File]] = MISSING,
attachments: List[Attachment] = MISSING,
suppress: bool = MISSING,
delete_after: Optional[float] = None,
Expand Down Expand Up @@ -1211,6 +1230,10 @@ async def edit(
To remove all embeds ``[]`` should be passed.
.. versionadded:: 2.0
file: Sequence[:class:`File`]
A new file to add to the message.
files: List[Sequence[:class:`File`]]
New files to add to the message.
attachments: List[:class:`Attachment`]
A list of attachments to keep in the message. If ``[]`` is passed
then all attachments are removed.
Expand Down Expand Up @@ -1244,7 +1267,9 @@ async def edit(
Tried to suppress a message without permissions or
edited a message's content or embed that isn't yours.
~discord.InvalidArgument
You specified both ``embed`` and ``embeds``
You specified both ``embed`` and ``embeds``,
specified both ``file`` and ``files``, or either``file``
or ``files`` were of the wrong type.
"""

payload: Dict[str, Any] = {}
Expand Down Expand Up @@ -1289,8 +1314,42 @@ async def edit(
payload['components'] = view.to_components()
else:
payload['components'] = []

if file is not MISSING and files is not MISSING:
raise InvalidArgument('cannot pass both file and files parameter to edit()')

if file is not MISSING:
if not isinstance(file, File):
raise InvalidArgument('file parameter must be File')

data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
try:
data = await self._state.http.edit_files(
self.channel.id,
self.id,
files=[file],
**payload,
)
finally:
file.close()

elif files is not MISSING:
if len(files) > 10:
raise InvalidArgument('files parameter must be a list of up to 10 elements')
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument('files parameter must be a list of File')

try:
data = await self._state.http.edit_files(
self.channel.id,
self.id,
files=files,
**payload,
)
finally:
for f in files:
f.close()
else:
data = await self._state.http.edit_message(self.channel.id, self.id, **payload)
message = Message(state=self._state, channel=self.channel, data=data)

if view and not view.is_finished():
Expand Down

0 comments on commit 4d8b3d9

Please sign in to comment.