Skip to content

Commit

Permalink
Recreate aiohttp.FormData objects during request retries
Browse files Browse the repository at this point in the history
  • Loading branch information
Rapptz committed Mar 24, 2021
1 parent 09e2e39 commit aae6f49
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
31 changes: 24 additions & 7 deletions discord/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ async def ws_connect(self, url, *, compress=0):

return await self.__session.ws_connect(url, **kwargs)

async def request(self, route, *, files=None, **kwargs):
async def request(self, route, *, files=None, form=None, **kwargs):
bucket = route.bucket
method = route.method
url = route.url
Expand Down Expand Up @@ -181,6 +181,13 @@ async def request(self, route, *, files=None, **kwargs):
if files:
for f in files:
f.reset(seek=tries)

if form:
form_data = aiohttp.FormData()
for params in form:
form_data.add_field(**params)
kwargs['data'] = form_data

try:
async with self.__session.request(method, url, **kwargs) as r:
log.debug('%s %s with %s has returned %s', method, url, kwargs.get('data'), r.status)
Expand Down Expand Up @@ -371,7 +378,7 @@ def send_typing(self, channel_id):

def send_files(self, channel_id, *, files, content=None, tts=False, embed=None, nonce=None, allowed_mentions=None, message_reference=None):
r = Route('POST', '/channels/{channel_id}/messages', channel_id=channel_id)
form = aiohttp.FormData()
form = []

payload = {'tts': tts}
if content:
Expand All @@ -385,15 +392,25 @@ def send_files(self, channel_id, *, files, content=None, tts=False, embed=None,
if message_reference:
payload['message_reference'] = message_reference

form.add_field('payload_json', utils.to_json(payload))
form.append({'name': 'payload_json', 'value': utils.to_json(payload)})
if len(files) == 1:
file = files[0]
form.add_field('file', file.fp, filename=file.filename, content_type='application/octet-stream')
form.append({
'name': 'file',
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream'
})
else:
for index, file in enumerate(files):
form.add_field('file%s' % index, file.fp, filename=file.filename, content_type='application/octet-stream')

return self.request(r, data=form, files=files)
form.append({
'name': 'file%s' % index,
'value': file.fp,
'filename': file.filename,
'content_type': 'application/octet-stream'
})

return self.request(r, form=form, files=files)

async def ack_message(self, channel_id, message_id):
r = Route('POST', '/channels/{channel_id}/messages/{message_id}/ack', channel_id=channel_id, message_id=message_id)
Expand Down
15 changes: 8 additions & 7 deletions discord/webhook.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,20 +203,21 @@ async def request(self, verb, url, payload=None, multipart=None, *, files=None,
if reason:
headers['X-Audit-Log-Reason'] = _uriquote(reason, safe='/ ')

if multipart:
data = aiohttp.FormData()
for key, value in multipart.items():
if key.startswith('file'):
data.add_field(key, value[1], filename=value[0], content_type=value[2])
else:
data.add_field(key, value)

base_url = url.replace(self._request_url, '/') or '/'
_id = self._webhook_id
for tries in range(5):
for file in files:
file.reset(seek=tries)

if multipart:
data = aiohttp.FormData()
for key, value in multipart.items():
if key.startswith('file'):
data.add_field(key, value[1], filename=value[0], content_type=value[2])
else:
data.add_field(key, value)

async with self.session.request(verb, url, headers=headers, data=data) as r:
log.debug('Webhook ID %s with %s %s has returned status code %s', _id, verb, base_url, r.status)
# Coerce empty strings to return None for hygiene purposes
Expand Down

0 comments on commit aae6f49

Please sign in to comment.