Skip to content
6 changes: 3 additions & 3 deletions examples/01_LocalBenchmark/run_tinyllm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@
import argparse
import threading

import inference_endpoint.rulesets.mlcommons.models as mlcommons_models
import inference_endpoint.config.rulesets.mlcommons.models as mlcommons_models
from inference_endpoint.config.rulesets.mlcommons.rules import CURRENT
from inference_endpoint.config.user_config import UserConfig
from inference_endpoint.core.types import QueryResult, StreamChunk
from inference_endpoint.dataset_manager.dataloader import DataLoader
Expand All @@ -30,7 +31,6 @@
SampleIssuer,
WithoutReplacementSampleOrder,
)
from inference_endpoint.rulesets.mlcommons.rules import CURRENT
from tqdm import tqdm
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from transformers.utils import logging
Expand Down Expand Up @@ -167,7 +167,7 @@ def issue(self, sample):
)
SampleEventHandler.stream_chunk_complete(stream_chunk)
first = False
query_result = QueryResult(id=sample.uuid, response_output="".join(chunks))
query_result = QueryResult(id=sample.uuid, response_output=chunks)
else:
response = self.compute_func(sample.data)
query_result = QueryResult(id=sample.uuid, response_output=response)
Expand Down
16 changes: 14 additions & 2 deletions src/inference_endpoint/commands/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pathlib import Path
from urllib.parse import urljoin

from tqdm import tqdm
from transformers import AutoTokenizer
from transformers.utils import logging as transformers_logging

Expand Down Expand Up @@ -100,7 +101,7 @@ class ResponseCollector:
count: Total number of completed queries (success + failure).
"""

def __init__(self, collect_responses: bool = False):
def __init__(self, collect_responses: bool = False, pbar: tqdm | None = None):
"""Initialize response collector.

Args:
Expand All @@ -112,6 +113,8 @@ def __init__(self, collect_responses: bool = False):
self.errors: list[str] = []
self.count = 0

self.pbar = pbar

def on_complete_hook(self, result: QueryResult):
"""Callback invoked when a query completes (success or failure).

Expand All @@ -128,6 +131,9 @@ def on_complete_hook(self, result: QueryResult):
elif self.collect_responses:
self.responses[result.id] = result.response_output

if self.pbar:
self.pbar.update(1)


async def run_benchmark_command(args: argparse.Namespace) -> None:
"""Run performance benchmark in offline, online, or YAML-configured mode.
Expand Down Expand Up @@ -531,7 +537,12 @@ def _run_benchmark(
raise SetupError(str(e)) from e

# Setup response collector
response_collector = ResponseCollector(collect_responses=collect_responses)
pbar = tqdm(
desc=f"{model_name} (Streaming: {enable_streaming})", total=total_samples
)
response_collector = ResponseCollector(
collect_responses=collect_responses, pbar=pbar
)
SampleEventHandler.register_hook(
SampleEvent.COMPLETE, response_collector.on_complete_hook
)
Expand Down Expand Up @@ -669,6 +680,7 @@ def signal_handler(signum, frame):
# Cleanup - always execute
logger.info("Cleaning up...")
try:
pbar.close()
sample_issuer.shutdown()
http_client.shutdown()
shutil.rmtree(tmp_dir, ignore_errors=True)
Expand Down
16 changes: 14 additions & 2 deletions src/inference_endpoint/core/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,18 +89,23 @@ class QueryResult(msgspec.Struct, tag="query_result", kw_only=True, frozen=True)
Attributes:
id: Query identifier (matches the originating Query.id).
response_output: Generated text response from the endpoint (None if error).
Can be a string, or a tuple of strings. If it is a string,
it is assumed to be a non-streaming response. If it is a
tuple of strings, it is assumed to be a streamed response,
where the first element is the first chunk, which will not
be included in the TPOT measurements.
metadata: Additional response metadata (token counts, model info, etc.).
error: Error message if query failed (None if successful).
completed_at: High-resolution timestamp (nanoseconds, monotonic clock).
Auto-set in __post_init__ to prevent tampering.
Auto-set in __post_init__ to prevent tampering.

Note:
The completed_at field is intentionally set internally to prevent
benchmark result manipulation. Users must not override this timestamp.
"""

id: str = ""
response_output: str | None = None
response_output: str | tuple[str, ...] | None = None
metadata: dict[str, Any] = msgspec.field(default_factory=dict)
error: str | None = None
completed_at: float = msgspec.UNSET
Expand All @@ -119,6 +124,13 @@ def __post_init__(self):
# Timestamp must be generated internally
msgspec.structs.force_setattr(self, "completed_at", time.monotonic_ns())

# A list can be passed on, but we need to convert it to a tuple to maintain immutability,
# and for serialization to work properly.
if isinstance(self.response_output, list):
msgspec.structs.force_setattr(
self, "response_output", tuple(self.response_output)
)


class StreamChunk(msgspec.Struct, tag="stream_chunk", kw_only=True):
"""A single chunk from a streaming inference response.
Expand Down
9 changes: 7 additions & 2 deletions src/inference_endpoint/endpoint_client/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,11 +394,16 @@ async def _handle_streaming_request(self, query: Query) -> None:
first_chunk_sent = True

# Send final complete response
final_output = "".join(accumulated_content)
response_output = []
if accumulated_content:
response_output.append(accumulated_content[0])
if len(accumulated_content) > 1:
response_output.append("".join(accumulated_content[1:]))

await self._response_socket.send(
QueryResult(
id=query.id,
response_output=final_output,
response_output=response_output,
metadata={"first_chunk": not first_chunk_sent, "final_chunk": True},
)
)
Expand Down
3 changes: 1 addition & 2 deletions src/inference_endpoint/load_generator/events.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class SessionEvent(Event):
LOADGEN_ISSUE_CALLED = "loadgen_issue_called"
LOADGEN_STOP = "loadgen_stop"
LOADGEN_DATA_LOAD = "loadgen_data_load"

# TODO: Add an event to record errors occurring
ERROR = "error"


class SampleEvent(Event):
Expand Down
5 changes: 4 additions & 1 deletion src/inference_endpoint/load_generator/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from typing import Any

from ..core.types import QueryResult, StreamChunk
from ..metrics.recorder import EventRecorder
from ..metrics.recorder import EventRecorder, record_exception
from .events import SampleEvent

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -149,6 +149,7 @@ def stream_chunk_complete(self, chunk: StreamChunk) -> None:
SampleEvent.FIRST_CHUNK,
timestamp_ns,
sample_uuid=chunk.id,
output=chunk.response_chunk,
)
hooks = self.first_chunk_hooks
else:
Expand Down Expand Up @@ -179,6 +180,8 @@ def query_result_complete(self, result: QueryResult) -> None:
if result.error is not None:
logger.error(f"Error in request {result.id}: {result.error}")

record_exception(result.error, result.id)

EventRecorder.record_event(
SampleEvent.COMPLETE,
timestamp_ns,
Expand Down
59 changes: 52 additions & 7 deletions src/inference_endpoint/metrics/recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import shutil
import sqlite3
import threading
import time
import uuid
from functools import partial
from pathlib import Path
Expand Down Expand Up @@ -265,14 +266,34 @@ def commit_buffer():
should_commit = True
else:
# Regular event - add to buffer
# Format: (sample_uuid, event_type, timestamp_ns, output)
event_buffer.append(item[:-1])
if item[-1] is not None and item[1] == SampleEvent.COMPLETE.value:
output_buffer.append(
{
"s_uuid": item[0],
"output": item[-1],
}
)
if item[-1] is not None:
if item[1] == SampleEvent.FIRST_CHUNK.value:
# In post-processing, we use this to validate that the first chunk is the response output is the same as the data in the FIRST_CHUNK_RECEIVED event
output_buffer.append(
{"s_uuid": item[0], "first_chunk": item[-1]}
)
elif item[1] == SampleEvent.COMPLETE.value:
output_data = item[-1]
if not isinstance(output_data, list | tuple | str):
raise TypeError(
f"QueryResult.response_output should be a list or tuple or str, but got {type(output_data)}"
)
output_buffer.append(
{
"s_uuid": item[0],
"output": output_data,
}
)
elif item[1] == SessionEvent.ERROR.value:
output_buffer.append(
{
"s_uuid": item[0],
"error_type": item[1],
"error_message": item[-1],
}
)
should_commit = len(event_buffer) >= self.txn_buffer_size

# Commit if buffer is full
Expand Down Expand Up @@ -435,3 +456,27 @@ def __enter__(self):
def __exit__(self, exc_type, exc_value, traceback):
"""Context manager exit - stops the writer thread."""
self.close()


def record_exception(
exc_value: Exception | str,
sample_uuid: str | None = None,
):
"""Records an exception as an event to the current event recorder.

This will force commit the existing event buffer immediately to ensure the error is surfaced
as soon as possible for any monitoring.

Args:
exc_value: The exception to record, or a string error message.
sample_uuid: The sample uuid to record the error for.
"""
if EventRecorder.LIVE is None:
return
EventRecorder.record_event(
SessionEvent.ERROR,
time.monotonic_ns(),
sample_uuid=sample_uuid,
output=str(exc_value),
force_commit=True,
)
Loading
Loading