Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported extensible routes #133

Merged
merged 10 commits into from
Jul 29, 2024
16 changes: 0 additions & 16 deletions aidial_sdk/_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,22 +6,6 @@
from aidial_sdk.utils.errors import json_error


def missing_deployment_error() -> DIALException:
return DIALException(
status_code=404,
code="deployment_not_found",
message="The API deployment for this resource does not exist.",
)


def missing_endpoint_error(endpoint: str) -> DIALException:
return DIALException(
status_code=404,
code="endpoint_not_found",
message=f"The deployment doesn't implement '{endpoint}' endpoint.",
)


def pydantic_validation_exception_handler(
request: Request, exc: Exception
) -> JSONResponse:
Expand Down
216 changes: 106 additions & 110 deletions aidial_sdk/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,14 @@
import re
import warnings
from logging import Filter, LogRecord
from typing import Dict, Optional, Type, TypeVar
from typing import Any, Callable, Coroutine, Literal, Optional, Type, TypeVar

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_sdk._errors import (
dial_exception_handler,
fastapi_exception_handler,
missing_deployment_error,
missing_endpoint_error,
pydantic_validation_exception_handler,
)
from aidial_sdk.chat_completion.base import ChatCompletion
Expand All @@ -29,6 +27,7 @@
from aidial_sdk.header_propagator import HeaderPropagator
from aidial_sdk.pydantic_v1 import ValidationError
from aidial_sdk.telemetry.types import TelemetryConfig
from aidial_sdk.utils._reflection import get_method_implementation
from aidial_sdk.utils.log_config import LogConfig
from aidial_sdk.utils.logging import log_debug, set_log_deployment
from aidial_sdk.utils.streaming import merge_chunks
Expand All @@ -50,8 +49,6 @@ def filter(self, record: LogRecord):


class DIALApp(FastAPI):
chat_completion_impls: Dict[str, ChatCompletion] = {}
embeddings_impls: Dict[str, Embeddings] = {}

def __init__(
self,
Expand Down Expand Up @@ -88,38 +85,6 @@ def __init__(
self.add_api_route(path, DIALApp._healthcheck, methods=["GET"])
logging.getLogger("uvicorn.access").addFilter(PathFilter(path))

self.add_api_route(
"/openai/deployments/{deployment_id}/embeddings",
self._embeddings,
methods=["POST"],
)

self.add_api_route(
"/openai/deployments/{deployment_id}/chat/completions",
self._chat_completion,
methods=["POST"],
)

self.add_api_route(
"/openai/deployments/{deployment_id}/rate",
self._rate_response,
methods=["POST"],
)

self.add_api_route(
"/openai/deployments/{deployment_id}/tokenize",
self._chat_completion_endpoint_factory("tokenize", TokenizeRequest),
methods=["POST"],
)

self.add_api_route(
"/openai/deployments/{deployment_id}/truncate_prompt",
self._chat_completion_endpoint_factory(
"truncate_prompt", TruncatePromptRequest
),
methods=["POST"],
)

self.add_exception_handler(
ValidationError, pydantic_validation_exception_handler
)
Expand All @@ -139,99 +104,130 @@ def configure_telemetry(self, config: TelemetryConfig):

init_telemetry(app=self, config=config)

def add_embeddings(self, deployment_name: str, impl: Embeddings) -> None:
self.embeddings_impls[deployment_name] = impl
def add_embeddings(
self, deployment_name: str, impl: Embeddings
) -> "DIALApp":
self.add_api_route(
f"/openai/deployments/{deployment_name}/embeddings",
self._embeddings(deployment_name, impl),
methods=["POST"],
)

return self

def add_chat_completion(
self, deployment_name: str, impl: ChatCompletion
) -> None:
self.chat_completion_impls[deployment_name] = impl
) -> "DIALApp":

def _chat_completion_endpoint_factory(
self, endpoint: str, request_type: Type["RequestType"]
):
async def _handler(
deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
deployment = self._get_chat_completion(deployment_id)
self.add_api_route(
f"/openai/deployments/{deployment_name}/chat/completions",
self._chat_completion(deployment_name, impl),
methods=["POST"],
)

request = await request_type.from_request(original_request)
self.add_api_route(
f"/openai/deployments/{deployment_name}/rate",
self._rate_response(deployment_name, impl),
methods=["POST"],
)

endpoint_impl = getattr(deployment, endpoint, None)
if not endpoint_impl:
raise missing_endpoint_error(endpoint)
if endpoint_impl := get_method_implementation(impl, "tokenize"):
self.add_api_route(
f"/openai/deployments/{deployment_name}/tokenize",
self._endpoint_factory(
deployment_name, endpoint_impl, "tokenize", TokenizeRequest
),
methods=["POST"],
)

if endpoint_impl := get_method_implementation(impl, "truncate_prompt"):
self.add_api_route(
f"/openai/deployments/{deployment_name}/truncate_prompt",
self._endpoint_factory(
deployment_name,
endpoint_impl,
"truncate_prompt",
TruncatePromptRequest,
),
methods=["POST"],
)

return self

def _endpoint_factory(
self,
deployment_id: str,
endpoint_impl: Callable[[RequestType], Coroutine[Any, Any, Any]],
endpoint: Literal["tokenize", "truncate_prompt"],
vladisavvv marked this conversation as resolved.
Show resolved Hide resolved
request_type: Type["RequestType"],
):
async def _handler(original_request: Request) -> Response:
set_log_deployment(deployment_id)

try:
response = await endpoint_impl(request)
except NotImplementedError:
raise missing_endpoint_error(endpoint)
request = await request_type.from_request(
original_request, deployment_id
)
response = await endpoint_impl(request)

response_json = response.dict()
log_debug(f"response [{endpoint}]: {response_json}")
return JSONResponse(content=response_json)

return _handler

async def _rate_response(
self, deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
deployment = self._get_chat_completion(deployment_id)

request = await RateRequest.from_request(original_request)

await deployment.rate_response(request)
return Response(status_code=200)

async def _embeddings(
self, deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
deployment = self._get_embeddings(deployment_id)
request = await EmbeddingsRequest.from_request(original_request)
response = await deployment.embeddings(request)
response_json = response.dict()
return JSONResponse(content=response_json)

async def _chat_completion(
self, deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
deployment = self._get_chat_completion(deployment_id)

request = await ChatCompletionRequest.from_request(original_request)

response = ChatCompletionResponse(request)
first_chunk = await response._generator(
deployment.chat_completion, request
)
def _rate_response(self, deployment_id: str, impl: ChatCompletion):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)

if request.stream:
return StreamingResponse(
response._generate_stream(first_chunk),
media_type="text/event-stream",
request = await RateRequest.from_request(
original_request, deployment_id
)
else:
response_json = await merge_chunks(
response._generate_stream(first_chunk)

await impl.rate_response(request)
return Response(status_code=200)

return _handler

def _chat_completion(self, deployment_id: str, impl: ChatCompletion):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)

request = await ChatCompletionRequest.from_request(
original_request, deployment_id
)

response = ChatCompletionResponse(request)
first_chunk = await response._generator(
impl.chat_completion, request
)

log_debug(f"response: {response_json}")
if request.stream:
return StreamingResponse(
response._generate_stream(first_chunk),
media_type="text/event-stream",
)
else:
response_json = await merge_chunks(
response._generate_stream(first_chunk)
)

log_debug(f"response: {response_json}")
return JSONResponse(content=response_json)

return _handler

def _embeddings(self, deployment_id: str, impl: Embeddings):
async def _handler(original_request: Request):
set_log_deployment(deployment_id)
request = await EmbeddingsRequest.from_request(
original_request, deployment_id
)
response = await impl.embeddings(request)
response_json = response.dict()
return JSONResponse(content=response_json)

return _handler

@staticmethod
async def _healthcheck() -> JSONResponse:
return JSONResponse(content={"status": "ok"})

def _get_chat_completion(self, deployment_id: str) -> ChatCompletion:
impl = self.chat_completion_impls.get(deployment_id, None)
if not impl:
raise missing_deployment_error()
return impl

def _get_embeddings(self, deployment_id: str) -> Embeddings:
impl = self.embeddings_impls.get(deployment_id, None)
if not impl:
raise missing_deployment_error()
return impl
16 changes: 5 additions & 11 deletions aidial_sdk/deployment/from_request_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,15 @@
class FromRequestMixin(ABC, ExtraForbidModel):
@classmethod
@abstractmethod
async def from_request(cls: Type[T], request: fastapi.Request) -> T:
async def from_request(
cls: Type[T], request: fastapi.Request, deployment_id: str
) -> T:
pass


class FromRequestBasicMixin(FromRequestMixin):
@classmethod
async def from_request(cls, request: fastapi.Request):
async def from_request(cls, request: fastapi.Request, deployment_id: str):
return cls(**(await _get_request_body(request)))


Expand Down Expand Up @@ -60,15 +62,7 @@ def jwt(self) -> Optional[str]:
return self.jwt_secret.get_secret_value() if self.jwt_secret else None

@classmethod
async def from_request(cls, request: fastapi.Request):
deployment_id = request.path_params.get("deployment_id")
if deployment_id is None or not isinstance(deployment_id, str):
raise DIALException(
status_code=404,
type="invalid_path",
message="Invalid path",
)

async def from_request(cls, request: fastapi.Request, deployment_id: str):
headers = request.headers.mutablecopy()

api_key = headers.get("Api-Key")
Expand Down
34 changes: 34 additions & 0 deletions aidial_sdk/utils/_reflection.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from typing import Any, Optional


def has_method_implemented(obj: Any, method_name: str) -> bool:
"""
Determine if a method is overridden in an object instance or
if it is inherited from its class.
"""

base_method = None
for cls in type(obj).__mro__[1:]:
base_method = getattr(cls, method_name, None)
if base_method is not None:
break

this_method = getattr(obj, method_name, None)

if base_method is None or this_method is None:
return False

if hasattr(base_method, "__code__") and hasattr(this_method, "__code__"):
return base_method.__code__ != this_method.__code__

return base_method != this_method


def get_method_implementation(obj: Any, method_name: str) -> Optional[Any]:
"""
Get the method implementation of an object instance.
"""

if has_method_implemented(obj, method_name):
return getattr(obj, method_name)
return None
8 changes: 6 additions & 2 deletions aidial_sdk/utils/logging.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
import logging
from contextvars import ContextVar
from typing import Optional

logger = logging.getLogger("aidial_sdk")
deployment_id = ContextVar("deployment_id", default=None)

deployment_id: ContextVar[Optional[str]] = ContextVar(
"deployment_id", default=None
)

def set_log_deployment(new_deployment_id):

def set_log_deployment(new_deployment_id: str):
deployment_id.set(new_deployment_id)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,10 @@ def raise_exception(exception_type: str):


class BrokenApplication(ChatCompletion):
"""
Application which breaks immediately after receiving a request.
"""

async def chat_completion(
self, request: Request, response: Response
) -> None:
Expand Down
Loading