Skip to content

Commit

Permalink
[PR #8736/1b88af2 backport][3.11] Improve performance of WebSocketRea…
Browse files Browse the repository at this point in the history
…der (#8744)
  • Loading branch information
bdraco authored Aug 19, 2024
1 parent 9a03467 commit 8c19def
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 96 deletions.
1 change: 1 addition & 0 deletions CHANGES/8736.misc.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Improved performance of the WebSocket reader -- by :user:`bdraco`.
200 changes: 104 additions & 96 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,12 @@ class WSMsgType(IntEnum):
error = ERROR


MESSAGE_TYPES_WITH_CONTENT: Final = (
WSMsgType.BINARY,
WSMsgType.TEXT,
WSMsgType.CONTINUATION,
)

WS_KEY: Final[bytes] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"


Expand Down Expand Up @@ -313,17 +319,101 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return True, data

try:
return self._feed_data(data)
self._feed_data(data)
except Exception as exc:
self._exc = exc
set_exception(self.queue, exc)
return True, b""

def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return False, b""

def _feed_data(self, data: bytes) -> None:
for fin, opcode, payload, compressed in self.parse_frame(data):
if compressed and not self._decompressobj:
self._decompressobj = ZLibDecompressor(suppress_deflate_header=True)
if opcode == WSMsgType.CLOSE:
if opcode in MESSAGE_TYPES_WITH_CONTENT:
# load text/binary
is_continuation = opcode == WSMsgType.CONTINUATION
if not fin:
# got partial frame payload
if not is_continuation:
self._opcode = opcode
self._partial += payload
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
continue

has_partial = bool(self._partial)
if is_continuation:
if self._opcode is None:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"Continuation frame for non started message",
)
opcode = self._opcode
self._opcode = None
# previous frame was non finished
# we should get continuation opcode
elif has_partial:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)

if has_partial:
assembled_payload = self._partial + payload
self._partial.clear()
else:
assembled_payload = payload

if self._max_msg_size and len(assembled_payload) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(assembled_payload), self._max_msg_size
),
)

# Decompress process must to be done after all packets
# received.
if compressed:
if not self._decompressobj:
self._decompressobj = ZLibDecompressor(
suppress_deflate_header=True
)
payload_merged = self._decompressobj.decompress_sync(
assembled_payload + _WS_DEFLATE_TRAILING, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(assembled_payload)

if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc

self.queue.feed_data(WSMessage(WSMsgType.TEXT, text, ""), len(text))
continue

self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""), len(payload_merged)
)
elif opcode == WSMsgType.CLOSE:
if len(payload) >= 2:
close_code = UNPACK_CLOSE_CODE(payload[:2])[0]
if close_code < 3000 and close_code not in ALLOWED_CLOSE_CODES:
Expand Down Expand Up @@ -358,90 +448,10 @@ def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
WSMessage(WSMsgType.PONG, payload, ""), len(payload)
)

elif (
opcode not in (WSMsgType.TEXT, WSMsgType.BINARY)
and self._opcode is None
):
else:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR, f"Unexpected opcode={opcode!r}"
)
else:
# load text/binary
if not fin:
# got partial frame payload
if opcode != WSMsgType.CONTINUATION:
self._opcode = opcode
self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)
else:
# previous frame was non finished
# we should get continuation opcode
if self._partial:
if opcode != WSMsgType.CONTINUATION:
raise WebSocketError(
WSCloseCode.PROTOCOL_ERROR,
"The opcode in non-fin frame is expected "
"to be zero, got {!r}".format(opcode),
)

if opcode == WSMsgType.CONTINUATION:
assert self._opcode is not None
opcode = self._opcode
self._opcode = None

self._partial.extend(payload)
if self._max_msg_size and len(self._partial) >= self._max_msg_size:
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Message size {} exceeds limit {}".format(
len(self._partial), self._max_msg_size
),
)

# Decompress process must to be done after all packets
# received.
if compressed:
assert self._decompressobj is not None
self._partial.extend(_WS_DEFLATE_TRAILING)
payload_merged = self._decompressobj.decompress_sync(
self._partial, self._max_msg_size
)
if self._decompressobj.unconsumed_tail:
left = len(self._decompressobj.unconsumed_tail)
raise WebSocketError(
WSCloseCode.MESSAGE_TOO_BIG,
"Decompressed message size {} exceeds limit {}".format(
self._max_msg_size + left, self._max_msg_size
),
)
else:
payload_merged = bytes(self._partial)

self._partial.clear()

if opcode == WSMsgType.TEXT:
try:
text = payload_merged.decode("utf-8")
self.queue.feed_data(
WSMessage(WSMsgType.TEXT, text, ""), len(text)
)
except UnicodeDecodeError as exc:
raise WebSocketError(
WSCloseCode.INVALID_TEXT, "Invalid UTF-8 text message"
) from exc
else:
self.queue.feed_data(
WSMessage(WSMsgType.BINARY, payload_merged, ""),
len(payload_merged),
)

return False, b""

def parse_frame(
self, buf: bytes
Expand Down Expand Up @@ -521,23 +531,21 @@ def parse_frame(

# read payload length
if self._state is WSParserState.READ_PAYLOAD_LENGTH:
length = self._payload_length_flag
if length == 126:
length_flag = self._payload_length_flag
if length_flag == 126:
if buf_length - start_pos < 2:
break
data = buf[start_pos : start_pos + 2]
start_pos += 2
length = UNPACK_LEN2(data)[0]
self._payload_length = length
elif length > 126:
self._payload_length = UNPACK_LEN2(data)[0]
elif length_flag > 126:
if buf_length - start_pos < 8:
break
data = buf[start_pos : start_pos + 8]
start_pos += 8
length = UNPACK_LEN3(data)[0]
self._payload_length = length
self._payload_length = UNPACK_LEN3(data)[0]
else:
self._payload_length = length
self._payload_length = length_flag

self._state = (
WSParserState.READ_PAYLOAD_MASK
Expand All @@ -560,11 +568,11 @@ def parse_frame(
chunk_len = buf_length - start_pos
if length >= chunk_len:
self._payload_length = length - chunk_len
payload.extend(buf[start_pos:])
payload += buf[start_pos:]
start_pos = buf_length
else:
self._payload_length = 0
payload.extend(buf[start_pos : start_pos + length])
payload += buf[start_pos : start_pos + length]
start_pos = start_pos + length

if self._payload_length != 0:
Expand Down

0 comments on commit 8c19def

Please sign in to comment.