Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
534 changes: 534 additions & 0 deletions .cursor/rules/msgspec-patterns.mdc

Large diffs are not rendered by default.

59 changes: 56 additions & 3 deletions src/inference_endpoint/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,14 @@ class QueryStatus(Enum):
_OUTPUT_RESULT_TYPE = str | tuple[str, ...] | _OUTPUT_DICT_TYPE | None


class Query(msgspec.Struct, kw_only=True): # type: ignore[call-arg]
class Query(
msgspec.Struct,
frozen=True,
kw_only=True,
array_like=True,
omit_defaults=True,
gc=False,
): # type: ignore[call-arg]
"""Represents a single inference query to be sent to an endpoint.

A Query encapsulates all information needed to make an HTTP request to
Expand All @@ -72,6 +79,17 @@ class Query(msgspec.Struct, kw_only=True): # type: ignore[call-arg]
... data={"prompt": "Hello", "model": "Qwen/Qwen3-8B", "max_tokens": 100},
... headers={"Authorization": "Bearer token123"},
... )

Note:
gc=False: Safe because data/headers are simple key-value pairs without cycles.
Do NOT store self-referential or cyclic structures in data/headers fields.

array_like=True: Encodes as array instead of object (e.g., ["id", {...}, {...}, 0.0]
instead of {"id": ..., "data": ..., ...}). Provides ~6-50% size reduction and
~6-29% ser/des speedup for ZMQ transport depending on payload size.

omit_defaults=True: Fields with default values are omitted during encoding,
further reducing message size for queries with empty headers.
"""

id: str = msgspec.field(default_factory=lambda: str(uuid.uuid4()))
Expand All @@ -80,7 +98,15 @@ class Query(msgspec.Struct, kw_only=True): # type: ignore[call-arg]
created_at: float = msgspec.field(default_factory=time.time)


class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True): # type: ignore[call-arg]
class QueryResult(
msgspec.Struct,
tag="query_result",
kw_only=True,
frozen=True,
array_like=True,
omit_defaults=True,
gc=False,
): # type: ignore[call-arg]
"""Result of a completed inference query.

Represents the outcome of processing a Query, including the response text,
Expand All @@ -106,6 +132,15 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True)
Note:
The completed_at field is intentionally set internally to prevent
benchmark result manipulation. Users must not override this timestamp.

gc=False: Safe because metadata contains only scalar key-value pairs.
Do NOT store cyclic references in metadata or response_output fields.

omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory)
are omitted if value equals default.

array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}]
instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size.
"""

id: str = ""
Expand Down Expand Up @@ -154,7 +189,15 @@ def get_response_output_string(self) -> str:
return "<EMPTY>"


class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True): # type: ignore[call-arg]
class StreamChunk(
msgspec.Struct,
tag="stream_chunk",
frozen=True,
kw_only=True,
array_like=True,
omit_defaults=True,
gc=False,
): # type: ignore[call-arg]
"""A single chunk from a streaming inference response.

Streaming responses are sent incrementally as the model generates text.
Expand All @@ -174,6 +217,16 @@ class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True): # type: ig
Streaming "Hello World" might produce:
>>> StreamChunk(id="q1", response_chunk="Hello", is_complete=False)
>>> StreamChunk(id="q1", response_chunk=" World", is_complete=True)

Note:
gc=False: Safe because metadata contains only scalar key-value pairs.
Do NOT store cyclic references in metadata field.

omit_defaults=True: Fields with static defaults (ie. those NOT using default_factory)
are omitted if value equals default.

array_like=True: Encodes as array instead of object (e.g. ["id", "chunk", false, {}]
instead of {"id": ..., "response_chunk": ..., ...}). Reduces payload size.
"""

id: str = ""
Expand Down
53 changes: 40 additions & 13 deletions src/inference_endpoint/openai/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,40 +24,55 @@
# ============================================================================


class SSEDelta(msgspec.Struct):
# NOTE(vir): msgspec usage
# omit_defaults=True: Fields with static defaults are omitted if value equals default (ie those not using default_factory)
# gc=False: Safe for request/response structs with scalar and nested struct fields only.
# frozen=True: Makes structs immutable and hashable, also enables faster struct decoding
# (direct attribute access via fixed memory offset vs hash table lookup)


class SSEDelta(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False): # type: ignore[call-arg]
"""SSE delta object containing content."""

content: str = ""
reasoning: str = ""


class SSEChoice(msgspec.Struct):
class SSEChoice(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""SSE choice object containing delta."""

delta: SSEDelta = msgspec.field(default_factory=SSEDelta)
delta: SSEDelta | None = None
finish_reason: str | None = None


class SSEMessage(msgspec.Struct):
class SSEMessage(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""SSE message structure for OpenAI streaming responses."""

choices: list[SSEChoice] = msgspec.field(default_factory=list)
choices: tuple[SSEChoice, ...] = ()


# ============================================================================
# OpenAI Chat Completion Types (msgspec-based)
# OpenAI Chat Completion Types
# ============================================================================


class ChatMessage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class ChatMessage(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""Chat message in OpenAI format."""

role: str
content: str
name: str | None = None


class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class ChatCompletionRequest(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""OpenAI chat completion request."""

model: str
Expand All @@ -77,32 +92,44 @@ class ChatCompletionRequest(msgspec.Struct, kw_only=True, omit_defaults=True):
chat_template: str | None = None


class ChatCompletionResponseMessage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class ChatCompletionResponseMessage(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""Response message from OpenAI."""

role: str
content: str | None
refusal: str | None


class ChatCompletionChoice(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class ChatCompletionChoice(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""A single choice in the completion response."""

index: int
message: ChatCompletionResponseMessage
finish_reason: str | None


class CompletionUsage(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class CompletionUsage(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
"""Token usage statistics."""

prompt_tokens: int
completion_tokens: int
total_tokens: int


class ChatCompletionResponse(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
"""OpenAI chat completion response (msgspec version)."""
class ChatCompletionResponse(
msgspec.Struct,
frozen=True,
kw_only=True,
omit_defaults=False,
gc=False,
): # type: ignore[call-arg]
"""OpenAI chat completion response."""

id: str
object: str = "chat.completion"
Expand Down
20 changes: 14 additions & 6 deletions src/inference_endpoint/sglang/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@
# ============================================================================


class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class SamplingParams(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
max_new_tokens: int = 32768
"""int: Maximum number of tokens to generate per request (1-32768)"""

Expand All @@ -37,16 +39,18 @@ class SamplingParams(msgspec.Struct, kw_only=True, omit_defaults=True): # type:
"""int: Top-k sampling (number of highest probability tokens to consider). -1 = disable"""

top_p: float = 1.0
"""float: Top-p/nucleus sampling (cumulative probability threshold). 0.0-1.0, typically 1.0 for no filterin"""
"""float: Top-p/nucleus sampling (cumulative probability threshold). 0.0-1.0, typically 1.0 for no filtering"""


class SGLangGenerateRequest(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class SGLangGenerateRequest(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
input_ids: list[int]
sampling_params: SamplingParams
stream: bool


class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class MetaInfo(msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False): # type: ignore[call-arg]
id: str
finish_reason: dict[str, Any]
prompt_tokens: int
Expand All @@ -57,13 +61,17 @@ class MetaInfo(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignor
e2e_latency: float


class SGLangGenerateResponse(msgspec.Struct, kw_only=True, omit_defaults=True): # type: ignore[call-arg]
class SGLangGenerateResponse(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
text: str
output_ids: list[int]
meta_info: MetaInfo


class SGLangSSEDelta(msgspec.Struct):
class SGLangSSEDelta(
msgspec.Struct, frozen=True, kw_only=True, omit_defaults=True, gc=False
): # type: ignore[call-arg]
text: str = ""
token_delta: list[int] = msgspec.field(default_factory=list)
total_completion_tokens: int = 0
Expand Down
Loading