Skip to content

Commit

Permalink
Refactor payload registry to avoid linear searches for common types (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
bdraco authored Nov 10, 2024
1 parent 50cccb3 commit c5d6b84
Showing 1 changed file with 22 additions and 8 deletions.
30 changes: 22 additions & 8 deletions aiohttp/payload.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def __init__(self) -> None:
self._first: List[_PayloadRegistryItem] = []
self._normal: List[_PayloadRegistryItem] = []
self._last: List[_PayloadRegistryItem] = []
self._normal_lookup: Dict[Any, PayloadType] = {}

def get(
self,
Expand All @@ -109,12 +110,20 @@ def get(
_CHAIN: "Type[chain[_PayloadRegistryItem]]" = chain,
**kwargs: Any,
) -> "Payload":
if self._first:
for factory, type_ in self._first:
if isinstance(data, type_):
return factory(data, *args, **kwargs)
# Try the fast lookup first
if lookup_factory := self._normal_lookup.get(type(data)):
return lookup_factory(data, *args, **kwargs)
# Bail early if its already a Payload
if isinstance(data, Payload):
return data
for factory, type in _CHAIN(self._first, self._normal, self._last):
if isinstance(data, type):
# Fallback to the slower linear search
for factory, type_ in _CHAIN(self._normal, self._last):
if isinstance(data, type_):
return factory(data, *args, **kwargs)

raise LookupError()

def register(
Expand All @@ -124,6 +133,11 @@ def register(
self._first.append((factory, type))
elif order is Order.normal:
self._normal.append((factory, type))
if isinstance(type, Iterable):
for t in type:
self._normal_lookup[t] = factory
else:
self._normal_lookup[type] = factory
elif order is Order.try_last:
self._last.append((factory, type))
else:
Expand Down Expand Up @@ -159,7 +173,8 @@ def __init__(
self._headers[hdrs.CONTENT_TYPE] = content_type
else:
self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
self._headers.update(headers or {})
if headers:
self._headers.update(headers)

@property
def size(self) -> Optional[int]:
Expand Down Expand Up @@ -228,18 +243,17 @@ class BytesPayload(Payload):
def __init__(
self, value: Union[bytes, bytearray, memoryview], *args: Any, **kwargs: Any
) -> None:
if not isinstance(value, (bytes, bytearray, memoryview)):
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")

if "content_type" not in kwargs:
kwargs["content_type"] = "application/octet-stream"

super().__init__(value, *args, **kwargs)

if isinstance(value, memoryview):
self._size = value.nbytes
else:
elif isinstance(value, (bytes, bytearray)):
self._size = len(value)
else:
raise TypeError(f"value argument must be byte-ish, not {type(value)!r}")

if self._size > TOO_LARGE_BYTES_BODY:
warnings.warn(
Expand Down

0 comments on commit c5d6b84

Please sign in to comment.