Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@ dependencies = [
"websocket-client==1.9.0",
# Data handling
"duckdb==1.4.0",
"orjson==3.11.5",
"msgspec==0.20.0",
"pydantic==2.12.0",
"pydantic_core==2.41.1",
Expand Down
6 changes: 2 additions & 4 deletions src/inference_endpoint/endpoint_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,13 @@

"""
Endpoint Client for the MLPerf Inference Endpoint Benchmarking System.

This module provides HTTP client implementation with multiprocessing and ZMQ.
This module provides HTTP client implementation.
"""

from .config import HTTPClientConfig
from .http_client import AsyncHttpEndpointClient, HTTPEndpointClient
from .http_client import HTTPEndpointClient

__all__ = [
"AsyncHttpEndpointClient",
"HTTPEndpointClient",
"HTTPClientConfig",
]
9 changes: 6 additions & 3 deletions src/inference_endpoint/endpoint_client/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,13 +117,16 @@ class HTTPClientConfig:
worker_gc_mode: Literal["disabled", "relaxed", "system"] = "relaxed"

# Request adapter for Query/Response <-> Payload/Response bytes
adapter: type[HttpRequestAdapter] | None = None # None: use default
# Default in __post_init__ if None
adapter: type[HttpRequestAdapter] = None # type: ignore[assignment]

# SSE accumulator for streaming responses
accumulator: type[SSEAccumulatorProtocol] | None = None # None: use default
# Default in __post_init__ if None
accumulator: type[SSEAccumulatorProtocol] = None # type: ignore[assignment]

# Worker pool transport class for worker IPC
worker_pool_transport: type[WorkerPoolTransport] | None = None # None: use default
# Default in __post_init__ if None
worker_pool_transport: type[WorkerPoolTransport] = None # type: ignore[assignment]

def __post_init__(self):
# set default adapter in __post_init__ to avoid circular dependency
Expand Down
9 changes: 6 additions & 3 deletions src/inference_endpoint/endpoint_client/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,7 +562,10 @@ def protocol_factory() -> HttpResponseProtocol:
self._creating -= 1

def release(self, conn: PooledConnection) -> None:
"""Return connection to pool for reuse and notify waiters."""
"""Return connection to pool for reuse and notify waiters (idempotent)."""
if not conn.in_use:
return

# Must close if: dead, server requested close, or error occurred
if not conn.is_alive() or conn.protocol.should_close:
self._close_connection(conn)
Expand Down Expand Up @@ -780,10 +783,10 @@ class InFlightRequest:
query_id: Correlates response back to original Query.
http_bytes: Serialized HTTP request for socket.write().
is_streaming: Whether this is a streaming (SSE) request or not.
connection: PooledConnection if any assigned to this request.
connection: PooledConnection assigned to this request (set once request is fired).
"""

query_id: str
http_bytes: bytes
is_streaming: bool
connection: PooledConnection | None = field(default=None, repr=False)
connection: PooledConnection = field(default=None, repr=False) # type: ignore[assignment]
59 changes: 20 additions & 39 deletions src/inference_endpoint/endpoint_client/http_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@
logger = logging.getLogger(__name__)


class AsyncHttpEndpointClient:
class HTTPEndpointClient:
"""
Async HTTP client for LLM inference.
HTTP client for LLM inference.

Architecture:
- Main process: Accepts requests, distributes to workers, handles responses
Expand All @@ -45,10 +45,12 @@ class AsyncHttpEndpointClient:

Usage:
with ManagedZMQContext.scoped() as zmq_ctx:
client = AsyncHttpEndpointClient(config, zmq_context=zmq_ctx)
client = HTTPEndpointClient(config, zmq_context=zmq_ctx)
client.issue(query)
response = await client.recv()
await client.shutdown()
response = client.poll() # Non-blocking, returns None if nothing ready
responses = client.drain() # Drain all available responses
# response = await client.recv() # Blocking; only if caller provides its own loop
client.shutdown() # Blocks until workers stop
"""

def __init__(
Expand All @@ -60,6 +62,8 @@ def __init__(
self.client_id = uuid.uuid4().hex[:8]
self.config = config
self._worker_cycle = cycle(range(self.config.num_workers))

# TODO(vir): make context setup/teardown part of transport protocol
if config.worker_pool_transport is ZmqWorkerPoolTransport:
if zmq_context is None:
raise ValueError(
Expand All @@ -85,9 +89,6 @@ def __init__(
# Initialize on event loop
asyncio.run_coroutine_threadsafe(self._initialize(), self.loop).result()

assert self.config.adapter is not None
assert self.config.accumulator is not None
assert self.config.worker_pool_transport is not None
logger.info(
f"EndpointClient initialized with num_workers={self.config.num_workers}, "
f"endpoints={self.config.endpoint_urls}, "
Expand Down Expand Up @@ -131,16 +132,20 @@ async def recv(self) -> QueryResult | StreamChunk | None:

def drain(self) -> list[QueryResult | StreamChunk]:
"""Non-blocking. Returns all available responses."""
results: list[QueryResult | StreamChunk] = []
while (r := self.poll()) is not None:
results.append(r)
return results
return list(iter(self.poll, None))

async def shutdown(self) -> None:
"""Gracefully shutdown client."""
logger.info(f"[{self.client_id}] Shutting down...")
def shutdown(self) -> None:
"""Gracefully shutdown client. Synchronous — blocks the caller until complete."""
if self._shutdown: # Already shutdown, no-op
return
asyncio.run_coroutine_threadsafe(self._shutdown_async(), self.loop).result()

async def _shutdown_async(self) -> None:
"""Async shutdown internals - must be called on the event loop."""
self._shutdown = True

logger.info(f"[{self.client_id}] Shutting down...")

# Shutdown workers
await self.worker_manager.shutdown()

Expand All @@ -154,27 +159,3 @@ async def shutdown(self) -> None:
f"[{self.client_id}] Dropped {self._dropped_requests} requests during shutdown"
)
logger.info(f"[{self.client_id}] Shutdown complete.")


class HTTPEndpointClient(AsyncHttpEndpointClient):
"""
Sync HTTP client for LLM inference.
Inherits from AsyncHttpEndpointClient and provides sync interface.

Usage:
client = HTTPEndpointClient(config)
client.issue(query)
"""

def issue(self, query: Query) -> None: # type: ignore[override]
"""Issue query."""
# Schedule on event loop thread
assert self.loop is not None
self.loop.call_soon_threadsafe(
lambda: super(HTTPEndpointClient, self).issue(query)
)

def shutdown(self) -> None: # type: ignore[override]
"""Sync shutdown wrapper - blocks until base class async shutdown completes."""
assert self.loop is not None
asyncio.run_coroutine_threadsafe(super().shutdown(), self.loop).result()
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ def __init__(
self.http_client = http_client

# Start response handler task to route completed responses back to SampleEventHandler
assert self.http_client.loop is not None
self._response_task = asyncio.run_coroutine_threadsafe(
self._handle_responses(), self.http_client.loop
)
Expand Down
70 changes: 24 additions & 46 deletions src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,20 +159,19 @@ def __init__(
if self._scheme == "https":
self._ssl_context = ssl.create_default_context()

# HTTP components
self._pool: ConnectionPool | None = None
self._http_template: HttpRequestTemplate | None = None
self._loop: asyncio.AbstractEventLoop | None = None
# HTTP components (initialized in run())
self._pool: ConnectionPool = None # type: ignore[assignment]
self._http_template: HttpRequestTemplate = None # type: ignore[assignment]
self._loop: asyncio.AbstractEventLoop = None # type: ignore[assignment]

# IPC transports
self._requests: ReceiverTransport | None = None
self._responses: SenderTransport | None = None
# IPC transports (initialized in run())
self._requests: ReceiverTransport = None # type: ignore[assignment]
self._responses: SenderTransport = None # type: ignore[assignment]

# Track active request tasks
self._active_tasks: set[asyncio.Task] = set()

# Use adapter type from config
assert self.http_config.adapter is not None
self._adapter: type[HttpRequestAdapter] = self.http_config.adapter

async def run(self) -> None:
Expand All @@ -184,7 +183,6 @@ async def run(self) -> None:
# Use eager task factory for immediate coroutine execution
# Tasks start executing synchronously until first await
# NOTE(vir): CRITICAL for minimizing TFB/TTFT
assert self._loop is not None
self._loop.set_task_factory(asyncio.eager_task_factory) # type: ignore[arg-type]

# Initialize HTTP template from URL components
Expand Down Expand Up @@ -267,7 +265,9 @@ async def run(self) -> None:
if self.http_config.record_worker_events:
pid = os.getpid()
worker_db_name = f"worker_report_{self.worker_id}_{pid}"
assert self.http_config.event_logs_dir is not None
assert (
self.http_config.event_logs_dir is not None
), "event_logs_dir must be set if record_worker_events is enabled"
report_path = self.http_config.event_logs_dir / f"{worker_db_name}.csv"

with EventRecorder(session_id=worker_db_name) as event_recorder:
Expand Down Expand Up @@ -327,16 +327,13 @@ async def _run_main_loop(self) -> None:
assert_active=True,
)

# Prepare request
prepared = self._prepare_request(query)

# Fire request
if not await self._fire_request(prepared):
# Prepare and fire request
req = self._prepare_request(query)
if not await self._fire_request(req):
continue

# Process response asynchronously
assert self._loop is not None
task = self._loop.create_task(self._process_response(prepared))
task = self._loop.create_task(self._process_response(req))

# Keep task alive to prevent GC
# Cleaned up in _process_response finally block
Expand All @@ -359,7 +356,6 @@ def _prepare_request(self, query: Query) -> InFlightRequest:
is_streaming = query.data.get("stream", False)

# Build complete HTTP request bytes
assert self._http_template is not None
http_bytes = self._http_template.build_request(
body_bytes,
is_streaming,
Expand All @@ -381,23 +377,21 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
Fire HTTP POST request:
1. Acquire TCP connection from pool
2. Send POST request bytes
3. Store connection for process_response task

Returns True on success.
Returns True on success, False on failure (error response sent).
"""
if self._shutdown:
await self._handle_error(req.query_id, "Worker is shutting down")
return False

try:
# Acquire connection from pool
assert self._pool is not None
conn = await self._pool.acquire()

# Write request bytes directly to transport
conn.protocol.write(req.http_bytes)

# Store connection for _process_response to use
# Store connection on req for response processing
req.connection = conn

return True
Expand All @@ -410,18 +404,14 @@ async def _fire_request(self, req: InFlightRequest) -> bool:
@profile
async def _process_response(self, req: InFlightRequest) -> None:
"""Process response for a fired request."""
try:
conn = req.connection
assert conn is not None, "Connection should be set by _fire_request"
conn = req.connection

try:
# Await headers and handle error status
status_code, _ = await conn.protocol.read_headers()
if status_code != 200:
error_body = await conn.protocol.read_body()
# Release connection early - done with socket I/O
assert self._pool is not None
self._pool.release(conn)
req.connection = None
await self._handle_error(
req.query_id,
f"HTTP {status_code}: {error_body.decode('utf-8', errors='replace')}",
Expand All @@ -439,11 +429,8 @@ async def _process_response(self, req: InFlightRequest) -> None:
logger.warning(f"Request {req.query_id} failed: {type(e).__name__}: {e}")

finally:
# Release connection back to pool if not already released
if req.connection:
assert self._pool is not None
self._pool.release(req.connection)
req.connection = None
# Release connection back to pool if not already
self._pool.release(conn)

# Record completion event
if self.http_config.record_worker_events:
Expand All @@ -462,18 +449,15 @@ async def _process_response(self, req: InFlightRequest) -> None:
@profile
async def _handle_streaming_body(self, req: InFlightRequest) -> None:
"""Handle streaming (SSE) response body."""
conn = req.connection
assert conn is not None
query_id = req.query_id
conn = req.connection

# Create accumulator for streaming response
assert self.http_config.accumulator is not None
accumulator = self.http_config.accumulator(
query_id, self.http_config.stream_all_chunks
)

# Process SSE stream - yields batches of chunks
assert self._responses is not None
async for chunk_batch in self._iter_sse_lines(conn):
for delta in chunk_batch:
if stream_chunk := accumulator.add_chunk(delta):
Expand All @@ -487,10 +471,8 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
assert_active=True,
)

# Release connection early - done with socket I/O
assert self._pool is not None
# Release connection early - done with socket I/O (idempotent)
self._pool.release(conn)
req.connection = None

# Send final complete back to main rank
self._responses.send(accumulator.get_final_output())
Expand All @@ -505,23 +487,19 @@ async def _handle_streaming_body(self, req: InFlightRequest) -> None:
@profile
async def _handle_non_streaming_body(self, req: InFlightRequest) -> None:
"""Handle non-streaming response body."""
conn = req.connection
assert conn is not None
query_id = req.query_id
conn = req.connection

# Read entire response body
response_bytes = await conn.protocol.read_body()

# Release connection early - done with socket I/O
assert self._pool is not None
# Release connection early - done with socket I/O (idempotent)
self._pool.release(conn)
req.connection = None

# Decode using adapter
result = self._adapter.decode_response(response_bytes, query_id)

# Send result back to main rank
assert self._responses is not None
self._responses.send(result)
if self.http_config.record_worker_events:
EventRecorder.record_event(
Expand Down
Loading
Loading