Skip to content

Commit 8fdf6f8

Browse files
authored
Add TypedDict callbacks (#98)
1 parent 9c936b7 commit 8fdf6f8

File tree

3 files changed

+60
-33
lines changed

3 files changed

+60
-33
lines changed

.gitignore

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ lib64
2828
pip-log.txt
2929

3030
# Unit test / coverage reports
31-
.coverage.*
31+
.coverage*
3232
.tox
3333
nosetests.xml
3434

multipart/multipart.py

+52-23
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,38 @@
99
from enum import IntEnum
1010
from io import BytesIO
1111
from numbers import Number
12-
from typing import Dict, Tuple, Union
12+
from typing import TYPE_CHECKING
1313

1414
from .decoders import Base64Decoder, QuotedPrintableDecoder
1515
from .exceptions import FileError, FormParserError, MultipartParseError, QuerystringParseError
1616

17+
if TYPE_CHECKING: # pragma: no cover
18+
from typing import Callable, TypedDict
19+
20+
class QuerystringCallbacks(TypedDict, total=False):
21+
on_field_start: Callable[[], None]
22+
on_field_name: Callable[[bytes, int, int], None]
23+
on_field_data: Callable[[bytes, int, int], None]
24+
on_field_end: Callable[[], None]
25+
on_end: Callable[[], None]
26+
27+
class OctetStreamCallbacks(TypedDict, total=False):
28+
on_start: Callable[[], None]
29+
on_data: Callable[[bytes, int, int], None]
30+
on_end: Callable[[], None]
31+
32+
class MultipartCallbacks(TypedDict, total=False):
33+
on_part_begin: Callable[[], None]
34+
on_part_data: Callable[[bytes, int, int], None]
35+
on_part_end: Callable[[], None]
36+
on_headers_begin: Callable[[], None]
37+
on_header_field: Callable[[bytes, int, int], None]
38+
on_header_value: Callable[[bytes, int, int], None]
39+
on_header_end: Callable[[], None]
40+
on_headers_finished: Callable[[], None]
41+
on_end: Callable[[], None]
42+
43+
1744
# Unique missing object.
1845
_missing = object()
1946

@@ -86,7 +113,7 @@ def join_bytes(b):
86113
return bytes(list(b))
87114

88115

89-
def parse_options_header(value: Union[str, bytes]) -> Tuple[bytes, Dict[bytes, bytes]]:
116+
def parse_options_header(value: str | bytes) -> tuple[bytes, dict[bytes, bytes]]:
90117
"""
91118
Parses a Content-Type header into a value in the following format:
92119
(content_type, {parameters})
@@ -148,15 +175,15 @@ class Field:
148175
:param name: the name of the form field
149176
"""
150177

151-
def __init__(self, name):
178+
def __init__(self, name: str):
152179
self._name = name
153-
self._value = []
180+
self._value: list[bytes] = []
154181

155182
# We cache the joined version of _value for speed.
156183
self._cache = _missing
157184

158185
@classmethod
159-
def from_value(klass, name, value):
186+
def from_value(cls, name: str, value: bytes | None) -> Field:
160187
"""Create an instance of a :class:`Field`, and set the corresponding
161188
value - either None or an actual value. This method will also
162189
finalize the Field itself.
@@ -166,22 +193,22 @@ def from_value(klass, name, value):
166193
None
167194
"""
168195

169-
f = klass(name)
196+
f = cls(name)
170197
if value is None:
171198
f.set_none()
172199
else:
173200
f.write(value)
174201
f.finalize()
175202
return f
176203

177-
def write(self, data):
204+
def write(self, data: bytes) -> int:
178205
"""Write some data into the form field.
179206
180207
:param data: a bytestring
181208
"""
182209
return self.on_data(data)
183210

184-
def on_data(self, data):
211+
def on_data(self, data: bytes) -> int:
185212
"""This method is a callback that will be called whenever data is
186213
written to the Field.
187214
@@ -191,24 +218,24 @@ def on_data(self, data):
191218
self._cache = _missing
192219
return len(data)
193220

194-
def on_end(self):
221+
def on_end(self) -> None:
195222
"""This method is called whenever the Field is finalized."""
196223
if self._cache is _missing:
197224
self._cache = b"".join(self._value)
198225

199-
def finalize(self):
226+
def finalize(self) -> None:
200227
"""Finalize the form field."""
201228
self.on_end()
202229

203-
def close(self):
230+
def close(self) -> None:
204231
"""Close the Field object. This will free any underlying cache."""
205232
# Free our value array.
206233
if self._cache is _missing:
207234
self._cache = b"".join(self._value)
208235

209236
del self._value
210237

211-
def set_none(self):
238+
def set_none(self) -> None:
212239
"""Some fields in a querystring can possibly have a value of None - for
213240
example, the string "foo&bar=&baz=asdf" will have a field with the
214241
name "foo" and value None, one with name "bar" and value "", and one
@@ -218,7 +245,7 @@ def set_none(self):
218245
self._cache = None
219246

220247
@property
221-
def field_name(self):
248+
def field_name(self) -> str:
222249
"""This property returns the name of the field."""
223250
return self._name
224251

@@ -230,13 +257,13 @@ def value(self):
230257

231258
return self._cache
232259

233-
def __eq__(self, other):
260+
def __eq__(self, other: object) -> bool:
234261
if isinstance(other, Field):
235262
return self.field_name == other.field_name and self.value == other.value
236263
else:
237264
return NotImplemented
238265

239-
def __repr__(self):
266+
def __repr__(self) -> str:
240267
if len(self.value) > 97:
241268
# We get the repr, and then insert three dots before the final
242269
# quote.
@@ -553,7 +580,7 @@ class BaseParser:
553580
def __init__(self):
554581
self.logger = logging.getLogger(__name__)
555582

556-
def callback(self, name, data=None, start=None, end=None):
583+
def callback(self, name: str, data=None, start=None, end=None):
557584
"""This function calls a provided callback with some data. If the
558585
callback is not set, will do nothing.
559586
@@ -584,7 +611,7 @@ def callback(self, name, data=None, start=None, end=None):
584611
self.logger.debug("Calling %s with no data", name)
585612
func()
586613

587-
def set_callback(self, name, new_func):
614+
def set_callback(self, name: str, new_func):
588615
"""Update the function for a callback. Removes from the callbacks dict
589616
if new_func is None.
590617
@@ -637,7 +664,7 @@ class OctetStreamParser(BaseParser):
637664
i.e. unbounded.
638665
"""
639666

640-
def __init__(self, callbacks={}, max_size=float("inf")):
667+
def __init__(self, callbacks: OctetStreamCallbacks = {}, max_size=float("inf")):
641668
super().__init__()
642669
self.callbacks = callbacks
643670
self._started = False
@@ -647,7 +674,7 @@ def __init__(self, callbacks={}, max_size=float("inf")):
647674
self.max_size = max_size
648675
self._current_size = 0
649676

650-
def write(self, data):
677+
def write(self, data: bytes):
651678
"""Write some data to the parser, which will perform size verification,
652679
and then pass the data to the underlying callback.
653680
@@ -732,7 +759,9 @@ class QuerystringParser(BaseParser):
732759
i.e. unbounded.
733760
"""
734761

735-
def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
762+
state: QuerystringState
763+
764+
def __init__(self, callbacks: QuerystringCallbacks = {}, strict_parsing=False, max_size=float("inf")):
736765
super().__init__()
737766
self.state = QuerystringState.BEFORE_FIELD
738767
self._found_sep = False
@@ -748,7 +777,7 @@ def __init__(self, callbacks={}, strict_parsing=False, max_size=float("inf")):
748777
# Should parsing be strict?
749778
self.strict_parsing = strict_parsing
750779

751-
def write(self, data):
780+
def write(self, data: bytes):
752781
"""Write some data to the parser, which will perform size verification,
753782
parse into either a field name or value, and then pass the
754783
corresponding data to the underlying callback. If an error is
@@ -780,7 +809,7 @@ def write(self, data):
780809

781810
return l
782811

783-
def _internal_write(self, data, length):
812+
def _internal_write(self, data: bytes, length: int):
784813
state = self.state
785814
strict_parsing = self.strict_parsing
786815
found_sep = self._found_sep
@@ -989,7 +1018,7 @@ class MultipartParser(BaseParser):
9891018
i.e. unbounded.
9901019
"""
9911020

992-
def __init__(self, boundary, callbacks={}, max_size=float("inf")):
1021+
def __init__(self, boundary, callbacks: MultipartCallbacks = {}, max_size=float("inf")):
9931022
# Initialize parser state.
9941023
super().__init__()
9951024
self.state = MultipartState.START

tests/test_multipart.py

+7-9
Original file line numberDiff line numberDiff line change
@@ -333,9 +333,9 @@ def on_field_end():
333333
del name_buffer[:]
334334
del data_buffer[:]
335335

336-
callbacks = {"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
337-
338-
self.p = QuerystringParser(callbacks)
336+
self.p = QuerystringParser(
337+
callbacks={"on_field_name": on_field_name, "on_field_data": on_field_data, "on_field_end": on_field_end}
338+
)
339339

340340
def test_simple_querystring(self):
341341
self.p.write(b"foo=bar")
@@ -464,18 +464,16 @@ def setUp(self):
464464
self.started = 0
465465
self.finished = 0
466466

467-
def on_start():
467+
def on_start() -> None:
468468
self.started += 1
469469

470-
def on_data(data, start, end):
470+
def on_data(data: bytes, start: int, end: int) -> None:
471471
self.d.append(data[start:end])
472472

473-
def on_end():
473+
def on_end() -> None:
474474
self.finished += 1
475475

476-
callbacks = {"on_start": on_start, "on_data": on_data, "on_end": on_end}
477-
478-
self.p = OctetStreamParser(callbacks)
476+
self.p = OctetStreamParser(callbacks={"on_start": on_start, "on_data": on_data, "on_end": on_end})
479477

480478
def assert_data(self, data, finalize=True):
481479
self.assertEqual(b"".join(self.d), data)

0 commit comments

Comments
 (0)