Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix custom HTTPException with argument names #302

Merged
merged 4 commits into from
Sep 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import inspect
import logging
import multiprocessing as mp
import pickle
import sys
import time
from queue import Empty, Queue
Expand All @@ -27,10 +26,11 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus
from litserve.utils import LitAPIStatus, dump_exception

mp.allow_connection_pickling()


try:
import uvloop

Expand Down Expand Up @@ -153,7 +153,7 @@ def run_single_loop(
"Please check the error trace for more details.",
uid,
)
err_pkl = pickle.dumps(e)
err_pkl = dump_exception(e)
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


Expand Down Expand Up @@ -226,7 +226,7 @@ def run_batched_loop(
"LitAPI ran into an error while processing the batched request.\n"
"Please check the error trace for more details."
)
err_pkl = pickle.dumps(e)
err_pkl = dump_exception(e)
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))

Expand Down Expand Up @@ -289,7 +289,7 @@ def run_streaming_loop(
"Please check the error trace for more details.",
uid,
)
response_queues[response_queue_id].put((uid, (pickle.dumps(e), LitAPIStatus.ERROR)))
response_queues[response_queue_id].put((uid, (dump_exception(e), LitAPIStatus.ERROR)))


def run_batched_streaming_loop(
Expand Down Expand Up @@ -362,7 +362,7 @@ def run_batched_streaming_loop(
"LitAPI ran into an error while processing the streaming batched request.\n"
"Please check the error trace for more details."
)
err_pkl = pickle.dumps(e)
err_pkl = dump_exception(e)
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))


Expand Down
17 changes: 17 additions & 0 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,23 @@ class LitAPIStatus:
FINISH_STREAMING = "FINISH_STREAMING"


class PickleableHTTPException(HTTPException):
@staticmethod
def from_exception(exc: HTTPException):
status_code = exc.status_code
detail = exc.detail
return PickleableHTTPException(status_code, detail)

def __reduce__(self):
return (HTTPException, (self.status_code, self.detail))


def dump_exception(exception):
if isinstance(exception, HTTPException):
exception = PickleableHTTPException.from_exception(exception)
return pickle.dumps(exception)


def load_and_raise(response):
try:
exception = pickle.loads(response) if isinstance(response, bytes) else response
Expand Down
11 changes: 11 additions & 0 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,13 +368,24 @@ def decode_request(self, request):
raise HTTPException(501, "decode request is bad")


class TestHTTPExceptionAPI2(ls.test_examples.SimpleLitAPI):
def decode_request(self, request):
raise HTTPException(status_code=400, detail="decode request is bad")


def test_http_exception():
server = LitServer(TestHTTPExceptionAPI())
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 501, "Server raises 501 error"
assert response.text == '{"detail":"decode request is bad"}', "decode request is bad"

server = LitServer(TestHTTPExceptionAPI2())
with wrap_litserve_start(server) as server, TestClient(server.app) as client:
response = client.post("/predict", json={"input": 4.0})
assert response.status_code == 400, "Server raises 400 error"
assert response.text == '{"detail":"decode request is bad"}', "decode request is bad"


class RequestIdMiddleware(BaseHTTPMiddleware):
def __init__(self, app: ASGIApp, length: int) -> None:
Expand Down
15 changes: 15 additions & 0 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import pickle

from fastapi import HTTPException

from litserve.utils import dump_exception


def test_dump_exception():
e1 = dump_exception(HTTPException(status_code=404, detail="Not Found"))
assert isinstance(e1, bytes)

exc = HTTPException(400, "Custom Lit error")
isinstance(pickle.loads(dump_exception(exc)), HTTPException)
assert pickle.loads(dump_exception(exc)).detail == "Custom Lit error"
assert pickle.loads(dump_exception(exc)).status_code == 400
Loading