Skip to content

Commit

Permalink
API: Move OAI to APIRouter
Browse files Browse the repository at this point in the history
This makes the API more modular for other API implementations in the
future.

Signed-off-by: kingbri <bdashore3@proton.me>
  • Loading branch information
bdashore3 committed Apr 6, 2024
1 parent 8bdc191 commit 5bb4995
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 64 deletions.
89 changes: 26 additions & 63 deletions endpoints/OAI/app.py → endpoints/OAI/router.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import asyncio
import pathlib
import uvicorn
from fastapi import FastAPI, Depends, HTTPException, Header, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi import APIRouter, Depends, HTTPException, Header, Request
from functools import partial
from loguru import logger
from sse_starlette import EventSourceResponse
Expand All @@ -15,7 +13,6 @@
call_with_semaphore,
generate_with_semaphore,
)
from common.logger import UVICORN_LOG_CONFIG
from common.networking import handle_request_error, run_with_request_disconnect
from common.templating import (
get_all_templates,
Expand Down Expand Up @@ -56,23 +53,8 @@
from endpoints.OAI.utils.model import get_model_list, stream_model_load
from endpoints.OAI.utils.lora import get_lora_list

app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
)

# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
router = APIRouter()


async def check_model_container():
Expand All @@ -90,8 +72,8 @@ async def check_model_container():


# Model list endpoint
@app.get("/v1/models", dependencies=[Depends(check_api_key)])
@app.get("/v1/model/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/models", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/list", dependencies=[Depends(check_api_key)])
async def list_models():
"""Lists all models in the model directory."""
model_config = config.model_config()
Expand All @@ -108,7 +90,7 @@ async def list_models():


# Currently loaded model endpoint
@app.get(
@router.get(
"/v1/model",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand Down Expand Up @@ -142,7 +124,7 @@ async def get_current_model():
return model_card


@app.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/model/draft/list", dependencies=[Depends(check_api_key)])
async def list_draft_models():
"""Lists all draft models in the model directory."""
draft_model_dir = unwrap(
Expand All @@ -156,7 +138,7 @@ async def list_draft_models():


# Load model endpoint
@app.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
@router.post("/v1/model/load", dependencies=[Depends(check_admin_key)])
async def load_model(request: Request, data: ModelLoadRequest):
"""Loads a model into the model container."""

Expand Down Expand Up @@ -209,7 +191,7 @@ async def load_model(request: Request, data: ModelLoadRequest):


# Unload model endpoint
@app.post(
@router.post(
"/v1/model/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
Expand All @@ -218,15 +200,15 @@ async def unload_model():
await model.unload_model()


@app.get("/v1/templates", dependencies=[Depends(check_api_key)])
@app.get("/v1/template/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/templates", dependencies=[Depends(check_api_key)])
@router.get("/v1/template/list", dependencies=[Depends(check_api_key)])
async def get_templates():
templates = get_all_templates()
template_strings = list(map(lambda template: template.stem, templates))
return TemplateList(data=template_strings)


@app.post(
@router.post(
"/v1/template/switch",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
Expand All @@ -252,7 +234,7 @@ async def switch_template(data: TemplateSwitchRequest):
raise HTTPException(400, error_message) from e


@app.post(
@router.post(
"/v1/template/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
Expand All @@ -263,15 +245,15 @@ async def unload_template():


# Sampler override endpoints
@app.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@app.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/overrides", dependencies=[Depends(check_api_key)])
@router.get("/v1/sampling/override/list", dependencies=[Depends(check_api_key)])
async def list_sampler_overrides():
"""API wrapper to list all currently applied sampler overrides"""

return sampling.overrides


@app.post(
@router.post(
"/v1/sampling/override/switch",
dependencies=[Depends(check_admin_key)],
)
Expand Down Expand Up @@ -300,7 +282,7 @@ async def switch_sampler_override(data: SamplerOverrideSwitchRequest):
raise HTTPException(400, error_message)


@app.post(
@router.post(
"/v1/sampling/override/unload",
dependencies=[Depends(check_admin_key)],
)
Expand All @@ -311,8 +293,8 @@ async def unload_sampler_override():


# Lora list endpoint
@app.get("/v1/loras", dependencies=[Depends(check_api_key)])
@app.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
@router.get("/v1/loras", dependencies=[Depends(check_api_key)])
@router.get("/v1/lora/list", dependencies=[Depends(check_api_key)])
async def get_all_loras():
"""Lists all LoRAs in the lora directory."""
lora_path = pathlib.Path(unwrap(config.lora_config().get("lora_dir"), "loras"))
Expand All @@ -322,7 +304,7 @@ async def get_all_loras():


# Currently loaded loras endpoint
@app.get(
@router.get(
"/v1/lora",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand All @@ -344,7 +326,7 @@ async def get_active_loras():


# Load lora endpoint
@app.post(
@router.post(
"/v1/lora/load",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
Expand Down Expand Up @@ -388,7 +370,7 @@ async def load_lora(data: LoraLoadRequest):


# Unload lora endpoint
@app.post(
@router.post(
"/v1/lora/unload",
dependencies=[Depends(check_admin_key), Depends(check_model_container)],
)
Expand All @@ -399,7 +381,7 @@ async def unload_loras():


# Encode tokens endpoint
@app.post(
@router.post(
"/v1/token/encode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand All @@ -413,7 +395,7 @@ async def encode_tokens(data: TokenEncodeRequest):


# Decode tokens endpoint
@app.post(
@router.post(
"/v1/token/decode",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand All @@ -425,7 +407,7 @@ async def decode_tokens(data: TokenDecodeRequest):
return response


@app.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
@router.get("/v1/auth/permission", dependencies=[Depends(check_api_key)])
async def get_key_permission(
x_admin_key: Optional[str] = Header(None),
x_api_key: Optional[str] = Header(None),
Expand All @@ -452,7 +434,7 @@ async def get_key_permission(


# Completions endpoint
@app.post(
@router.post(
"/v1/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand Down Expand Up @@ -488,7 +470,7 @@ async def completion_request(request: Request, data: CompletionRequest):


# Chat completions endpoint
@app.post(
@router.post(
"/v1/chat/completions",
dependencies=[Depends(check_api_key), Depends(check_model_container)],
)
Expand Down Expand Up @@ -536,22 +518,3 @@ async def chat_completion_request(request: Request, data: ChatCompletionRequest)
disconnect_message="Chat completion generation cancelled by user.",
)
return response


async def start_api(host: str, port: int):
"""Isolated function to start the API server"""

# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")

config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)

await server.serve()
47 changes: 47 additions & 0 deletions endpoints/server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import uvicorn
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from loguru import logger

from common.logger import UVICORN_LOG_CONFIG
from endpoints.OAI.router import router as OAIRouter

app = FastAPI(
title="TabbyAPI",
summary="An OAI compatible exllamav2 API that's both lightweight and fast",
description=(
"This docs page is not meant to send requests! Please use a service "
"like Postman or a frontend UI."
),
)

# ALlow CORS requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)


async def start_api(host: str, port: int):
"""Isolated function to start the API server"""

# TODO: Move OAI API to a separate folder
logger.info(f"Developer documentation: http://{host}:{port}/redoc")
logger.info(f"Completions: http://{host}:{port}/v1/completions")
logger.info(f"Chat completions: http://{host}:{port}/v1/chat/completions")

# Add OAI router
app.include_router(OAIRouter)

config = uvicorn.Config(
app,
host=host,
port=port,
log_config=UVICORN_LOG_CONFIG,
)
server = uvicorn.Server(config)

await server.serve()
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from common.networking import is_port_in_use
from common.signals import signal_handler
from common.utils import unwrap
from endpoints.OAI.app import start_api
from endpoints.server import start_api


async def entrypoint(args: Optional[dict] = None):
Expand Down

0 comments on commit 5bb4995

Please sign in to comment.