Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typings in multipart #3622

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGES/3621.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Improve typing annotations for multipart.

Use `async for` instead of `while` loop for
reading full multipart data in `web_request.Request#post` functionality.
19 changes: 11 additions & 8 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ class BodyPartReader:
def __init__(
self,
boundary: bytes,
headers: Mapping[str, Optional[str]],
headers: Mapping[str, str],
content: StreamReader,
*,
_newline: bytes = b'\r\n'
Expand Down Expand Up @@ -443,7 +443,7 @@ def decode(self, data: bytes) -> bytes:
return data

def _decode_content(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_ENCODING]).lower()
encoding = self.headers[CONTENT_ENCODING].lower()

if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
Expand All @@ -455,7 +455,7 @@ def _decode_content(self, data: bytes) -> bytes:
raise RuntimeError('unknown content encoding: {}'.format(encoding))

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_TRANSFER_ENCODING]).lower()
encoding = self.headers[CONTENT_TRANSFER_ENCODING].lower()

if encoding == 'base64':
return base64.b64decode(data)
Expand Down Expand Up @@ -547,7 +547,7 @@ def __init__(
def __aiter__(self) -> 'MultipartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(self) -> Union['MultipartReader', BodyPartReader]:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
Expand All @@ -569,19 +569,19 @@ def at_eof(self) -> bool:
"""
return self._at_eof

async def next(self) -> Any:
async def next(self) -> Optional[Union['MultipartReader', BodyPartReader]]:
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return
return None
await self._maybe_release_last_part()
if self._at_bof:
await self._read_until_first_boundary()
self._at_bof = False
else:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return
return None
self._last_part = await self.fetch_next_part()
return self._last_part

Expand All @@ -598,7 +598,10 @@ async def fetch_next_part(self) -> Any:
headers = await self._read_headers()
return self._get_part_reader(headers)

def _get_part_reader(self, headers: 'CIMultiDictProxy[str]') -> Any:
def _get_part_reader(
self,
headers: 'CIMultiDictProxy[str]',
) -> Union['MultipartReader', BodyPartReader]:
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.

Expand Down
74 changes: 39 additions & 35 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from .abc import AbstractStreamWriter
from .helpers import DEBUG, ChainMapProxy, HeadersMixin, reify, sentinel
from .http_parser import RawRequestMessage
from .multipart import MultipartReader
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand Down Expand Up @@ -608,46 +608,50 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
multipart = await self.multipart()
max_size = self._client_max_size

field = await multipart.next()
while field is not None:
async for field in multipart:
size = 0
content_type = field.headers.get(hdrs.CONTENT_TYPE)

if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
field_content_type = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
if field.filename:
assert field_content_type is not None, \
'Cannot read file without knowing what it is'
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(
field.name,
field.filename,
cast(io.BufferedReader, tmp),
field_content_type,
CIMultiDictProxy(CIMultiDict(**field.headers)),
)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
content_type, field.headers)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)

field = await multipart.next()
else:
data = await self.read()
if data:
Expand Down