Skip to content

Commit

Permalink
fix: allow creating forum threads with files
Browse files Browse the repository at this point in the history
  • Loading branch information
elliotcubit committed May 14, 2023
1 parent 6a69f66 commit 2c828c4
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 38 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ These changes are available on the `master` branch, but have not yet been releas
([#2048](https://github.com/Pycord-Development/pycord/pull/2048))
- Fixed the Slash command syncronization method `indiviual`.
([#1925](https://github.com/Pycord-Development/pycord/pull/1925))
- Fixed `HttpException` when trying to create a Forum thread with files.
([#2075](https://github.com/Pycord-Development/pycord/pull/2075))

## [2.4.1] - 2023-03-20

Expand Down
49 changes: 13 additions & 36 deletions discord/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,53 +1274,25 @@ async def create_thread(
if file is not None and files is not None:
raise InvalidArgument("cannot pass both file and files parameter to send()")

if file is not None:
if not isinstance(file, File):
raise InvalidArgument("file parameter must be File")

try:
data = await state.http.send_files(
self.id,
files=[file],
allowed_mentions=allowed_mentions,
content=message_content,
embed=embed,
embeds=embeds,
nonce=nonce,
stickers=stickers,
components=components,
)
finally:
file.close()

elif files is not None:
if files is not None:
if len(files) > 10:
raise InvalidArgument(
"files parameter must be a list of up to 10 elements"
)
elif not all(isinstance(file, File) for file in files):
raise InvalidArgument("files parameter must be a list of File")

try:
data = await state.http.send_files(
self.id,
files=files,
content=message_content,
embed=embed,
embeds=embeds,
nonce=nonce,
allowed_mentions=allowed_mentions,
stickers=stickers,
components=components,
)
finally:
for f in files:
f.close()
else:
if file is not None:
if not isinstance(file, File):
raise InvalidArgument("file parameter must be File")
files = [file]

try:
data = await state.http.start_forum_thread(
self.id,
content=message_content,
name=name,
files=files,
embed=embed,
embeds=embeds,
nonce=nonce,
Expand All @@ -1333,6 +1305,11 @@ async def create_thread(
applied_tags=applied_tags,
reason=reason,
)
finally:
if files is not None:
for f in files:
f.close()

ret = Thread(guild=self.guild, state=self._state, data=data)
msg = ret.get_partial_message(data["last_message_id"])
if view:
Expand Down
29 changes: 27 additions & 2 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -1170,14 +1170,15 @@ def start_forum_thread(
invitable: bool = True,
applied_tags: SnowflakeList | None = None,
reason: str | None = None,
files: Sequence[File] | None = None,
embed: embed.Embed | None = None,
embeds: list[embed.Embed] | None = None,
nonce: str | None = None,
allowed_mentions: message.AllowedMentions | None = None,
stickers: list[sticker.StickerItem] | None = None,
components: list[components.Component] | None = None,
) -> Response[threads.Thread]:
payload = {
payload: dict[str, Any] = {
"name": name,
"auto_archive_duration": auto_archive_duration,
"invitable": invitable,
Expand Down Expand Up @@ -1208,13 +1209,37 @@ def start_forum_thread(

if rate_limit_per_user:
payload["rate_limit_per_user"] = rate_limit_per_user

form = [{"name": "payload_json"}]
if files:
attachments = []
for index, file in enumerate(files):
attachments.append(
{
"id": index,
"filename": file.filename,
"description": file.description,
}
)
form.append(
{
"name": f"files[{index}]",
"value": file.fp,
"filename": file.filename,
"content_type": "application/octet-stream",
}
)
payload["attachments"] = attachments

form[0]["value"] = utils._to_json(payload)

# TODO: Once supported by API, remove has_message=true query parameter
route = Route(
"POST",
"/channels/{channel_id}/threads?has_message=true",
channel_id=channel_id,
)
return self.request(route, json=payload, reason=reason)
return self.request(route, form=form, reason=reason)

def join_thread(self, channel_id: Snowflake) -> Response[None]:
return self.request(
Expand Down

0 comments on commit 2c828c4

Please sign in to comment.