Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 22 additions & 28 deletions src/inference_endpoint/endpoint_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -674,7 +674,12 @@ class HttpRequestTemplate:
cached_headers: Pre-encoded headers from cache_headers(), included in every request.
"""

__slots__ = ("static_prefix", "cached_headers")
__slots__ = (
"static_prefix",
"cached_headers",
"_prefix_streaming",
"_prefix_non_streaming",
)

# Pre-encoded general headers
HEADERS_STREAMING = (
Expand All @@ -687,6 +692,7 @@ class HttpRequestTemplate:
def __init__(self, static_prefix: bytes):
self.static_prefix = static_prefix
self.cached_headers = b""
self._rebuild_prefixes()

@classmethod
def from_url(cls, host: str, port: int, path: str = "/") -> HttpRequestTemplate:
Expand Down Expand Up @@ -714,15 +720,21 @@ def from_url(cls, host: str, port: int, path: str = "/") -> HttpRequestTemplate:

return cls(static_prefix=(request_line + host_header).encode("ascii"))

def _rebuild_prefixes(self) -> None:
"""Merge static_prefix + cached_headers + content-type into two ready-to-use prefixes."""
base = self.static_prefix + self.cached_headers
self._prefix_streaming = base + self.HEADERS_STREAMING
self._prefix_non_streaming = base + self.HEADERS_NON_STREAMING

def cache_headers(self, headers: dict[str, str]) -> None:
"""
Pre-cache extra headers to avoid first-call encoding overhead.
Pre-encode headers that repeat on every request.

Call this during setup to prime the cache for headers that will be
used repeatedly at runtime.
Call this during setup so build_request() only needs body + content_length
at runtime.

Args:
headers: Headers to pre-encode and cache
headers: Headers to pre-encode and merge into the request prefix
"""
encoded = "".join(f"{k}: {v}\r\n" for k, v in headers.items()).encode(
"utf-8", "surrogateescape"
Expand All @@ -731,6 +743,7 @@ def cache_headers(self, headers: dict[str, str]) -> None:
# full header lines (e.g. "Authorization: Bearer ...\r\n"), not arbitrary fragments.
if encoded not in self.cached_headers:
self.cached_headers += encoded
self._rebuild_prefixes()

def build_request(
self,
Expand All @@ -749,38 +762,19 @@ def build_request(
Returns:
Complete HTTP request in bytes.
"""
content_type_headers = (
self.HEADERS_STREAMING if streaming else self.HEADERS_NON_STREAMING
)
prefix = self._prefix_streaming if streaming else self._prefix_non_streaming
content_length = f"Content-Length: {len(body)}\r\n\r\n".encode("ascii")

# Fast path: no extra headers
# Fast path: only body + content_length vary per request
if not extra_headers:
return b"".join(
[
self.static_prefix,
self.cached_headers,
content_type_headers,
content_length,
body,
]
)
return b"".join([prefix, content_length, body])

# Slow path: extra headers are encoded per-call;
# use cache_headers() at setup time for headers that repeat every request.
extra = "".join(f"{k}: {v}\r\n" for k, v in extra_headers.items()).encode(
"utf-8", "surrogateescape"
)
return b"".join(
[
self.static_prefix,
self.cached_headers,
content_type_headers,
extra,
content_length,
body,
]
)
return b"".join([prefix, extra, content_length, body])


@dataclass(slots=True)
Expand Down