diff --git a/src/litserve/loops.py b/src/litserve/loops.py index 7e65ba9f..d01c3d6f 100644 --- a/src/litserve/loops.py +++ b/src/litserve/loops.py @@ -26,7 +26,7 @@ from litserve import LitAPI from litserve.callbacks import CallbackRunner, EventTypes from litserve.specs.base import LitSpec -from litserve.utils import LitAPIStatus, dump_exception +from litserve.utils import LitAPIStatus, PickleableHTTPException mp.allow_connection_pickling() @@ -147,14 +147,20 @@ def run_single_loop( callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) + + except HTTPException as e: + response_queues[response_queue_id].put(( + uid, + (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), + )) + except Exception as e: logger.exception( "LitAPI ran into an error while processing the request uid=%s.\n" "Please check the error trace for more details.", uid, ) - err_pkl = dump_exception(e) - response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) def run_batched_loop( @@ -221,14 +227,20 @@ def run_batched_loop( for response_queue_id, uid, y_enc in y_enc_list: response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) + except HTTPException as e: + for response_queue_id, uid in zip(response_queue_ids, uids): + response_queues[response_queue_id].put(( + uid, + (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), + )) + except Exception as e: logger.exception( "LitAPI ran into an error while processing the batched request.\n" "Please check the error trace for more details." ) - err_pkl = dump_exception(e) for response_queue_id, uid in zip(response_queue_ids, uids): - response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) def run_streaming_loop( @@ -283,13 +295,19 @@ def run_streaming_loop( y_enc = lit_api.format_encoded_response(y_enc) response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK))) response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + + except HTTPException as e: + response_queues[response_queue_id].put(( + uid, + (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), + )) except Exception as e: logger.exception( "LitAPI ran into an error while processing the streaming request uid=%s.\n" "Please check the error trace for more details.", uid, ) - response_queues[response_queue_id].put((uid, (dump_exception(e), LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) def run_batched_streaming_loop( @@ -357,13 +375,18 @@ def run_batched_streaming_loop( for response_queue_id, uid in zip(response_queue_ids, uids): response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING))) + except HTTPException as e: + response_queues[response_queue_id].put(( + uid, + (PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR), + )) + except Exception as e: logger.exception( "LitAPI ran into an error while processing the streaming batched request.\n" "Please check the error trace for more details." ) - err_pkl = dump_exception(e) - response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR))) + response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR))) def inference_worker( diff --git a/src/litserve/server.py b/src/litserve/server.py index 9ab90965..fdc8554f 100644 --- a/src/litserve/server.py +++ b/src/litserve/server.py @@ -44,7 +44,7 @@ from litserve.python_client import client_template from litserve.specs import OpenAISpec from litserve.specs.base import LitSpec -from litserve.utils import LitAPIStatus, call_after_stream, load_and_raise +from litserve.utils import LitAPIStatus, call_after_stream mp.allow_connection_pickling() @@ -362,7 +362,7 @@ async def predict(request: self.request_type) -> self.response_type: response, status = self.response_buffer.pop(uid) if status == LitAPIStatus.ERROR: - load_and_raise(response) + raise response self._callback_runner.trigger_event(EventTypes.ON_RESPONSE, litserver=self) return response diff --git a/src/litserve/specs/openai.py b/src/litserve/specs/openai.py index dd2172bc..af802f7e 100644 --- a/src/litserve/specs/openai.py +++ b/src/litserve/specs/openai.py @@ -27,7 +27,7 @@ from pydantic import BaseModel, Field from litserve.specs.base import LitSpec -from litserve.utils import LitAPIStatus, azip, load_and_raise +from litserve.utils import LitAPIStatus, azip if typing.TYPE_CHECKING: from litserve import LitServer @@ -380,7 +380,7 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon # iterate over n choices for i, (response, status) in enumerate(streaming_response): if status == LitAPIStatus.ERROR: - load_and_raise(response) + raise response encoded_response = json.loads(response) logger.debug(encoded_response) chat_msg = ChoiceDelta(**encoded_response) @@ -424,7 +424,7 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat usage = None async for response, status in streaming_response: if status == LitAPIStatus.ERROR: - load_and_raise(response) + raise response # data from LitAPI.encode_response encoded_response = json.loads(response) logger.debug(encoded_response) diff --git a/src/litserve/utils.py b/src/litserve/utils.py index 73c84367..b644c00d 100644 --- a/src/litserve/utils.py +++ b/src/litserve/utils.py @@ -48,18 +48,6 @@ def dump_exception(exception): return pickle.dumps(exception) -def load_and_raise(response): - try: - exception = pickle.loads(response) if isinstance(response, bytes) else response - raise exception - except pickle.PickleError: - logger.exception( - f"main process failed to load the exception from the parallel worker process. " - f"{response} couldn't be unpickled." - ) - raise - - async def azip(*async_iterables): iterators = [ait.__aiter__() for ait in async_iterables] while True: diff --git a/tests/test_batch.py b/tests/test_batch.py index b853f354..15d12b6e 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -166,7 +166,7 @@ def test_batch_predict_string_warning(): class FakeResponseQueue: def put(self, *args): - raise Exception("Exit loop") + raise StopIteration("exit loop") def test_batched_loop(): @@ -188,7 +188,7 @@ def test_batched_loop(): lit_api_mock, lit_api_mock, requests_queue, - FakeResponseQueue(), + [FakeResponseQueue()], max_batch_size=2, batch_timeout=4, callback_runner=NOOP_CB_RUNNER, diff --git a/tests/test_lit_server.py b/tests/test_lit_server.py index 8c316401..966f1ea0 100644 --- a/tests/test_lit_server.py +++ b/tests/test_lit_server.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import asyncio -import pickle import re import sys from unittest.mock import MagicMock, patch @@ -315,13 +314,7 @@ def encode_response(self, output, context): @pytest.mark.asyncio -@patch("litserve.server.load_and_raise") -async def test_inject_context(mocked_load_and_raise): - def dummy_load_and_raise(resp): - raise pickle.loads(resp) - - mocked_load_and_raise.side_effect = dummy_load_and_raise - +async def test_inject_context(): # Test context injection with single loop api = IdentityAPI() server = LitServer(api) diff --git a/tests/test_loops.py b/tests/test_loops.py index 4c272452..42ad603f 100644 --- a/tests/test_loops.py +++ b/tests/test_loops.py @@ -118,7 +118,7 @@ def put(self, item): response, status = args if status == LitAPIStatus.FINISH_STREAMING: raise StopIteration("interrupt iteration") - if status == LitAPIStatus.ERROR and b"interrupt iteration" in response: + if status == LitAPIStatus.ERROR and isinstance(response, StopIteration): assert self.count // 2 == self.num_streamed_outputs, ( f"Loop count must have incremented for " f"{self.num_streamed_outputs} times." )