From 67ab44dc9379c323fd2a3f7f223091a83f9eb0a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= Date: Tue, 13 Feb 2024 20:08:46 +0100 Subject: [PATCH] fix(3082): Exception handlers extract details from non-litestar exceptions (#3106) Fix exception middleware --- litestar/app.py | 9 +++ litestar/middleware/exceptions/middleware.py | 61 +++++++++++++------ tests/unit/test_app.py | 10 +-- tests/unit/test_exceptions.py | 43 ++++++++++--- .../test_exception_handler_middleware.py | 5 +- 5 files changed, 91 insertions(+), 37 deletions(-) diff --git a/litestar/app.py b/litestar/app.py index 48f3c33283..1ece9f016b 100644 --- a/litestar/app.py +++ b/litestar/app.py @@ -421,6 +421,15 @@ def __init__( if self.pdb_on_exception: warn_pdb_on_exception() + try: + from starlette.exceptions import HTTPException as StarletteHTTPException + + from litestar.middleware.exceptions.middleware import _starlette_exception_handler + + config.exception_handlers.setdefault(StarletteHTTPException, _starlette_exception_handler) + except ImportError: + pass + super().__init__( after_request=config.after_request, after_response=config.after_response, diff --git a/litestar/middleware/exceptions/middleware.py b/litestar/middleware/exceptions/middleware.py index 4828cfd8d5..f3ff1572b6 100644 --- a/litestar/middleware/exceptions/middleware.py +++ b/litestar/middleware/exceptions/middleware.py @@ -9,7 +9,7 @@ from litestar.datastructures import Headers from litestar.enums import MediaType, ScopeType -from litestar.exceptions import WebSocketException +from litestar.exceptions import HTTPException, LitestarException, WebSocketException from litestar.middleware.cors import CORSMiddleware from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response from litestar.serialization import encode_json @@ -20,6 +20,8 @@ if TYPE_CHECKING: + from starlette.exceptions import HTTPException as StarletteHTTPException + from litestar import Response from litestar.app import Litestar from litestar.connection import Request @@ -58,15 +60,16 @@ def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Excepti if not exception_handlers: return None - status_code: int | None = getattr(exc, "status_code", None) - if status_code and (exception_handler := exception_handlers.get(status_code)): - return exception_handler + default_handler: ExceptionHandler | None = None + if isinstance(exc, HTTPException): + if exception_handler := exception_handlers.get(exc.status_code): + return exception_handler + else: + default_handler = exception_handlers.get(HTTP_500_INTERNAL_SERVER_ERROR) return next( (exception_handlers[cast("Type[Exception]", cls)] for cls in getmro(type(exc)) if cls in exception_handlers), - exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR] - if not hasattr(exc, "status_code") and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers - else None, + default_handler, ) @@ -107,6 +110,17 @@ def to_response(self, request: Request | None = None) -> Response: ) +def _starlette_exception_handler(request: Request[Any, Any, Any], exc: StarletteHTTPException) -> Response: + return create_exception_response( + request=request, + exc=HTTPException( + detail=exc.detail, + status_code=exc.status_code, + headers=exc.headers, + ), + ) + + def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -> Response: """Construct a response from an exception. @@ -122,11 +136,23 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) - Returns: Response: HTTP response constructed from exception details. """ - status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) - if status_code == HTTP_500_INTERNAL_SERVER_ERROR: - detail = "Internal Server Error" + headers: dict[str, Any] | None + extra: dict[str, Any] | list | None + + if isinstance(exc, HTTPException): + status_code = exc.status_code + headers = exc.headers + extra = exc.extra else: - detail = getattr(exc, "detail", repr(exc)) + status_code = HTTP_500_INTERNAL_SERVER_ERROR + headers = None + extra = None + + detail = ( + exc.detail + if isinstance(exc, LitestarException) and status_code != HTTP_500_INTERNAL_SERVER_ERROR + else "Internal Server Error" + ) try: media_type = request.route_handler.media_type @@ -136,8 +162,8 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) - content = ExceptionResponseContent( status_code=status_code, detail=detail, - headers=getattr(exc, "headers", None), - extra=getattr(exc, "extra", None), + headers=headers, + extra=extra, media_type=media_type, ) return content.to_response(request=request) @@ -246,12 +272,13 @@ async def handle_websocket_exception(send: Send, exc: Exception) -> None: Returns: None. """ + code = 4000 + HTTP_500_INTERNAL_SERVER_ERROR + reason = "Internal Server Error" if isinstance(exc, WebSocketException): code = exc.code reason = exc.detail - else: - code = 4000 + getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) - reason = getattr(exc, "detail", repr(exc)) + elif isinstance(exc, LitestarException): + reason = exc.detail event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason} await send(event) @@ -266,7 +293,7 @@ def default_http_exception_handler(self, request: Request, exc: Exception) -> Re Returns: An HTTP response. """ - status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR) + status_code = exc.status_code if isinstance(exc, HTTPException) else HTTP_500_INTERNAL_SERVER_ERROR if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self._get_debug_scope(request.scope): return create_debug_response(request=request, exc=exc) return create_exception_response(request=request, exc=exc) diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py index 282b4a715c..f64bb18ff6 100644 --- a/tests/unit/test_app.py +++ b/tests/unit/test_app.py @@ -164,15 +164,7 @@ def test_app_config_object_used(app_config_object: AppConfig, monkeypatch: pytes # have been accessed during app instantiation. property_mocks: List[Tuple[str, Mock]] = [] for field in fields(AppConfig): - if field.name == "response_cache_config": - property_mock = PropertyMock(return_value=ResponseCacheConfig()) - if field.name in ["event_emitter_backend", "response_cache_config"]: - property_mock = PropertyMock(return_value=Mock()) - else: - # default iterable return value allows the mock properties that need to be iterated over in - # `Litestar.__init__()` to not blow up, for other properties it shouldn't matter what the value is for the - # sake of this test. - property_mock = PropertyMock(return_value=[]) + property_mock = PropertyMock() property_mocks.append((field.name, property_mock)) monkeypatch.setattr(type(app_config_object), field.name, property_mock, raising=False) diff --git a/tests/unit/test_exceptions.py b/tests/unit/test_exceptions.py index e1b5174674..fdfb064258 100644 --- a/tests/unit/test_exceptions.py +++ b/tests/unit/test_exceptions.py @@ -111,15 +111,13 @@ def test_create_exception_response_utility_litestar_http_exception(media_type: M @pytest.mark.parametrize("media_type", [MediaType.JSON, MediaType.TEXT]) def test_create_exception_response_utility_starlette_http_exception(media_type: MediaType) -> None: - exc = StarletteHTTPException(detail="starlette http exception", status_code=HTTP_400_BAD_REQUEST) - request = RequestFactory(handler_kwargs={"media_type": media_type}).get() - response = create_exception_response(request=request, exc=exc) - assert response.status_code == HTTP_400_BAD_REQUEST - assert response.media_type == media_type - if media_type == MediaType.JSON: - assert response.content == {"status_code": 400, "detail": "starlette http exception"} - else: - assert response.content == b'{"status_code":400,"detail":"starlette http exception"}' + @get("/", media_type=media_type) + def handler() -> str: + raise StarletteHTTPException(status_code=400) + + with create_test_client(handler) as client: + response = client.get("/", headers={"Accept": media_type}) + assert response.json() == {"status_code": 400, "detail": "Bad Request"} @pytest.mark.parametrize("media_type", [MediaType.JSON, MediaType.TEXT]) @@ -171,3 +169,30 @@ def handler() -> None: assert response.json().get("details").startswith("Traceback (most recent call last") else: assert response.text.startswith("Traceback (most recent call last") + + +def test_non_litestar_exception_with_status_code_is_500() -> None: + # https://github.com/litestar-org/litestar/issues/3082 + class MyException(Exception): + status_code: int = 400 + + @get("/") + def handler() -> None: + raise MyException("hello") + + with create_test_client([handler]) as client: + assert client.get("/").status_code == 500 + + +def test_non_litestar_exception_with_detail_is_not_included() -> None: + # https://github.com/litestar-org/litestar/issues/3082 + class MyException(Exception): + status_code: int = 400 + detail: str = "hello" + + @get("/") + def handler() -> None: + raise MyException() + + with create_test_client([handler], debug=False) as client: + assert client.get("/", headers={"Accept": MediaType.JSON}).json().get("detail") == "Internal Server Error" diff --git a/tests/unit/test_middleware/test_exception_handler_middleware.py b/tests/unit/test_middleware/test_exception_handler_middleware.py index 85cbe4c84b..2c0b1334d8 100644 --- a/tests/unit/test_middleware/test_exception_handler_middleware.py +++ b/tests/unit/test_middleware/test_exception_handler_middleware.py @@ -12,7 +12,7 @@ from litestar.logging.config import LoggingConfig, StructLoggingConfig from litestar.middleware.exceptions import ExceptionHandlerMiddleware from litestar.middleware.exceptions._debug_response import get_symbol_name -from litestar.middleware.exceptions.middleware import get_exception_handler +from litestar.middleware.exceptions.middleware import _starlette_exception_handler, get_exception_handler from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR from litestar.testing import TestClient, create_test_client from litestar.types import ExceptionHandlersMap @@ -124,7 +124,8 @@ def exception_handler(request: Request, exc: Exception) -> Response: app = Litestar(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None) assert app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0].exception_handlers == { # type: ignore - Exception: exception_handler + Exception: exception_handler, + StarletteHTTPException: _starlette_exception_handler, }