Skip to content

Commit

Permalink
improve error handling from inference worker processes (#337)
Browse files Browse the repository at this point in the history
* remove load_and_raise

* fix tests

* fixes
  • Loading branch information
aniketmaurya authored Oct 22, 2024
1 parent 11fed4a commit 11ea0ce
Show file tree
Hide file tree
Showing 7 changed files with 40 additions and 36 deletions.
39 changes: 31 additions & 8 deletions src/litserve/loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from litserve import LitAPI
from litserve.callbacks import CallbackRunner, EventTypes
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, dump_exception
from litserve.utils import LitAPIStatus, PickleableHTTPException

mp.allow_connection_pickling()

Expand Down Expand Up @@ -147,14 +147,20 @@ def run_single_loop(
callback_runner.trigger_event(EventTypes.AFTER_ENCODE_RESPONSE, lit_api=lit_api)

response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
err_pkl = dump_exception(e)
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_batched_loop(
Expand Down Expand Up @@ -221,14 +227,20 @@ def run_batched_loop(
for response_queue_id, uid, y_enc in y_enc_list:
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))

except HTTPException as e:
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the batched request.\n"
"Please check the error trace for more details."
)
err_pkl = dump_exception(e)
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_streaming_loop(
Expand Down Expand Up @@ -283,13 +295,19 @@ def run_streaming_loop(
y_enc = lit_api.format_encoded_response(y_enc)
response_queues[response_queue_id].put((uid, (y_enc, LitAPIStatus.OK)))
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))
except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the streaming request uid=%s.\n"
"Please check the error trace for more details.",
uid,
)
response_queues[response_queue_id].put((uid, (dump_exception(e), LitAPIStatus.ERROR)))
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def run_batched_streaming_loop(
Expand Down Expand Up @@ -357,13 +375,18 @@ def run_batched_streaming_loop(
for response_queue_id, uid in zip(response_queue_ids, uids):
response_queues[response_queue_id].put((uid, ("", LitAPIStatus.FINISH_STREAMING)))

except HTTPException as e:
response_queues[response_queue_id].put((
uid,
(PickleableHTTPException.from_exception(e), LitAPIStatus.ERROR),
))

except Exception as e:
logger.exception(
"LitAPI ran into an error while processing the streaming batched request.\n"
"Please check the error trace for more details."
)
err_pkl = dump_exception(e)
response_queues[response_queue_id].put((uid, (err_pkl, LitAPIStatus.ERROR)))
response_queues[response_queue_id].put((uid, (e, LitAPIStatus.ERROR)))


def inference_worker(
Expand Down
4 changes: 2 additions & 2 deletions src/litserve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from litserve.python_client import client_template
from litserve.specs import OpenAISpec
from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, call_after_stream, load_and_raise
from litserve.utils import LitAPIStatus, call_after_stream

mp.allow_connection_pickling()

Expand Down Expand Up @@ -362,7 +362,7 @@ async def predict(request: self.request_type) -> self.response_type:
response, status = self.response_buffer.pop(uid)

if status == LitAPIStatus.ERROR:
load_and_raise(response)
raise response
self._callback_runner.trigger_event(EventTypes.ON_RESPONSE, litserver=self)
return response

Expand Down
6 changes: 3 additions & 3 deletions src/litserve/specs/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pydantic import BaseModel, Field

from litserve.specs.base import LitSpec
from litserve.utils import LitAPIStatus, azip, load_and_raise
from litserve.utils import LitAPIStatus, azip

if typing.TYPE_CHECKING:
from litserve import LitServer
Expand Down Expand Up @@ -380,7 +380,7 @@ async def streaming_completion(self, request: ChatCompletionRequest, pipe_respon
# iterate over n choices
for i, (response, status) in enumerate(streaming_response):
if status == LitAPIStatus.ERROR:
load_and_raise(response)
raise response
encoded_response = json.loads(response)
logger.debug(encoded_response)
chat_msg = ChoiceDelta(**encoded_response)
Expand Down Expand Up @@ -424,7 +424,7 @@ async def non_streaming_completion(self, request: ChatCompletionRequest, generat
usage = None
async for response, status in streaming_response:
if status == LitAPIStatus.ERROR:
load_and_raise(response)
raise response
# data from LitAPI.encode_response
encoded_response = json.loads(response)
logger.debug(encoded_response)
Expand Down
12 changes: 0 additions & 12 deletions src/litserve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,6 @@ def dump_exception(exception):
return pickle.dumps(exception)


def load_and_raise(response):
try:
exception = pickle.loads(response) if isinstance(response, bytes) else response
raise exception
except pickle.PickleError:
logger.exception(
f"main process failed to load the exception from the parallel worker process. "
f"{response} couldn't be unpickled."
)
raise


async def azip(*async_iterables):
iterators = [ait.__aiter__() for ait in async_iterables]
while True:
Expand Down
4 changes: 2 additions & 2 deletions tests/test_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def test_batch_predict_string_warning():

class FakeResponseQueue:
def put(self, *args):
raise Exception("Exit loop")
raise StopIteration("exit loop")


def test_batched_loop():
Expand All @@ -188,7 +188,7 @@ def test_batched_loop():
lit_api_mock,
lit_api_mock,
requests_queue,
FakeResponseQueue(),
[FakeResponseQueue()],
max_batch_size=2,
batch_timeout=4,
callback_runner=NOOP_CB_RUNNER,
Expand Down
9 changes: 1 addition & 8 deletions tests/test_lit_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
import pickle
import re
import sys
from unittest.mock import MagicMock, patch
Expand Down Expand Up @@ -315,13 +314,7 @@ def encode_response(self, output, context):


@pytest.mark.asyncio
@patch("litserve.server.load_and_raise")
async def test_inject_context(mocked_load_and_raise):
def dummy_load_and_raise(resp):
raise pickle.loads(resp)

mocked_load_and_raise.side_effect = dummy_load_and_raise

async def test_inject_context():
# Test context injection with single loop
api = IdentityAPI()
server = LitServer(api)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_loops.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def put(self, item):
response, status = args
if status == LitAPIStatus.FINISH_STREAMING:
raise StopIteration("interrupt iteration")
if status == LitAPIStatus.ERROR and b"interrupt iteration" in response:
if status == LitAPIStatus.ERROR and isinstance(response, StopIteration):
assert self.count // 2 == self.num_streamed_outputs, (
f"Loop count must have incremented for " f"{self.num_streamed_outputs} times."
)
Expand Down

0 comments on commit 11ea0ce

Please sign in to comment.