Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve typings for multipart #3905

Merged
merged 3 commits into from
Jul 19, 2019
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGES/3621.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
Improve typing annotations for multipart.py along with changes required
by mypy in files that references multipart.py.
79 changes: 53 additions & 26 deletions aiohttp/multipart.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Tuple,
Type,
Union,
cast,
)
from urllib.parse import parse_qsl, unquote, urlencode

Expand Down Expand Up @@ -195,21 +194,26 @@ def content_disposition_filename(params: Mapping[str, str],


class MultipartResponseWrapper:
"""Wrapper around the MultipartBodyReader.
"""Wrapper around the MultipartReader.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is no such of thing as MultipartBodyReader so decided to go with MultipartReader as the only thing that matched the interface. I will give it another look, but for now that was the thing that matched.


It takes care about
underlying connection and close it when it needs in.
"""

def __init__(self, resp: 'ClientResponse', stream: Any) -> None:
# TODO: add strong annotation to stream
def __init__(
self,
resp: 'ClientResponse',
stream: 'MultipartReader',
) -> None:
self.resp = resp
self.stream = stream

def __aiter__(self) -> 'MultipartResponseWrapper':
return self

async def __anext__(self) -> Any:
async def __anext__(
self,
) -> Union['MultipartReader', 'BodyPartReader']:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
Expand All @@ -219,7 +223,9 @@ def at_eof(self) -> bool:
"""Returns True when all response data had been read."""
return self.resp.content.at_eof()

async def next(self) -> Any:
async def next(
self,
) -> Optional[Union['MultipartReader', 'BodyPartReader']]:
"""Emits next multipart reader object."""
item = await self.stream.next()
if self.stream.at_eof():
Expand All @@ -240,7 +246,7 @@ class BodyPartReader:
def __init__(
self,
boundary: bytes,
headers: Mapping[str, Optional[str]],
headers: 'CIMultiDictProxy[str]',
content: StreamReader,
*,
_newline: bytes = b'\r\n'
Expand All @@ -262,19 +268,19 @@ def __init__(
def __aiter__(self) -> 'BodyPartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(self) -> bytes:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

async def next(self) -> Any:
async def next(self) -> Optional[bytes]:
item = await self.read()
if not item:
return None
return item

async def read(self, *, decode: bool=False) -> Any:
async def read(self, *, decode: bool=False) -> bytes:
"""Reads body part data.

decode: Decodes data following by encoding
Expand Down Expand Up @@ -429,7 +435,11 @@ async def text(self, *, encoding: Optional[str]=None) -> str:
encoding = encoding or self.get_charset(default='utf-8')
return data.decode(encoding)

async def json(self, *, encoding: Optional[str]=None) -> Any:
async def json(
self,
*,
encoding: Optional[str]=None,
) -> Optional[Dict[str, Any]]:
"""Like read(), but assumes that body parts contains JSON data."""
data = await self.read(decode=True)
if not data:
Expand Down Expand Up @@ -468,7 +478,7 @@ def decode(self, data: bytes) -> bytes:
return data

def _decode_content(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_ENCODING]).lower()
encoding = self.headers.get(CONTENT_ENCODING, '').lower()

if encoding == 'deflate':
return zlib.decompress(data, -zlib.MAX_WBITS)
Expand All @@ -480,7 +490,7 @@ def _decode_content(self, data: bytes) -> bytes:
raise RuntimeError('unknown content encoding: {}'.format(encoding))

def _decode_content_transfer(self, data: bytes) -> bytes:
encoding = cast(str, self.headers[CONTENT_TRANSFER_ENCODING]).lower()
encoding = self.headers.get(CONTENT_TRANSFER_ENCODING, '').lower()

if encoding == 'base64':
return base64.b64decode(data)
Expand Down Expand Up @@ -564,22 +574,27 @@ def __init__(
self._boundary = ('--' + self._get_boundary()).encode()
self._newline = _newline
self._content = content
self._last_part = None
self._last_part = None # type: Optional[Union['MultipartReader', BodyPartReader]] # noqa
self._at_eof = False
self._at_bof = True
self._unread = [] # type: List[bytes]

def __aiter__(self) -> 'MultipartReader':
return self

async def __anext__(self) -> Any:
async def __anext__(
self,
) -> Union['MultipartReader', BodyPartReader]:
part = await self.next()
if part is None:
raise StopAsyncIteration # NOQA
return part

@classmethod
def from_response(cls, response: 'ClientResponse') -> Any:
def from_response(
cls,
response: 'ClientResponse',
) -> MultipartResponseWrapper:
"""Constructs reader instance from HTTP response.

:param response: :class:`~aiohttp.client.ClientResponse` instance
Expand All @@ -594,19 +609,21 @@ def at_eof(self) -> bool:
"""
return self._at_eof

async def next(self) -> Any:
async def next(
self,
) -> Optional[Union['MultipartReader', BodyPartReader]]:
"""Emits the next multipart body part."""
# So, if we're at BOF, we need to skip till the boundary.
if self._at_eof:
return
return None
await self._maybe_release_last_part()
if self._at_bof:
await self._read_until_first_boundary()
self._at_bof = False
else:
await self._read_boundary()
if self._at_eof: # we just read the last boundary, nothing to do there
return
return None
self._last_part = await self.fetch_next_part()
return self._last_part

Expand All @@ -618,12 +635,17 @@ async def release(self) -> None:
break
await item.release()

async def fetch_next_part(self) -> Any:
async def fetch_next_part(
self,
) -> Union['MultipartReader', BodyPartReader]:
"""Returns the next body part reader."""
headers = await self._read_headers()
return self._get_part_reader(headers)

def _get_part_reader(self, headers: 'CIMultiDictProxy[str]') -> Any:
def _get_part_reader(
self,
headers: 'CIMultiDictProxy[str]',
) -> Union['MultipartReader', BodyPartReader]:
"""Dispatches the response by the `Content-Type` header, returning
suitable reader instance.

Expand Down Expand Up @@ -822,7 +844,7 @@ def boundary(self) -> str:
def append(
self,
obj: Any,
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
if headers is None:
headers = CIMultiDict()
Expand All @@ -841,15 +863,20 @@ def append(
def append_payload(self, payload: Payload) -> Payload:
"""Adds a new body part to multipart writer."""
# compression
encoding = payload.headers.get(CONTENT_ENCODING, '').lower() # type: Optional[str] # noqa
encoding = payload.headers.get(
CONTENT_ENCODING,
'',
).lower() # type: Optional[str]
if encoding and encoding not in ('deflate', 'gzip', 'identity'):
raise RuntimeError('unknown content encoding: {}'.format(encoding))
if encoding == 'identity':
encoding = None

# te encoding
te_encoding = payload.headers.get(
CONTENT_TRANSFER_ENCODING, '').lower() # type: Optional[str] # noqa
CONTENT_TRANSFER_ENCODING,
'',
).lower() # type: Optional[str]
if te_encoding not in ('', 'base64', 'quoted-printable', 'binary'):
raise RuntimeError('unknown content transfer encoding: {}'
''.format(te_encoding))
Expand All @@ -867,7 +894,7 @@ def append_payload(self, payload: Payload) -> Payload:
def append_json(
self,
obj: Any,
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append JSON part."""
if headers is None:
Expand All @@ -879,7 +906,7 @@ def append_form(
self,
obj: Union[Sequence[Tuple[str, str]],
Mapping[str, str]],
headers: Optional['MultiMapping[str]']=None
headers: Optional[MultiMapping[str]]=None
) -> Payload:
"""Helper to append form urlencoded part."""
assert isinstance(obj, (Sequence, Mapping))
Expand Down
68 changes: 38 additions & 30 deletions aiohttp/web_request.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
sentinel,
)
from .http_parser import RawRequestMessage
from .multipart import MultipartReader
from .multipart import BodyPartReader, MultipartReader
from .streams import EmptyStreamReader, StreamReader
from .typedefs import (
DEFAULT_JSON_DECODER,
Expand Down Expand Up @@ -633,41 +633,49 @@ async def post(self) -> 'MultiDictProxy[Union[str, bytes, FileField]]':
field = await multipart.next()
while field is not None:
size = 0
content_type = field.headers.get(hdrs.CONTENT_TYPE)

if field.filename:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
field_ct = field.headers.get(hdrs.CONTENT_TYPE)

if isinstance(field, BodyPartReader):
if field.filename and field_ct:
# store file in temp file
tmp = tempfile.TemporaryFile()
chunk = await field.read_chunk(size=2**16)
while chunk:
chunk = field.decode(chunk)
tmp.write(chunk)
size += len(chunk)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
field_ct, field.headers)
out.add(field.name, ff)
else:
# deal with ordinary data
value = await field.read(decode=True)
if field_ct is None or \
field_ct.startswith('text/'):
charset = field.get_charset(default='utf-8')
out.add(field.name, value.decode(charset))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that accounts for producing str from post

else:
out.add(field.name, value)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that is for bytes

size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
chunk = await field.read_chunk(size=2**16)
tmp.seek(0)

ff = FileField(field.name, field.filename,
cast(io.BufferedReader, tmp),
content_type, field.headers)
out.add(field.name, ff)
else:
value = await field.read(decode=True)
if content_type is None or \
content_type.startswith('text/'):
charset = field.get_charset(default='utf-8')
value = value.decode(charset)
out.add(field.name, value)
size += len(value)
if 0 < max_size < size:
raise HTTPRequestEntityTooLarge(
max_size=max_size,
actual_size=size
)
raise ValueError(
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given the documentation, I assume that behavior here is correct and decoding nested multipart is a custom job.

'To decode nested multipart you need '
'to use custom reader',
)

field = await multipart.next()
else:
Expand Down