diff --git a/tests/test_batch.py b/tests/test_batch.py index 7aa243c0..28fed15a 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -151,19 +151,25 @@ def test_max_batch_size_warning(): def test_check_predict_output_warnings(): api = SimpleBatchLitAPI() api.predict = MagicMock(return_value="This is a string") - + mock_input = torch.tensor([[1.0], [2.0]]) - - with pytest.warns(UserWarning, match="The 'predict' method returned a string instead of a list of predictions. " - "When batching is enabled, 'predict' should return a list of predictions. " - "To avoid unexpected behavior, ensure 'predict' returns a list of predictions, or implement 'LitAPI.unbatch' correctly."): + + with pytest.warns( + UserWarning, + match="The 'predict' method returned a string instead of a list of predictions. " + "When batching is enabled, 'predict' should return a list of predictions. " + "To avoid unexpected behavior, ensure 'predict' returns a list of predictions, or implement 'LitAPI.unbatch' correctly.", + ): # Simulate the behavior in run_batched_loop y = api.predict(mock_input) if isinstance(y, str): import warnings - warnings.warn("The 'predict' method returned a string instead of a list of predictions. " - "When batching is enabled, 'predict' should return a list of predictions. " - "To avoid unexpected behavior, ensure 'predict' returns a list of predictions, or implement 'LitAPI.unbatch' correctly.") + + warnings.warn( + "The 'predict' method returned a string instead of a list of predictions. " + "When batching is enabled, 'predict' should return a list of predictions. " + "To avoid unexpected behavior, ensure 'predict' returns a list of predictions, or implement 'LitAPI.unbatch' correctly." + ) class FakeResponseQueue: