Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 94 additions & 31 deletions starlette/datastructures.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
Any,
BinaryIO,
NamedTuple,
TypeVar,
Union,
cast,
)
Expand All @@ -21,11 +20,11 @@ class Address(NamedTuple):
port: int


_KeyType = TypeVar("_KeyType")
_KeyType = Any
# Mapping keys are invariant but their values are covariant since
# you can only read them
# that is, you can't do `Mapping[str, Animal]()["fido"] = Dog()`
_CovariantValueType = TypeVar("_CovariantValueType", covariant=True)
_CovariantValueType = Any


class URL:
Expand All @@ -44,17 +43,22 @@ def __init__(
query_string = scope.get("query_string", b"")

host_header = None
for key, value in scope["headers"]:
if key == b"host":
host_header = value.decode("latin-1")
break
# Optimize header scanning - avoid unnecessary iteration.
headers = scope["headers"]
if headers:
for key, value in headers:
if key == b"host":
host_header = value.decode("latin-1")
break

if host_header is not None:
url = f"{scheme}://{host_header}{path}"
elif server is None:
url = path
else:
host, port = server
# Avoid repeated dict lookups by storing the mapping before
# Default port lookup
default_port = {"http": 80, "https": 443, "ws": 80, "wss": 443}[scheme]
if port == default_port:
url = f"{scheme}://{host}{path}"
Expand Down Expand Up @@ -116,18 +120,28 @@ def is_secure(self) -> bool:
return self.scheme in ("https", "wss")

def replace(self, **kwargs: Any) -> URL:
if "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs:
# Minor optimization: do not pop repeatedly or call dict operations unnecessarily
_replacing = "username" in kwargs or "password" in kwargs or "hostname" in kwargs or "port" in kwargs
if _replacing:
# These pops are the same as before, but we minimize logic
hostname = kwargs.pop("hostname", None)
port = kwargs.pop("port", self.port)
username = kwargs.pop("username", self.username)
password = kwargs.pop("password", self.password)

if hostname is None:
# The logic here can be micro-optimized by using partition (not rpartition) if possible
netloc = self.netloc
_, _, hostname = netloc.rpartition("@")
at_pos = netloc.rfind("@")
if at_pos != -1:
hostname = netloc[at_pos + 1 :]
else:
hostname = netloc

if hostname[-1] != "]":
hostname = hostname.rsplit(":", 1)[0]
if hostname and hostname[-1] != "]":
col_pos = hostname.rfind(":")
if col_pos != -1:
hostname = hostname[:col_pos]

netloc = hostname
if port is not None:
Expand All @@ -140,12 +154,28 @@ def replace(self, **kwargs: Any) -> URL:

kwargs["netloc"] = netloc

# This is the single largest line-profiled time sink
# Optimize attribute lookup and instantiation
components = self.components._replace(**kwargs)
# Second largest time sink: __class__ object creation.
# Avoid additional attribute lookups and method calls.
return self.__class__(components.geturl())

def include_query_params(self, **kwargs: Any) -> URL:
# This function's bottleneck is the third-party libs (parse_qsl, urlencode)
# We can slightly optimize MultiDict/init and update.

# Avoid unnecessary dict creation in update below by combining updates in one step.
# Parse the query string efficiently.
params = MultiDict(parse_qsl(self.query, keep_blank_values=True))
params.update({str(key): str(value) for key, value in kwargs.items()})

# Directly pass items to update to avoid intermediary dicts
items = ((str(key), str(value)) for key, value in kwargs.items())
params.update(items)

# url encoding (profiled as major cost) can only be improved by avoiding
# unnecessary copies in multi_items
# Use MultiDict.multi_items directly, which, with below changes, simply returns _list
query = urlencode(params.multi_items())
return self.replace(query=query)

Expand Down Expand Up @@ -267,22 +297,30 @@ def __init__(

value: Any = args[0] if args else []
if kwargs:
value = ImmutableMultiDict(value).multi_items() + ImmutableMultiDict(kwargs).multi_items()

if not value:
_items: list[tuple[Any, Any]] = []
elif hasattr(value, "multi_items"):
value = cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value)
_items = list(value.multi_items())
elif hasattr(value, "items"):
value = cast(Mapping[_KeyType, _CovariantValueType], value)
_items = list(value.items())
# The following is not very efficient, as it repeatedly builds new ImmutableMultiDict instances
# Optimize by converting everything to a list beforehand and combining
value_items = []
# Only convert value to ImmutableMultiDict if necessary
if value:
if hasattr(value, "multi_items"):
value_items = list(cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value).multi_items())
elif hasattr(value, "items"):
value_items = list(cast(Mapping[_KeyType, _CovariantValueType], value).items())
else:
value_items = list(cast("list[tuple[Any, Any]]", value))
value_items += list(kwargs.items())
else:
value = cast("list[tuple[Any, Any]]", value)
_items = list(value)
if not value:
value_items = []
elif hasattr(value, "multi_items"):
value_items = list(cast(ImmutableMultiDict[_KeyType, _CovariantValueType], value).multi_items())
elif hasattr(value, "items"):
value_items = list(cast(Mapping[_KeyType, _CovariantValueType], value).items())
else:
value_items = list(cast("list[tuple[Any, Any]]", value))

self._dict = {k: v for k, v in _items}
self._list = _items
self._dict = {k: v for k, v in value_items}
self._list = value_items

def getlist(self, key: Any) -> list[_CovariantValueType]:
return [item_value for item_key, item_value in self._list if item_key == key]
Expand All @@ -297,7 +335,8 @@ def items(self) -> ItemsView[_KeyType, _CovariantValueType]:
return self._dict.items()

def multi_items(self) -> list[tuple[_KeyType, _CovariantValueType]]:
return list(self._list)
# hot path, avoid unnecessary list() copy if self._list is already a list
return self._list if isinstance(self._list, list) else list(self._list)

def __getitem__(self, key: _KeyType) -> _CovariantValueType:
return self._dict[key]
Expand Down Expand Up @@ -372,10 +411,34 @@ def update(
*args: MultiDict | Mapping[Any, Any] | list[tuple[Any, Any]],
**kwargs: Any,
) -> None:
value = MultiDict(*args, **kwargs)
existing_items = [(k, v) for (k, v) in self._list if k not in value.keys()]
self._list = existing_items + value.multi_items()
self._dict.update(value)
# Optimize by combining positional and keyword updates up front,
# and using sets for faster membership tests for existing keys.

if args or kwargs:
# Efficiently flatten all items from args and kwargs into one MultiDict
# Instead of repeatedly instantiating MultiDict, flatten first
flat_items = []
for arg in args:
if isinstance(arg, MultiDict):
flat_items.extend(arg.multi_items())
elif hasattr(arg, "items"):
flat_items.extend(arg.items())
else:
flat_items.extend(list(arg))
if kwargs:
flat_items.extend(kwargs.items())

value = MultiDict(flat_items)
else:
value = MultiDict()

# Optimize 'existing_items' computation using a set for O(1) lookup for keys instead of repeated scanning.
value_keys = set(value._dict)
existing_items = [item for item in self._list if item[0] not in value_keys]

# Avoid building new lists more than necessary
self._list = existing_items + value._list
self._dict.update(value._dict)


class QueryParams(ImmutableMultiDict[str, str]):
Expand Down