Skip to content

Commit

Permalink
Add tracing to Truss Server, fixes BT-11854 (#1104)
Browse files Browse the repository at this point in the history
* Add tracing

* Fix tests.

* Debug why streamining hangs

* Fix tests, add truss config option.

* Add legacy export endpoint

* Add issue ID to requirements TODO
  • Loading branch information
marius-baseten authored Sep 3, 2024
1 parent 9a84bba commit 6484697
Show file tree
Hide file tree
Showing 22 changed files with 1,166 additions and 414 deletions.
789 changes: 540 additions & 249 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 4 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.30rc5"
version = "0.9.30rc701"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down Expand Up @@ -41,6 +41,9 @@ loguru = ">=0.7.2"
msgpack = ">=1.0.2"
msgpack-numpy = ">=0.4.8"
numpy = ">=1.23.5"
opentelemetry-api = ">=1.25.0"
opentelemetry-sdk = ">=1.25.0"
opentelemetry-exporter-otlp = ">=1.25.0"
packaging = ">=20.9"
pathspec = ">=0.9.0"
psutil = ">=5.9.4"
Expand Down
5 changes: 4 additions & 1 deletion truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
# are detected and cause a build failure. If there are no
# requirements provided, we just pass an empty string,
# as there's no need to install anything.
# TODO (BT-10217): above reasoning leads to inconsistencies. To get consistent
# images tentatively add server requirements always. This whole point needs
# more thought and potentially a re-design.
user_provided_python_requirements = (
base_server_requirements + spec.requirements_txt
if spec.requirements
else ""
else base_server_requirements
)
if spec.requirements_file is not None:
copy_into_build_dir(
Expand Down
165 changes: 165 additions & 0 deletions truss/templates/server/common/tracing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
import contextlib
import json
import logging
import os
import pathlib
import time
from typing import Iterator, List, Optional, Sequence

import opentelemetry.exporter.otlp.proto.http.trace_exporter as oltp_exporter
import opentelemetry.sdk.resources as resources
import opentelemetry.sdk.trace as sdk_trace
import opentelemetry.sdk.trace.export as trace_export
from opentelemetry import context, trace
from shared import secrets_resolver

logger = logging.getLogger(__name__)

ATTR_NAME_DURATION = "duration_sec"
# "New" Jaeger exporter.
OTEL_EXPORTER_OTLP_ENDPOINT = "OTEL_EXPORTER_OTLP_ENDPOINT"
# "Old" Honeycomb exporter. We might want to temporarily export to both, but
# eventually remove this one.
OTEL_EXPORTER_OTLP_ENDPOINT_LEGACY = "OTEL_EXPORTER_OTLP_ENDPOINT_LEGACY"
# Writing trace data to a file is only intended for testing / debugging.
OTEL_TRACING_NDJSON_FILE = "OTEL_TRACING_NDJSON_FILE"
# Exporting trace data to a public honeycomb instance (not our cluster collector)
# is intended only for testing / debugging.
HONEYCOMB_DATASET = "HONEYCOMB_DATASET"
HONEYCOMB_API_KEY = "HONEYCOMB_API_KEY"

DEFAULT_ENABLE_TRACING_DATA = False # This should be in sync with truss_config.py.


class JSONFileExporter(trace_export.SpanExporter):
"""Writes spans to newline-delimited JSON file for debugging / testing."""

def __init__(self, file_path: pathlib.Path):
self._file = file_path.open("a")

def export(
self, spans: Sequence[sdk_trace.ReadableSpan]
) -> trace_export.SpanExportResult:
for span in spans:
# Get rid of newlines and whitespace.
self._file.write(json.dumps(json.loads(span.to_json())))
self._file.write("\n")
self._file.flush()
return trace_export.SpanExportResult.SUCCESS

def shutdown(self) -> None:
self._file.close()


_truss_tracer: Optional[trace.Tracer] = None


def get_truss_tracer(secrets: secrets_resolver.SecretsResolver, config) -> trace.Tracer:
"""Creates a cached tracer (i.e. runtime-singleton) to be used for truss
internal tracing.
The goal is to separate truss-internal tracing instrumentation
completely from potential user-defined tracing - see also `detach_context` below.
"""
enable_tracing_data = config.get("runtime", {}).get(
"enable_tracing_data", DEFAULT_ENABLE_TRACING_DATA
)

global _truss_tracer
if _truss_tracer:
return _truss_tracer

span_processors: List[sdk_trace.SpanProcessor] = []
if otlp_endpoint := os.getenv(OTEL_EXPORTER_OTLP_ENDPOINT):
logger.info(f"Exporting trace data to {OTEL_EXPORTER_OTLP_ENDPOINT}.")
otlp_exporter = oltp_exporter.OTLPSpanExporter(endpoint=otlp_endpoint)
otlp_processor = sdk_trace.export.BatchSpanProcessor(otlp_exporter)
span_processors.append(otlp_processor)

if otlp_endpoint := os.getenv(OTEL_EXPORTER_OTLP_ENDPOINT_LEGACY):
logger.info(f"Exporting trace data to {OTEL_EXPORTER_OTLP_ENDPOINT_LEGACY}.")
otlp_exporter = oltp_exporter.OTLPSpanExporter(endpoint=otlp_endpoint)
otlp_processor = sdk_trace.export.BatchSpanProcessor(otlp_exporter)
span_processors.append(otlp_processor)

if tracing_log_file := os.getenv(OTEL_TRACING_NDJSON_FILE):
logger.info(f"Exporting trace data to file `{tracing_log_file}`.")
json_file_exporter = JSONFileExporter(pathlib.Path(tracing_log_file))
file_processor = sdk_trace.export.SimpleSpanProcessor(json_file_exporter)
span_processors.append(file_processor)

if (
honeycomb_dataset := os.getenv(HONEYCOMB_DATASET)
) and HONEYCOMB_API_KEY in secrets:
honeycomb_api_key = secrets[HONEYCOMB_API_KEY]
logger.info("Exporting trace data to honeycomb.")
honeycomb_exporter = oltp_exporter.OTLPSpanExporter(
endpoint="https://api.honeycomb.io/v1/traces",
headers={
"x-honeycomb-team": honeycomb_api_key,
"x-honeycomb-dataset": honeycomb_dataset,
},
)
honeycomb_processor = sdk_trace.export.BatchSpanProcessor(honeycomb_exporter)
span_processors.append(honeycomb_processor)

if span_processors and enable_tracing_data:
logger.info("Instantiating truss tracer.")
resource = resources.Resource.create({resources.SERVICE_NAME: "TrussServer"})
trace_provider = sdk_trace.TracerProvider(resource=resource)
for sp in span_processors:
trace_provider.add_span_processor(sp)
tracer = trace_provider.get_tracer("truss_server")
else:
if enable_tracing_data:
logger.info(
"Using no-op tracing (tracing is enabled, but no exporters configured)."
)
else:
logger.info("Using no-op tracing (tracing was disabled).")

tracer = sdk_trace.NoOpTracer()

_truss_tracer = tracer
return _truss_tracer


@contextlib.contextmanager
def detach_context() -> Iterator[trace.Context]:
"""Breaks opentelemetry's context propagation.
The goal is to separate truss-internal tracing instrumentation
completely from potential user-defined tracing. Opentelemetry has a global state
that makes "outer" span-contexts parents of nested spans. If user-code in a
truss model also uses tracing, these traces could easily become polluted with our
internal contexts. Therefore, all user code (predict and pre/post-processing) should
be wrapped in this context for isolation.
"""
current_context = context.get_current()
# Create an invalid tracing context. This forces that tracing code inside this
# context manager creates a new root tracing context.
transient_token = context.attach(trace.set_span_in_context(trace.INVALID_SPAN))
try:
yield current_context
finally:
# Reattach original context.
context.detach(transient_token)
context.attach(current_context)


@contextlib.contextmanager
def section_as_event(span: sdk_trace.Span, section_name: str) -> Iterator[None]:
"""Helper to record the start and end of a sections as events and the duration.
Note that events are much cheaper to create than dedicated spans.
"""
t0 = time.time()
span.add_event(f"start: {section_name}")
try:
yield
finally:
t1 = time.time()
span.add_event(
f"done: {section_name}", attributes={ATTR_NAME_DURATION: t1 - t0}
)
98 changes: 60 additions & 38 deletions truss/templates/server/common/truss_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,16 @@
import common.errors as errors
import shared.util as utils
import uvicorn
from common import tracing
from common.termination_handler_middleware import TerminationHandlerMiddleware
from fastapi import Depends, FastAPI, HTTPException, Request
from fastapi.responses import ORJSONResponse, StreamingResponse
from fastapi.routing import APIRoute as FastAPIRoute
from model_wrapper import ModelWrapper
from opentelemetry import propagate as otel_propagate
from opentelemetry.sdk import trace as sdk_trace
from shared.logging import setup_logging
from shared.secrets_resolver import SecretsResolver
from shared.serialization import (
DeepNumpyEncoder,
truss_msgpack_deserialize,
Expand Down Expand Up @@ -81,8 +85,9 @@ class BasetenEndpoints:
to functions will rename unused except for backwards compatibility checks.
"""

def __init__(self, model: ModelWrapper) -> None:
def __init__(self, model: ModelWrapper, tracer: sdk_trace.Tracer) -> None:
self._model = model
self._tracer = tracer

def _safe_lookup_model(self, model_name: str) -> ModelWrapper:
if model_name != self._model.name:
Expand Down Expand Up @@ -130,42 +135,57 @@ async def predict(
model: ModelWrapper = self._safe_lookup_model(model_name)

self.check_healthy(model)
trace_ctx = otel_propagate.extract(request.headers) or None
# This is the top-level span in the truss-server, so we set the context here.
# Nested spans "inherit" context automatically.
with self._tracer.start_as_current_span(
"predict-endpoint", context=trace_ctx
) as span:
body: Dict
if self.is_binary(request):
with tracing.section_as_event(span, "binary-deserialize"):
body = truss_msgpack_deserialize(body_raw)
else:
try:
with tracing.section_as_event(span, "json-deserialize"):
body = json.loads(body_raw)
except json.JSONDecodeError as e:
error_message = f"Invalid JSON payload: {str(e)}"
logging.error(error_message)
raise HTTPException(status_code=400, detail=error_message)

# calls ModelWrapper.__call__, which runs validate, preprocess, predict, and postprocess
with tracing.section_as_event(span, "model-call"):
response: Union[Dict, Generator] = await model(
body,
headers=utils.transform_keys(
request.headers, lambda key: key.lower()
),
)

body: Dict
if self.is_binary(request):
body = truss_msgpack_deserialize(body_raw)
else:
try:
body = json.loads(body_raw)
except json.JSONDecodeError as e:
error_message = f"Invalid JSON payload: {str(e)}"
logging.error(error_message)
raise HTTPException(status_code=400, detail=error_message)

# calls ModelWrapper.__call__, which runs validate, preprocess, predict, and postprocess
response: Union[Dict, Generator] = await model(
body,
headers=utils.transform_keys(request.headers, lambda key: key.lower()),
)
# In the case that the model returns a Generator object, return a
# StreamingResponse instead.
if isinstance(response, (AsyncGenerator, Generator)):
# media_type in StreamingResponse sets the Content-Type header
return StreamingResponse(
response, media_type="application/octet-stream"
)

# In the case that the model returns a Generator object, return a
# StreamingResponse instead.
if isinstance(response, (AsyncGenerator, Generator)):
# media_type in StreamingResponse sets the Content-Type header
return StreamingResponse(response, media_type="application/octet-stream")

response_headers = {}
if self.is_binary(request):
response_headers["Content-Type"] = "application/octet-stream"
return Response(
content=truss_msgpack_serialize(response), headers=response_headers
)
else:
response_headers["Content-Type"] = "application/json"
return Response(
content=json.dumps(response, cls=DeepNumpyEncoder),
headers=response_headers,
)
response_headers = {}
if self.is_binary(request):
with tracing.section_as_event(span, "binary-serialize"):
response_headers["Content-Type"] = "application/octet-stream"
return Response(
content=truss_msgpack_serialize(response),
headers=response_headers,
)
else:
with tracing.section_as_event(span, "json-serialize"):
response_headers["Content-Type"] = "application/json"
return Response(
content=json.dumps(response, cls=DeepNumpyEncoder),
headers=response_headers,
)

async def schema(self, model_name: str) -> Dict:
model: ModelWrapper = self._safe_lookup_model(model_name)
Expand Down Expand Up @@ -206,10 +226,12 @@ def __init__(
config: Dict,
setup_json_logger: bool = True,
):
secrets = SecretsResolver.get_secrets(config)
tracer = tracing.get_truss_tracer(secrets, config)
self.http_port = http_port
self._config = config
self._model = ModelWrapper(self._config)
self._endpoints = BasetenEndpoints(self._model)
self._model = ModelWrapper(self._config, tracer)
self._endpoints = BasetenEndpoints(self._model, tracer)
self._setup_json_logger = setup_json_logger

def cleanup(self):
Expand Down Expand Up @@ -344,7 +366,7 @@ def start(self):
# Call this so uvloop gets used
cfg.setup_event_loop()

async def serve():
async def serve() -> None:
serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
serversocket.bind((cfg.host, cfg.port))
Expand Down
Loading

0 comments on commit 6484697

Please sign in to comment.