Skip to content

Commit

Permalink
TrussServer supports request/repsonse (#1148)
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten authored Sep 19, 2024
1 parent a121e69 commit b6f8959
Show file tree
Hide file tree
Showing 37 changed files with 1,511 additions and 1,046 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ jobs:
- uses: ./.github/actions/setup-python/
- run: poetry install
- name: run tests
run: poetry run pytest -v --cov=truss -m 'not integration' --junitxml=report.xml
run: poetry run pytest --durations=0 -m 'not integration' --junitxml=report.xml
- name: Publish Test Report # Not sure how to display this in the UI for non PRs.
uses: mikepenz/action-junit-report@v4
if: always()
Expand Down
6 changes: 3 additions & 3 deletions docker/base_images/base_image.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
FROM nvidia/cuda:12.2.2-base-ubuntu20.04
ENV CUDNN_VERSION=8.9.5.29
ENV CUDA=12.2
ENV LD_LIBRARY_PATH /usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH
ENV LD_LIBRARY_PATH=/usr/local/cuda/extras/CUPTI/lib64:$LD_LIBRARY_PATH

RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2004/x86_64/3bf863cc.pub && \
apt-get update && apt-get install -y --no-install-recommends \
Expand All @@ -21,7 +21,7 @@ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/
rm -rf /var/lib/apt/lists/*

# Allow statements and log messages to immediately appear in the Knative logs
ENV PYTHONUNBUFFERED True
ENV PYTHONUNBUFFERED=True
ENV DEBIAN_FRONTEND=noninteractive

RUN apt update && \
Expand Down Expand Up @@ -49,7 +49,7 @@ FROM python:{{python_version}}
RUN apt update && apt install -y

# Allow statements and log messages to immediately appear in the Knative logs
ENV PYTHONUNBUFFERED True
ENV PYTHONUNBUFFERED=True
{% endif %}


Expand Down
858 changes: 437 additions & 421 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.35"
version = "0.9.36rc01"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand Down
6 changes: 6 additions & 0 deletions truss/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import warnings
from pathlib import Path

from pydantic import PydanticDeprecatedSince20
from single_source import get_version

# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)


__version__ = get_version(__name__, Path(__file__).parent.parent)


Expand Down
6 changes: 5 additions & 1 deletion truss/config/trt_llm.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,17 @@
import json
import logging
import warnings
from enum import Enum
from typing import Optional

from huggingface_hub.errors import HFValidationError
from huggingface_hub.utils import validate_repo_id
from pydantic import BaseModel, validator
from pydantic import BaseModel, PydanticDeprecatedSince20, validator
from rich.console import Console

# Suppress Pydantic V1 warnings, because we have to use it for backwards compat.
warnings.filterwarnings("ignore", category=PydanticDeprecatedSince20)

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

Expand Down
4 changes: 4 additions & 0 deletions truss/remote/baseten/service.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import enum
import time
import urllib.parse
import warnings
from typing import (
Any,
Dict,
Expand All @@ -17,6 +18,9 @@
from truss.truss_handle import TrussHandle
from truss.util.errors import RemoteNetworkError

# "classes created inside an enum will not become a member" -> intended here anyway.
warnings.filterwarnings("ignore", category=DeprecationWarning, message=".*enum.*")

DEFAULT_STREAM_ENCODING = "utf-8"


Expand Down
7 changes: 4 additions & 3 deletions truss/remote/remote_factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import inspect

try:
from configparser import DEFAULTSECT, ConfigParser # type: ignore
except ImportError:
# We need to do this for old python.
from configparser import DEFAULTSECT
from configparser import SafeConfigParser as ConfigParser
except ImportError:
# We need to do this for py312 and onwards.
from configparser import DEFAULTSECT, ConfigParser # type: ignore


from functools import partial
from operator import is_not
Expand Down
8 changes: 4 additions & 4 deletions truss/templates/base.Dockerfile.jinja
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
ARG PYVERSION={{config.python_version}}
FROM {{base_image_name_and_tag}} as truss_server
FROM {{base_image_name_and_tag}} AS truss_server

ENV PYTHON_EXECUTABLE {{ config.base_image.python_executable_path or 'python3' }}
ENV PYTHON_EXECUTABLE={{ config.base_image.python_executable_path or 'python3' }}

{% block fail_fast %}
RUN grep -w 'ID=debian\|ID_LIKE=debian' /etc/os-release || { echo "ERROR: Supplied base image is not a debian image"; exit 1; }
Expand Down Expand Up @@ -52,7 +52,7 @@ RUN pip install -r {{config_requirements_filename}} --no-cache-dir && rm -rf /ro



ENV APP_HOME /app
ENV APP_HOME=/app
WORKDIR $APP_HOME


Expand All @@ -68,7 +68,7 @@ COPY ./{{config.bundled_packages_dir}} /packages


{% for env_var_name, env_var_value in config.environment_variables.items() %}
ENV {{ env_var_name }} {{ env_var_value }}
ENV {{ env_var_name }}={{ env_var_value }}
{% endfor %}

{% block run %}
Expand Down
6 changes: 3 additions & 3 deletions truss/templates/cache.Dockerfile.jinja
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
FROM python:3.11-slim as cache_warmer
FROM python:3.11-slim AS cache_warmer

RUN mkdir -p /app/model_cache
WORKDIR /app

{% if hf_access_token %}
ENV HUGGING_FACE_HUB_TOKEN {{hf_access_token}}
ENV HUGGING_FACE_HUB_TOKEN={{hf_access_token}}
{% endif %}

RUN apt-get -y update; apt-get -y install curl; curl -s https://baseten-public.s3.us-west-2.amazonaws.com/bin/b10cp-5fe8dc7da-linux-amd64 -o /app/b10cp; chmod +x /app/b10cp
ENV B10CP_PATH_TRUSS /app/b10cp
ENV B10CP_PATH_TRUSS=/app/b10cp
COPY ./cache_requirements.txt /app/cache_requirements.txt
RUN pip install -r /app/cache_requirements.txt --no-cache-dir && rm -rf /root/.cache/pip
COPY ./cache_warmer.py /cache_warmer.py
Expand Down
2 changes: 1 addition & 1 deletion truss/templates/control/control/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def run(self):
"inference_server_home": self._inf_serv_home,
"inference_server_process_args": [
self._python_executable_path,
f"{self._inf_serv_home}/inference_server.py",
f"{self._inf_serv_home}/main.py",
],
"control_server_host": "0.0.0.0",
"control_server_port": self._control_server_port,
Expand Down
14 changes: 7 additions & 7 deletions truss/templates/server.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
{% block base_image_patch %}
# If user base image is supplied in config, apply build commands from truss base image
{% if config.base_image %}
ENV PYTHONUNBUFFERED True
ENV PYTHONUNBUFFERED=True
ENV DEBIAN_FRONTEND=noninteractive

RUN apt update && \
Expand Down Expand Up @@ -90,14 +90,14 @@ COPY ./{{ config.model_module_dir }} /app/model

{% block run %}
{%- if config.live_reload %}
ENV HASH_TRUSS {{truss_hash}}
ENV CONTROL_SERVER_PORT 8080
ENV INFERENCE_SERVER_PORT 8090
ENV HASH_TRUSS={{truss_hash}}
ENV CONTROL_SERVER_PORT=8080
ENV INFERENCE_SERVER_PORT=8090
ENV SERVER_START_CMD="/control/.env/bin/python3 /control/control/server.py"
ENTRYPOINT ["/control/.env/bin/python3", "/control/control/server.py"]
{%- else %}
ENV INFERENCE_SERVER_PORT 8080
ENV SERVER_START_CMD="{{(config.base_image.python_executable_path or "python3") ~ " /app/inference_server.py"}}"
ENTRYPOINT ["{{config.base_image.python_executable_path or "python3"}}", "/app/inference_server.py"]
ENV INFERENCE_SERVER_PORT=8080
ENV SERVER_START_CMD="{{(config.base_image.python_executable_path or "python3") ~ " /app/main.py"}}"
ENTRYPOINT ["{{config.base_image.python_executable_path or "python3"}}", "/app/main.py"]
{%- endif %}
{% endblock %}
117 changes: 59 additions & 58 deletions truss/templates/server/common/errors.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,21 @@
import asyncio
import contextlib
import logging
import sys
from http import HTTPStatus
from types import TracebackType
from typing import (
Callable,
Coroutine,
Generator,
Mapping,
NoReturn,
Optional,
TypeVar,
Tuple,
Type,
Union,
overload,
)

import fastapi
import starlette.responses
from fastapi import HTTPException
from fastapi.responses import JSONResponse
from typing_extensions import ParamSpec

# See https://github.com/basetenlabs/baseten/blob/master/docs/Error-Propagation.md
_TRUSS_SERVER_SERVICE_ID = 4
Expand Down Expand Up @@ -51,13 +51,21 @@ class UserCodeError(Exception):
pass


class ModelDefinitionError(TypeError):
"""When the user-defined truss model does not meet the contract."""


def _make_baseten_error_headers(error_code: int) -> Mapping[str, str]:
return {
"X-BASETEN-ERROR-SOURCE": f"{_TRUSS_SERVER_SERVICE_ID:02}",
"X-BASETEN-ERROR-CODE": f"{error_code:03}",
}


def add_error_headers_to_user_response(response: starlette.responses.Response) -> None:
response.headers.update(_make_baseten_error_headers(_BASETEN_CLIENT_ERROR_CODE))


def _make_baseten_response(
http_status: int,
info: Union[str, Exception],
Expand All @@ -71,9 +79,7 @@ def _make_baseten_response(
)


async def exception_handler(
request: fastapi.Request, exc: Exception
) -> fastapi.Response:
async def exception_handler(_: fastapi.Request, exc: Exception) -> fastapi.Response:
if isinstance(exc, ModelMissingError):
return _make_baseten_response(
HTTPStatus.NOT_FOUND.value, exc, _BASETEN_DOWNSTREAM_ERROR_CODE
Expand All @@ -88,6 +94,12 @@ async def exception_handler(
exc,
_BASETEN_CLIENT_ERROR_CODE,
)
if isinstance(exc, ModelDefinitionError):
return _make_baseten_response(
HTTPStatus.PRECONDITION_FAILED.value,
f"{type(exc).__name__}: {str(exc)}",
_BASETEN_DOWNSTREAM_ERROR_CODE,
)
if isinstance(exc, UserCodeError):
return _make_baseten_response(
HTTPStatus.INTERNAL_SERVER_ERROR.value,
Expand All @@ -113,60 +125,49 @@ async def exception_handler(
NotImplementedError,
InputParsingError,
UserCodeError,
ModelDefinitionError,
fastapi.HTTPException,
}


def _intercept_user_exception(exc: Exception, logger: logging.Logger) -> NoReturn:
# Note that logger.exception logs the stacktrace, such that the user can
# debug this error from the logs.
# TODO: consider removing the wrapper function from the stack trace.
if isinstance(exc, HTTPException):
logger.exception("Model raised HTTPException", stacklevel=2)
raise exc
else:
logger.exception("Internal Server Error", stacklevel=2)
raise UserCodeError(str(exc))


_P = ParamSpec("_P")
_R = TypeVar("_R")
_R_async = TypeVar("_R_async", bound=Coroutine) # Return type for async functions


@overload
def intercept_exceptions(
func: Callable[_P, _R], logger: logging.Logger
) -> Callable[_P, _R]: ...
def filter_traceback(
model_file_name: str,
) -> Union[
Tuple[Type[BaseException], BaseException, TracebackType],
Tuple[None, None, None],
]:
exc_type, exc_value, tb = sys.exc_info()
if tb is None:
return exc_type, exc_value, tb # type: ignore[return-value]

# Walk the traceback until we find the frame ending with 'model.py'
current_tb: Optional[TracebackType] = tb
while current_tb is not None:
filename = current_tb.tb_frame.f_code.co_filename
if filename.endswith(model_file_name):
# Return exception info with traceback starting from current_tb
return exc_type, exc_value, current_tb # type: ignore[return-value]
current_tb = current_tb.tb_next

@overload
def intercept_exceptions(
func: Callable[_P, _R_async], logger: logging.Logger
) -> Callable[_P, _R_async]: ...
# If `model_file_name` not found, return the original exception info
return exc_type, exc_value, tb # type: ignore[return-value]


@contextlib.contextmanager
def intercept_exceptions(
func: Callable[_P, _R], logger: logging.Logger
) -> Callable[_P, _R]:
"""Converts all exceptions to 500-`HTTPException` and logs them.
If exception is already `HTTPException`, re-raises exception as is.
"""
if asyncio.iscoroutinefunction(func):

async def inner_async(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return await func(*args, **kwargs)
except Exception as e:
_intercept_user_exception(e, logger)

return inner_async # type: ignore[return-value]
else:

def inner_sync(*args: _P.args, **kwargs: _P.kwargs) -> _R:
try:
return func(*args, **kwargs)
except Exception as e:
_intercept_user_exception(e, logger)

return inner_sync
logger: logging.Logger, model_file_name: str
) -> Generator[None, None, None]:
try:
yield
# Note that logger.error logs the stacktrace, such that the user can
# debug this error from the logs.
except HTTPException:
logger.error(
"Model raised HTTPException", exc_info=filter_traceback(model_file_name)
)
raise
except Exception as exc:
logger.error(
"Internal Server Error", exc_info=filter_traceback(model_file_name)
)
raise UserCodeError(str(exc))
2 changes: 1 addition & 1 deletion truss/templates/server/common/tracing.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def shutdown(self) -> None:
_truss_tracer: Optional[trace.Tracer] = None


def get_truss_tracer(secrets: secrets_resolver.SecretsResolver, config) -> trace.Tracer:
def get_truss_tracer(secrets: secrets_resolver.Secrets, config) -> trace.Tracer:
"""Creates a cached tracer (i.e. runtime-singleton) to be used for truss
internal tracing.
Expand Down
Loading

0 comments on commit b6f8959

Please sign in to comment.