Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,9 @@ filterwarnings = [

[tool.coverage.run]
source = ["src"]
concurrency = ["multiprocessing", "thread"]
parallel = true
sigterm = true
omit = [
"*/tests/*",
"*/test_*",
Expand Down
14 changes: 8 additions & 6 deletions src/inference_endpoint/endpoint_client/accumulator_protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,17 +32,19 @@ class SSEAccumulatorProtocol(Protocol):
is disabled, only the first chunk is emitted via add_chunk().
"""

def __init__(self, query_id: str, stream_all_chunks: bool) -> None:
def __init__(
self, query_id: str, stream_all_chunks: bool
) -> None: # pragma: no cover
"""
Initialize the accumulator.

Args:
query_id: Unique identifier for the request being accumulated.
stream_all_chunks: If True, emit all chunks; if False, only first chunk.
"""
pass
...

def add_chunk(self, delta: Any) -> StreamChunk | None:
def add_chunk(self, delta: Any) -> StreamChunk | None: # pragma: no cover
"""
Process an SSE delta and optionally emit a StreamChunk.

Expand All @@ -54,9 +56,9 @@ def add_chunk(self, delta: Any) -> StreamChunk | None:
Returns None for empty deltas, or after first chunk when
stream_all_chunks=False (TTFT-only mode).
"""
pass
...

def get_final_output(self) -> QueryResult:
def get_final_output(self) -> QueryResult: # pragma: no cover
"""
Return the final accumulated result after stream completion.

Expand All @@ -66,4 +68,4 @@ def get_final_output(self) -> QueryResult:
Returns:
QueryResult with the complete response output.
"""
pass
...
74 changes: 41 additions & 33 deletions src/inference_endpoint/endpoint_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,8 @@ class PooledConnection:
"last_used",
"in_use",
"idle_time_on_acquire",
"_fd",
"_stale_poller",
)

def __init__(
Expand All @@ -385,6 +387,11 @@ def __init__(
self.in_use = True
self.idle_time_on_acquire = 0.0

# Cache fd for stale checks — stable for the lifetime of the connection
sock = transport.get_extra_info("socket")
self._fd: int = sock.fileno() if sock is not None else -1
self._stale_poller: select.poll | None = None

def is_alive(self) -> bool:
"""Check if the connection is still usable.

Expand All @@ -405,31 +412,34 @@ def is_stale(self) -> bool:
For idle HTTP keep-alive connections, there should be no pending data.
If the socket is readable, it means the server sent FIN (EOF).

Optimization: Skip check for recently-used connections (< 1 second).
Uses poll() instead of select() to avoid FD_SETSIZE limit on high fds.
Poller is created lazily on first call and reused (fd is stable per connection).
"""
# Skip stale check for recently-used connections
# Server unlikely to close within 1 second of last use
if time.monotonic() - self.last_used < 1.0:
return False

if self.transport is None:
return True
# Fast path: poller already registered from a previous call
if self._stale_poller is not None:
try:
return bool(self._stale_poller.poll(0))
except (OSError, ValueError):
# fd closed or invalid — connection is dead, treat as stale
return True

# Get the socket file descriptor
sock = self.transport.get_extra_info("socket")
if sock is None:
# Slow path: first call — create poller and register fd
if self._fd < 0:
return True

try:
fd = sock.fileno()
if fd < 0:
return True

# Use select with zero timeout - avoids poll() object creation overhead
readable, _, exceptional = select.select([fd], [], [fd], 0)
return bool(readable or exceptional)
poller = select.poll()
poller.register(self._fd, select.POLLIN | select.POLLERR | select.POLLHUP)
self._stale_poller = poller
return bool(poller.poll(0))

except (OSError, ValueError):
# fd closed or invalid — connection is dead, treat as stale
return True


Expand Down Expand Up @@ -660,10 +670,11 @@ class HttpRequestTemplate:
that remain constant across requests to a given endpoint.

Attributes:
static_prefix: Pre-merged request line + host header bytes
static_prefix: Pre-merged request line + host header bytes.
cached_headers: Pre-encoded headers from cache_headers(), included in every request.
"""

__slots__ = ("static_prefix", "_extra_headers_cache", "extra_cached_headers")
__slots__ = ("static_prefix", "cached_headers")

# Pre-encoded general headers
HEADERS_STREAMING = (
Expand All @@ -675,8 +686,7 @@ class HttpRequestTemplate:

def __init__(self, static_prefix: bytes):
self.static_prefix = static_prefix
self._extra_headers_cache: dict[frozenset, bytes] = {}
self.extra_cached_headers = b""
self.cached_headers = b""

@classmethod
def from_url(cls, host: str, port: int, path: str = "/") -> HttpRequestTemplate:
Expand Down Expand Up @@ -714,12 +724,13 @@ def cache_headers(self, headers: dict[str, str]) -> None:
Args:
headers: Headers to pre-encode and cache
"""
cache_key = frozenset(headers.items())
if cache_key not in self._extra_headers_cache:
self._extra_headers_cache[cache_key] = "".join(
f"{k}: {v}\r\n" for k, v in headers.items()
).encode("utf-8", "surrogateescape")
self.extra_cached_headers = b"".join(self._extra_headers_cache.values())
encoded = "".join(f"{k}: {v}\r\n" for k, v in headers.items()).encode(
"utf-8", "surrogateescape"
)
# Substring dedup: safe because this is called once at setup with
# full header lines (e.g. "Authorization: Bearer ...\r\n"), not arbitrary fragments.
if encoded not in self.cached_headers:
self.cached_headers += encoded

def build_request(
self,
Expand Down Expand Up @@ -748,25 +759,22 @@ def build_request(
return b"".join(
[
self.static_prefix,
self.extra_cached_headers,
self.cached_headers,
content_type_headers,
content_length,
body,
]
)

# Slow path: extra headers (~1us uncached, ~50ns per extra-header cached)
cache_key = frozenset(extra_headers.items())
if (extra := self._extra_headers_cache.get(cache_key)) is None:
extra = "".join(f"{k}: {v}\r\n" for k, v in extra_headers.items()).encode(
"utf-8", "surrogateescape"
)
self._extra_headers_cache[cache_key] = extra

# Slow path: extra headers are encoded per-call;
# use cache_headers() at setup time for headers that repeat every request.
extra = "".join(f"{k}: {v}\r\n" for k, v in extra_headers.items()).encode(
"utf-8", "surrogateescape"
)
return b"".join(
[
self.static_prefix,
self.extra_cached_headers,
self.cached_headers,
content_type_headers,
extra,
content_length,
Expand Down
2 changes: 1 addition & 1 deletion src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
# - uvloop requires use of 'spawn'
try:
multiprocessing.set_start_method("spawn", force=False)
except RuntimeError:
except RuntimeError: # pragma: no cover
# Already set, which is fine (likely in tests or when importing multiple times)
pass

Expand Down
12 changes: 4 additions & 8 deletions src/inference_endpoint/utils/benchmark_httpclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,7 +636,7 @@ async def receiver():
# ---------------------------------------------------------------------------


_FULL_WORKERS = [1, 2, 4, 8, 10, 12, 14, 16]
_FULL_WORKERS = [1, 2, 4, 6, 8, 10, 12, 14, 16]
_FULL_PROMPT_LENGTHS = [
1,
32,
Expand Down Expand Up @@ -905,13 +905,9 @@ def run_sweep(
print(f"{'='*70}")

# Restart server when prompt_length or stream_interval changes
if (
server
and args.streaming
and (
prompt_length != last_prompt_length
or stream_interval != last_stream_interval
)
if server and (
prompt_length != last_prompt_length
or stream_interval != last_stream_interval
):
_restart_server(server, prompt_length, args.streaming, stream_interval)
last_prompt_length = prompt_length
Expand Down
Loading