From e8ffd36579e867b8e652728f6880afe6963e7194 Mon Sep 17 00:00:00 2001 From: Patrick Gerard <75830792+grumpyp@users.noreply.github.com> Date: Thu, 26 Sep 2024 19:31:08 +0200 Subject: [PATCH] Add warning for unexpected model output in batched prediction (#300) * fix: add warning unexpected output from HF model (closes #294) * add: warning if batched loop return string * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move output check to unbatch_no_stream * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * restore format loops.py * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: lint E501 Line too long * Update tests/test_batch.py Co-authored-by: Aniket Maurya * Update src/litserve/api.py Co-authored-by: Aniket Maurya * Update test to match new warning string * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Delete whitespace in warning string test_batch * Update test_batch.py * Update test_batch.py * Update warning copy * update test * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix test --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Aniket Maurya --- src/litserve/api.py | 7 +++++++ tests/test_batch.py | 16 ++++++++++++++++ 2 files changed, 23 insertions(+) diff --git a/src/litserve/api.py b/src/litserve/api.py index 18c9bb7d..6c6699c1 100644 --- a/src/litserve/api.py +++ b/src/litserve/api.py @@ -63,6 +63,13 @@ def predict(self, x, **kwargs): pass def _unbatch_no_stream(self, output): + if isinstance(output, str): + warnings.warn( + "The 'predict' method returned a string instead of a list of predictions. " + "When batching is enabled, 'predict' must return a list to handle multiple inputs correctly. " + "Please update the 'predict' method to return a list of predictions to avoid unexpected behavior.", + UserWarning, + ) return list(output) def _unbatch_stream(self, output_stream): diff --git a/tests/test_batch.py b/tests/test_batch.py index 5a1d1a1f..8b0865ff 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -148,6 +148,22 @@ def test_max_batch_size_warning(): LitServer(SimpleTorchAPI(), accelerator="cpu", devices=1, timeout=2) +def test_batch_predict_string_warning(): + api = ls.test_examples.SimpleBatchedAPI() + api._sanitize(2, None) + api.predict = MagicMock(return_value="This is a string") + + mock_input = torch.tensor([[1.0], [2.0]]) + + with pytest.warns( + UserWarning, + match="When batching is enabled, 'predict' must return a list to handle multiple inputs correctly.", + ): + # Simulate the behavior in run_batched_loop + y = api.predict(mock_input) + api.unbatch(y) + + class FakeResponseQueue: def put(self, *args): raise Exception("Exit loop")