Skip to content

Commit 093b407

Browse files
committed
Switch event classes to dataclasses
The _EventBundle used previously dynamically creates the Event classes, which makes typing hard. As the Event classes are now stable the dynamic creation isn't required, and as h11 supports Python3.6+ dataclasses can be used instead. This now makes the Events frozen (and almost imutable) which better matches the intended API. In turn it requires the `object.__setattr__` usage and the alteration to the `_clean_up_response_headers_for_sending` method. This change also improves the performance a little, using the benchmark, Before: 6.7k requests/sec After: 6.9k requests/sec Notes: The test response-header changes are required as the previous version would mutate the response object. The init for the Data event is required as slots and defaults aren't possible with dataclasses.
1 parent 889a564 commit 093b407

File tree

6 files changed

+184
-143
lines changed

6 files changed

+184
-143
lines changed

h11/_connection.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def send_with_data_passthrough(self, event):
483483
raise LocalProtocolError("Can't send data when our state is ERROR")
484484
try:
485485
if type(event) is Response:
486-
self._clean_up_response_headers_for_sending(event)
486+
event = self._clean_up_response_headers_for_sending(event)
487487
# We want to call _process_event before calling the writer,
488488
# because if someone tries to do something invalid then this will
489489
# give a sensible error message, while our writers all just assume
@@ -528,8 +528,7 @@ def send_failed(self):
528528
#
529529
# This function's *only* responsibility is making sure headers are set up
530530
# right -- everything downstream just looks at the headers. There are no
531-
# side channels. It mutates the response event in-place (but not the
532-
# response.headers list object).
531+
# side channels.
533532
def _clean_up_response_headers_for_sending(self, response):
534533
assert type(response) is Response
535534

@@ -582,4 +581,9 @@ def _clean_up_response_headers_for_sending(self, response):
582581
connection.add(b"close")
583582
headers = set_comma_header(headers, b"connection", sorted(connection))
584583

585-
response.headers = headers
584+
return Response(
585+
headers=headers,
586+
status_code=response.status_code,
587+
http_version=response.http_version,
588+
reason=response.reason,
589+
)

h11/_events.py

Lines changed: 144 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,17 @@
66
# Don't subclass these. Stuff will break.
77

88
import re
9+
from abc import ABC
10+
from dataclasses import dataclass, field
11+
from typing import Any, cast, Dict, List, Tuple, Union
912

10-
from . import _headers
1113
from ._abnf import request_target
14+
from ._headers import Headers, normalize_and_validate
1215
from ._util import bytesify, LocalProtocolError, validate
1316

1417
# Everything in __all__ gets re-exported as part of the h11 public API.
1518
__all__ = [
19+
"Event",
1620
"Request",
1721
"InformationalResponse",
1822
"Response",
@@ -24,72 +28,16 @@
2428
request_target_re = re.compile(request_target.encode("ascii"))
2529

2630

27-
class _EventBundle:
28-
_fields = []
29-
_defaults = {}
30-
31-
def __init__(self, **kwargs):
32-
_parsed = kwargs.pop("_parsed", False)
33-
allowed = set(self._fields)
34-
for kwarg in kwargs:
35-
if kwarg not in allowed:
36-
raise TypeError(
37-
"unrecognized kwarg {} for {}".format(
38-
kwarg, self.__class__.__name__
39-
)
40-
)
41-
required = allowed.difference(self._defaults)
42-
for field in required:
43-
if field not in kwargs:
44-
raise TypeError(
45-
"missing required kwarg {} for {}".format(
46-
field, self.__class__.__name__
47-
)
48-
)
49-
self.__dict__.update(self._defaults)
50-
self.__dict__.update(kwargs)
51-
52-
# Special handling for some fields
53-
54-
if "headers" in self.__dict__:
55-
self.headers = _headers.normalize_and_validate(
56-
self.headers, _parsed=_parsed
57-
)
58-
59-
if not _parsed:
60-
for field in ["method", "target", "http_version", "reason"]:
61-
if field in self.__dict__:
62-
self.__dict__[field] = bytesify(self.__dict__[field])
63-
64-
if "status_code" in self.__dict__:
65-
if not isinstance(self.status_code, int):
66-
raise LocalProtocolError("status code must be integer")
67-
# Because IntEnum objects are instances of int, but aren't
68-
# duck-compatible (sigh), see gh-72.
69-
self.status_code = int(self.status_code)
70-
71-
self._validate()
72-
73-
def _validate(self):
74-
pass
75-
76-
def __repr__(self):
77-
name = self.__class__.__name__
78-
kwarg_strs = [
79-
"{}={}".format(field, self.__dict__[field]) for field in self._fields
80-
]
81-
kwarg_str = ", ".join(kwarg_strs)
82-
return "{}({})".format(name, kwarg_str)
83-
84-
# Useful for tests
85-
def __eq__(self, other):
86-
return self.__class__ == other.__class__ and self.__dict__ == other.__dict__
31+
class Event(ABC):
32+
"""
33+
Base class for h11 events.
34+
"""
8735

88-
# This is an unhashable type.
89-
__hash__ = None
36+
__slots__ = ()
9037

9138

92-
class Request(_EventBundle):
39+
@dataclass(init=False, frozen=True)
40+
class Request(Event):
9341
"""The beginning of an HTTP request.
9442
9543
Fields:
@@ -123,10 +71,38 @@ class Request(_EventBundle):
12371
12472
"""
12573

126-
_fields = ["method", "target", "headers", "http_version"]
127-
_defaults = {"http_version": b"1.1"}
74+
__slots__ = ("method", "headers", "target", "http_version")
75+
76+
method: bytes
77+
headers: Headers
78+
target: bytes
79+
http_version: bytes
80+
81+
def __init__(
82+
self,
83+
*,
84+
method: Union[bytes, str],
85+
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
86+
target: Union[bytes, str],
87+
http_version: Union[bytes, str] = b"1.1",
88+
_parsed: bool = False,
89+
) -> None:
90+
super().__init__()
91+
if isinstance(headers, Headers):
92+
object.__setattr__(self, "headers", headers)
93+
else:
94+
object.__setattr__(
95+
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
96+
)
97+
if not _parsed:
98+
object.__setattr__(self, "method", bytesify(method))
99+
object.__setattr__(self, "target", bytesify(target))
100+
object.__setattr__(self, "http_version", bytesify(http_version))
101+
else:
102+
object.__setattr__(self, "method", method)
103+
object.__setattr__(self, "target", target)
104+
object.__setattr__(self, "http_version", http_version)
128105

129-
def _validate(self):
130106
# "A server MUST respond with a 400 (Bad Request) status code to any
131107
# HTTP/1.1 request message that lacks a Host header field and to any
132108
# request message that contains more than one Host header field or a
@@ -143,12 +119,58 @@ def _validate(self):
143119

144120
validate(request_target_re, self.target, "Illegal target characters")
145121

122+
# This is an unhashable type.
123+
__hash__ = None # type: ignore
124+
125+
126+
@dataclass(init=False, frozen=True)
127+
class _ResponseBase(Event):
128+
__slots__ = ("headers", "http_version", "reason", "status_code")
129+
130+
headers: Headers
131+
http_version: bytes
132+
reason: bytes
133+
status_code: int
134+
135+
def __init__(
136+
self,
137+
*,
138+
headers: Union[Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]]],
139+
status_code: int,
140+
http_version: Union[bytes, str] = b"1.1",
141+
reason: Union[bytes, str] = b"",
142+
_parsed: bool = False,
143+
) -> None:
144+
super().__init__()
145+
if isinstance(headers, Headers):
146+
object.__setattr__(self, "headers", headers)
147+
else:
148+
object.__setattr__(
149+
self, "headers", normalize_and_validate(headers, _parsed=_parsed)
150+
)
151+
if not _parsed:
152+
object.__setattr__(self, "reason", bytesify(reason))
153+
object.__setattr__(self, "http_version", bytesify(http_version))
154+
if not isinstance(status_code, int):
155+
raise LocalProtocolError("status code must be integer")
156+
# Because IntEnum objects are instances of int, but aren't
157+
# duck-compatible (sigh), see gh-72.
158+
object.__setattr__(self, "status_code", int(status_code))
159+
else:
160+
object.__setattr__(self, "reason", reason)
161+
object.__setattr__(self, "http_version", http_version)
162+
object.__setattr__(self, "status_code", status_code)
163+
164+
self.__post_init__()
165+
166+
def __post_init__(self) -> None:
167+
pass
146168

147-
class _ResponseBase(_EventBundle):
148-
_fields = ["status_code", "headers", "http_version", "reason"]
149-
_defaults = {"http_version": b"1.1", "reason": b""}
169+
# This is an unhashable type.
170+
__hash__ = None # type: ignore
150171

151172

173+
@dataclass(init=False, frozen=True)
152174
class InformationalResponse(_ResponseBase):
153175
"""An HTTP informational response.
154176
@@ -179,14 +201,18 @@ class InformationalResponse(_ResponseBase):
179201
180202
"""
181203

182-
def _validate(self):
204+
def __post_init__(self) -> None:
183205
if not (100 <= self.status_code < 200):
184206
raise LocalProtocolError(
185207
"InformationalResponse status_code should be in range "
186208
"[100, 200), not {}".format(self.status_code)
187209
)
188210

211+
# This is an unhashable type.
212+
__hash__ = None # type: ignore
213+
189214

215+
@dataclass(init=False, frozen=True)
190216
class Response(_ResponseBase):
191217
"""The beginning of an HTTP response.
192218
@@ -216,16 +242,20 @@ class Response(_ResponseBase):
216242
217243
"""
218244

219-
def _validate(self):
245+
def __post_init__(self) -> None:
220246
if not (200 <= self.status_code < 600):
221247
raise LocalProtocolError(
222248
"Response status_code should be in range [200, 600), not {}".format(
223249
self.status_code
224250
)
225251
)
226252

253+
# This is an unhashable type.
254+
__hash__ = None # type: ignore
255+
227256

228-
class Data(_EventBundle):
257+
@dataclass(init=False, frozen=True)
258+
class Data(Event):
229259
"""Part of an HTTP message body.
230260
231261
Fields:
@@ -258,16 +288,30 @@ class Data(_EventBundle):
258288
259289
"""
260290

261-
_fields = ["data", "chunk_start", "chunk_end"]
262-
_defaults = {"chunk_start": False, "chunk_end": False}
291+
__slots__ = ("data", "chunk_start", "chunk_end")
292+
293+
data: bytes
294+
chunk_start: bool
295+
chunk_end: bool
296+
297+
def __init__(
298+
self, data: bytes, chunk_start: bool = False, chunk_end: bool = False
299+
) -> None:
300+
object.__setattr__(self, "data", data)
301+
object.__setattr__(self, "chunk_start", chunk_start)
302+
object.__setattr__(self, "chunk_end", chunk_end)
303+
304+
# This is an unhashable type.
305+
__hash__ = None # type: ignore
263306

264307

265308
# XX FIXME: "A recipient MUST ignore (or consider as an error) any fields that
266309
# are forbidden to be sent in a trailer, since processing them as if they were
267310
# present in the header section might bypass external security filters."
268311
# https://svn.tools.ietf.org/svn/wg/httpbis/specs/rfc7230.html#chunked.trailer.part
269312
# Unfortunately, the list of forbidden fields is long and vague :-/
270-
class EndOfMessage(_EventBundle):
313+
@dataclass(init=False, frozen=True)
314+
class EndOfMessage(Event):
271315
"""The end of an HTTP message.
272316
273317
Fields:
@@ -284,11 +328,32 @@ class EndOfMessage(_EventBundle):
284328
285329
"""
286330

287-
_fields = ["headers"]
288-
_defaults = {"headers": []}
331+
__slots__ = ("headers",)
332+
333+
headers: Headers
334+
335+
def __init__(
336+
self,
337+
*,
338+
headers: Union[
339+
Headers, List[Tuple[bytes, bytes]], List[Tuple[str, str]], None
340+
] = None,
341+
_parsed: bool = False,
342+
) -> None:
343+
super().__init__()
344+
if headers is None:
345+
headers = Headers([])
346+
elif not isinstance(headers, Headers):
347+
headers = normalize_and_validate(headers, _parsed=_parsed)
348+
349+
object.__setattr__(self, "headers", headers)
350+
351+
# This is an unhashable type.
352+
__hash__ = None # type: ignore
289353

290354

291-
class ConnectionClosed(_EventBundle):
355+
@dataclass(frozen=True)
356+
class ConnectionClosed(Event):
292357
"""This event indicates that the sender has closed their outgoing
293358
connection.
294359

h11/tests/helpers.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ def normalize_data_events(in_events):
2626
out_events = []
2727
for event in in_events:
2828
if type(event) is Data:
29-
event.data = bytes(event.data)
30-
event.chunk_start = False
31-
event.chunk_end = False
29+
event = Data(data=bytes(event.data), chunk_start=False, chunk_end=False)
3230
if out_events and type(out_events[-1]) is type(event) is Data:
33-
out_events[-1].data += event.data
31+
out_events[-1] = Data(
32+
data=out_events[-1].data + event.data,
33+
chunk_start=out_events[-1].chunk_start,
34+
chunk_end=out_events[-1].chunk_end,
35+
)
3436
else:
3537
out_events.append(event)
3638
return out_events

0 commit comments

Comments
 (0)