Skip to content

Commit

Permalink
Refactor sendfile (#3383)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Nov 8, 2018
1 parent 412349d commit 690a3d4
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 95 deletions.
1 change: 1 addition & 0 deletions CHANGES/3383.bugfix
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix task cancellation when ``sendfile()`` syscall is used by static file handling.
66 changes: 39 additions & 27 deletions aiohttp/web_fileresponse.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import mimetypes
import os
import pathlib
from functools import partial
from typing import (IO, TYPE_CHECKING, Any, Awaitable, Callable, List, # noqa
Optional, Union, cast)

Expand Down Expand Up @@ -35,9 +36,15 @@ class SendfileStreamWriter(StreamWriter):
def __init__(self,
protocol: BaseProtocol,
loop: asyncio.AbstractEventLoop,
fobj: IO[Any],
count: int,
on_chunk_sent: _T_OnChunkSent=None) -> None:
super().__init__(protocol, loop, on_chunk_sent)
self._sendfile_buffer = [] # type: List[bytes]
self._fobj = fobj
self._count = count
self._offset = fobj.tell()
self._in_fd = fobj.fileno()

def _write(self, chunk: bytes) -> None:
# we overwrite StreamWriter._write, so nothing can be appended to
Expand All @@ -46,54 +53,57 @@ def _write(self, chunk: bytes) -> None:
self.output_size += len(chunk)
self._sendfile_buffer.append(chunk)

def _sendfile_cb(self, fut: 'asyncio.Future[None]',
out_fd: int, in_fd: int,
offset: int, count: int,
loop: asyncio.AbstractEventLoop,
registered: bool) -> None:
if registered:
loop.remove_writer(out_fd)
def _sendfile_cb(self, fut: 'asyncio.Future[None]', out_fd: int) -> None:
if fut.cancelled():
return
try:
if self._do_sendfile(out_fd):
set_result(fut, None)
except Exception as exc:
set_exception(fut, exc)

def _do_sendfile(self, out_fd: int) -> bool:
try:
n = os.sendfile(out_fd, in_fd, offset, count)
if n == 0: # EOF reached
n = count
n = os.sendfile(out_fd,
self._in_fd,
self._offset,
self._count)
if n == 0: # in_fd EOF reached
n = self._count
except (BlockingIOError, InterruptedError):
n = 0
except Exception as exc:
set_exception(fut, exc)
return
self.output_size += n
self._offset += n
self._count -= n
assert self._count >= 0
return self._count == 0

if n < count:
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd, in_fd,
offset + n, count - n, loop, True)
else:
set_result(fut, None)
def _done_fut(self, out_fd: int, fut: 'asyncio.Future[None]') -> None:
self.loop.remove_writer(out_fd)

async def sendfile(self, fobj: IO[Any], count: int) -> None:
async def sendfile(self) -> None:
assert self.transport is not None
out_socket = self.transport.get_extra_info('socket').dup()
out_socket.setblocking(False)
out_fd = out_socket.fileno()
in_fd = fobj.fileno()
offset = fobj.tell()

loop = self.loop
data = b''.join(self._sendfile_buffer)
try:
await loop.sock_sendall(out_socket, data)
fut = loop.create_future()
self._sendfile_cb(fut, out_fd, in_fd, offset, count, loop, False)
await fut
if not self._do_sendfile(out_fd):
fut = loop.create_future()
fut.add_done_callback(partial(self._done_fut, out_fd))
loop.add_writer(out_fd, self._sendfile_cb, fut, out_fd)
await fut
except asyncio.CancelledError:
raise
except Exception:
server_logger.debug('Socket error')
self.transport.close()
finally:
out_socket.close()

self.output_size += count
await super().write_eof()

async def write_eof(self, chunk: bytes=b'') -> None:
Expand Down Expand Up @@ -139,12 +149,14 @@ async def _sendfile_system(self, request: 'BaseRequest',
else:
writer = SendfileStreamWriter(
request.protocol,
request._loop
request._loop,
fobj,
count
)
request._payload_writer = writer

await super().prepare(request)
await writer.sendfile(fobj, count)
await writer.sendfile()

return writer

Expand Down
68 changes: 1 addition & 67 deletions tests/test_web_sendfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,73 +2,7 @@

from aiohttp import hdrs
from aiohttp.test_utils import make_mocked_coro, make_mocked_request
from aiohttp.web_fileresponse import FileResponse, SendfileStreamWriter


def test_static_handle_eof(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
m_os.sendfile.return_value = 0
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert fut.done()
assert fut.result() is None
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called


def test_static_handle_again(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
m_os.sendfile.side_effect = BlockingIOError()
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert not fut.done()
fake_loop.add_writer.assert_called_with(out_fd,
writer._sendfile_cb,
fut, out_fd, in_fd, 0, 100,
fake_loop, True)
assert not fake_loop.remove_writer.called


def test_static_handle_exception(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
exc = OSError()
m_os.sendfile.side_effect = exc
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
m_os.sendfile.assert_called_with(out_fd, in_fd, 0, 100)
assert fut.done()
assert exc is fut.exception()
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called


def test__sendfile_cb_return_on_cancelling(loop) -> None:
fake_loop = mock.Mock()
with mock.patch('aiohttp.web_fileresponse.os') as m_os:
out_fd = 30
in_fd = 31
fut = loop.create_future()
fut.cancel()
writer = SendfileStreamWriter(mock.Mock(), mock.Mock(), fake_loop)
writer._sendfile_cb(fut, out_fd, in_fd, 0, 100, fake_loop, False)
assert fut.done()
assert not fake_loop.add_writer.called
assert not fake_loop.remove_writer.called
assert not m_os.sendfile.called
from aiohttp.web_fileresponse import FileResponse


def test_using_gzip_if_header_present_and_file_available(loop) -> None:
Expand Down
68 changes: 67 additions & 1 deletion tests/test_web_sendfile_functional.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import os
import pathlib
import socket
import zlib

import pytest
Expand Down Expand Up @@ -324,7 +325,7 @@ def test_static_route_path_existence_check() -> None:
async def test_static_file_huge(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 100MB file
# fill 20MB file
with tmpdir.join(filename).open('w') as f:
for i in range(1024*20):
f.write(chr(i % 64 + 0x20) * 1024)
Expand Down Expand Up @@ -751,3 +752,68 @@ async def handler(request):
assert 'application/octet-stream' == resp.headers['Content-Type']
assert resp.headers.get('Content-Encoding') == 'deflate'
await resp.release()


async def test_static_file_huge_cancel(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 100MB file
with tmpdir.join(filename).open('w') as f:
for i in range(1024*20):
f.write(chr(i % 64 + 0x20) * 1024)

task = None

async def handler(request):
nonlocal task
task = request.task
# reduce send buffer size
tr = request.transport
sock = tr.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename))))
return ret

app = web.Application()

app.router.add_get('/', handler)
client = await aiohttp_client(app)

resp = await client.get('/')
assert resp.status == 200
task.cancel()
await asyncio.sleep(0)
data = b''
while True:
try:
data += await resp.content.read(1024)
except aiohttp.ClientPayloadError:
break
assert len(data) < 1024 * 1024 * 20


async def test_static_file_huge_error(aiohttp_client, tmpdir) -> None:
filename = 'huge_data.unknown_mime_type'

# fill 20MB file
with tmpdir.join(filename).open('wb') as f:
f.seek(20*1024*1024)
f.write(b'1')

async def handler(request):
# reduce send buffer size
tr = request.transport
sock = tr.get_extra_info('socket')
sock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 1024)
ret = web.FileResponse(pathlib.Path(str(tmpdir.join(filename))))
return ret

app = web.Application()

app.router.add_get('/', handler)
client = await aiohttp_client(app)

resp = await client.get('/')
assert resp.status == 200
# raise an exception on server side
resp.close()

0 comments on commit 690a3d4

Please sign in to comment.