Skip to content

Commit

Permalink
Limit websocket message size on reading (#3045)
Browse files Browse the repository at this point in the history
  • Loading branch information
asvetlov authored Jun 1, 2018
1 parent f426da7 commit e021a01
Show file tree
Hide file tree
Showing 7 changed files with 89 additions and 22 deletions.
1 change: 1 addition & 0 deletions CHANGES/3045.feature
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Limit websocket message size on reading to 4 MB by default.
11 changes: 7 additions & 4 deletions aiohttp/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,8 @@ def ws_connect(self, url, *,
fingerprint=None,
ssl_context=None,
proxy_headers=None,
compress=0):
compress=0,
max_msg_size=4*1024*1024):
"""Initiate websocket connection."""
return _WSRequestContextManager(
self._ws_connect(url,
Expand All @@ -539,7 +540,8 @@ def ws_connect(self, url, *,
fingerprint=fingerprint,
ssl_context=ssl_context,
proxy_headers=proxy_headers,
compress=compress))
compress=compress,
max_msg_size=max_msg_size))

async def _ws_connect(self, url, *,
protocols=(),
Expand All @@ -558,7 +560,8 @@ async def _ws_connect(self, url, *,
fingerprint=None,
ssl_context=None,
proxy_headers=None,
compress=0):
compress=0,
max_msg_size=4*1024*1024):

if headers is None:
headers = CIMultiDict()
Expand Down Expand Up @@ -667,7 +670,7 @@ async def _ws_connect(self, url, *,
transport = resp.connection.transport
reader = FlowControlDataQueue(
proto, limit=2 ** 16, loop=self._loop)
proto.set_parser(WebSocketReader(reader), reader)
proto.set_parser(WebSocketReader(reader, max_msg_size), reader)
tcp_nodelay(transport, True)
writer = WebSocketWriter(
proto, transport, use_mask=True,
Expand Down
24 changes: 22 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,9 @@ class WSParserState(IntEnum):

class WebSocketReader:

def __init__(self, queue, compress=True):
def __init__(self, queue, max_msg_size, compress=True):
self.queue = queue
self._max_msg_size = max_msg_size

self._exc = None
self._partial = bytearray()
Expand Down Expand Up @@ -320,6 +321,12 @@ def _feed_data(self, data):
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))
else:
# previous frame was non finished
# we should get continuation opcode
Expand All @@ -335,13 +342,26 @@ def _feed_data(self, data):
self._opcode = None

self._partial.extend(payload)
if (self._max_msg_size and
len(self._partial) >= self._max_msg_size):
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size))

# Decompress process must to be done after all packets
# received.
if compressed:
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress(
self._partial)
self._partial, self._max_msg_size)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size exceeds limit {}".
format(self._max_msg_size + left,
self._max_msg_size))
else:
payload_merged = bytes(self._partial)

Expand Down
5 changes: 3 additions & 2 deletions aiohttp/web_ws.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class WebSocketResponse(StreamResponse):
def __init__(self, *,
timeout=10.0, receive_timeout=None,
autoclose=True, autoping=True, heartbeat=None,
protocols=(), compress=True):
protocols=(), compress=True, max_msg_size=4*1024*1024):
super().__init__(status=101)
self._protocols = protocols
self._ws_protocol = None
Expand All @@ -61,6 +61,7 @@ def __init__(self, *,
self._pong_heartbeat = heartbeat / 2.0
self._pong_response_cb = None
self._compress = compress
self._max_msg_size = max_msg_size

def _cancel_heartbeat(self):
if self._pong_response_cb is not None:
Expand Down Expand Up @@ -203,7 +204,7 @@ def _post_start(self, request, protocol, writer):
self._reader = FlowControlDataQueue(
request._protocol, limit=2 ** 16, loop=self._loop)
request.protocol.set_parser(WebSocketReader(
self._reader, compress=self._compress))
self._reader, self._max_msg_size, compress=self._compress))
# disable HTTP keepalive for WebSocket
request.protocol.keep_alive(False)

Expand Down
8 changes: 7 additions & 1 deletion docs/client_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -498,7 +498,7 @@ The client session supports the context manager protocol for self closing.
proxy=None, proxy_auth=None, ssl=None, \
verify_ssl=None, fingerprint=None, \
ssl_context=None, proxy_headers=None, \
compress=0)
compress=0, max_msg_size=4194304)
:async-with:
:coroutine:

Expand Down Expand Up @@ -601,6 +601,12 @@ The client session supports the context manager protocol for self closing.

.. versionadded:: 2.3

:param int max_msg_size: maximum size of read websocket message,
4 MB by default. To disable the size
limit use ``0``.

.. versionadded:: 3.3


.. comethod:: close()

Expand Down
8 changes: 7 additions & 1 deletion docs/web_reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -865,7 +865,7 @@ WebSocketResponse

.. class:: WebSocketResponse(*, timeout=10.0, receive_timeout=None, \
autoclose=True, autoping=True, heartbeat=None, \
protocols=(), compress=True)
protocols=(), compress=True, max_msg_size=4194304)

Class for handling server-side websockets, inherited from
:class:`StreamResponse`.
Expand Down Expand Up @@ -903,6 +903,12 @@ WebSocketResponse
:param bool compress: Enable per-message deflate extension support.
False for disabled, default value is True.

:param int max_msg_size: maximum size of read websocket message, 4
MB by default. To disable the size limit use ``0``.

.. versionadded:: 3.3


The class supports ``async for`` statement for iterating over
incoming messages::

Expand Down
54 changes: 42 additions & 12 deletions tests/test_websocket_parser.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import random
import struct
import zlib
from unittest import mock

import pytest

import aiohttp
from aiohttp import http_websocket
from aiohttp.http import WebSocketError, WSCloseCode, WSMessage, WSMsgType
from aiohttp.http_websocket import (PACK_CLOSE_CODE, PACK_LEN1, PACK_LEN2,
PACK_LEN3, WebSocketReader,
_websocket_mask)
from aiohttp.http_websocket import (_WS_DEFLATE_TRAILING, PACK_CLOSE_CODE,
PACK_LEN1, PACK_LEN2, PACK_LEN3,
WebSocketReader, _websocket_mask)


def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True):
def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True,
compress=False):
"""Send a frame over the websocket with message as its payload."""
if compress:
compressobj = zlib.compressobj(wbits=-9)
message = compressobj.compress(message)
message = message + compressobj.flush(zlib.Z_SYNC_FLUSH)
if message.endswith(_WS_DEFLATE_TRAILING):
message = message[:-4]
msg_length = len(message)
if use_mask: # pragma: no cover
mask_bit = 0x80
Expand All @@ -25,6 +33,9 @@ def build_frame(message, opcode, use_mask=False, noheader=False, is_fin=True):
else:
header_first_byte = opcode

if compress:
header_first_byte |= 0x40

if msg_length < 126:
header = PACK_LEN1(
header_first_byte, msg_length | mask_bit)
Expand Down Expand Up @@ -67,7 +78,7 @@ def out(loop):

@pytest.fixture()
def parser(out):
return WebSocketReader(out)
return WebSocketReader(out, 4*1024*1024)


def test_parse_frame(parser):
Expand Down Expand Up @@ -444,16 +455,35 @@ def test_parse_compress_error_frame(parser):
assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


@pytest.fixture()
def parser_no_compress(out):
return WebSocketReader(out, compress=False)


def test_parse_no_compress_frame_single(parser_no_compress):

def test_parse_no_compress_frame_single():
parser_no_compress = WebSocketReader(out, 0, compress=False)
with pytest.raises(WebSocketError) as ctx:
parser_no_compress.parse_frame(struct.pack(
'!BB', 0b11000001, 0b00000001))
parser_no_compress.parse_frame(b'1')

assert ctx.value.code == WSCloseCode.PROTOCOL_ERROR


def test_msg_too_large(out):
parser = WebSocketReader(out, 256, compress=False)
data = build_frame(b'text'*256, WSMsgType.TEXT)
with pytest.raises(WebSocketError) as ctx:
parser._feed_data(data)
assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG


def test_msg_too_large_not_fin(out):
parser = WebSocketReader(out, 256, compress=False)
data = build_frame(b'text'*256, WSMsgType.TEXT, is_fin=False)
with pytest.raises(WebSocketError) as ctx:
parser._feed_data(data)
assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG


def test_compressed_msg_too_large(out):
parser = WebSocketReader(out, 256, compress=True)
data = build_frame(b'aaa'*256, WSMsgType.TEXT, compress=True)
with pytest.raises(WebSocketError) as ctx:
parser._feed_data(data)
assert ctx.value.code == WSCloseCode.MESSAGE_TOO_BIG

0 comments on commit e021a01

Please sign in to comment.