Skip to content

Commit

Permalink
fix(3082): Exception handlers extract details from non-litestar excep…
Browse files Browse the repository at this point in the history
…tions (#3106)

Fix exception middleware
  • Loading branch information
provinzkraut authored Feb 13, 2024
1 parent b266f5a commit 67ab44d
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 37 deletions.
9 changes: 9 additions & 0 deletions litestar/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,15 @@ def __init__(
if self.pdb_on_exception:
warn_pdb_on_exception()

try:
from starlette.exceptions import HTTPException as StarletteHTTPException

from litestar.middleware.exceptions.middleware import _starlette_exception_handler

config.exception_handlers.setdefault(StarletteHTTPException, _starlette_exception_handler)
except ImportError:
pass

super().__init__(
after_request=config.after_request,
after_response=config.after_response,
Expand Down
61 changes: 44 additions & 17 deletions litestar/middleware/exceptions/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from litestar.datastructures import Headers
from litestar.enums import MediaType, ScopeType
from litestar.exceptions import WebSocketException
from litestar.exceptions import HTTPException, LitestarException, WebSocketException
from litestar.middleware.cors import CORSMiddleware
from litestar.middleware.exceptions._debug_response import _get_type_encoders_for_request, create_debug_response
from litestar.serialization import encode_json
Expand All @@ -20,6 +20,8 @@


if TYPE_CHECKING:
from starlette.exceptions import HTTPException as StarletteHTTPException

from litestar import Response
from litestar.app import Litestar
from litestar.connection import Request
Expand Down Expand Up @@ -58,15 +60,16 @@ def get_exception_handler(exception_handlers: ExceptionHandlersMap, exc: Excepti
if not exception_handlers:
return None

status_code: int | None = getattr(exc, "status_code", None)
if status_code and (exception_handler := exception_handlers.get(status_code)):
return exception_handler
default_handler: ExceptionHandler | None = None
if isinstance(exc, HTTPException):
if exception_handler := exception_handlers.get(exc.status_code):
return exception_handler
else:
default_handler = exception_handlers.get(HTTP_500_INTERNAL_SERVER_ERROR)

return next(
(exception_handlers[cast("Type[Exception]", cls)] for cls in getmro(type(exc)) if cls in exception_handlers),
exception_handlers[HTTP_500_INTERNAL_SERVER_ERROR]
if not hasattr(exc, "status_code") and HTTP_500_INTERNAL_SERVER_ERROR in exception_handlers
else None,
default_handler,
)


Expand Down Expand Up @@ -107,6 +110,17 @@ def to_response(self, request: Request | None = None) -> Response:
)


def _starlette_exception_handler(request: Request[Any, Any, Any], exc: StarletteHTTPException) -> Response:
return create_exception_response(
request=request,
exc=HTTPException(
detail=exc.detail,
status_code=exc.status_code,
headers=exc.headers,
),
)


def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -> Response:
"""Construct a response from an exception.
Expand All @@ -122,11 +136,23 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -
Returns:
Response: HTTP response constructed from exception details.
"""
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
if status_code == HTTP_500_INTERNAL_SERVER_ERROR:
detail = "Internal Server Error"
headers: dict[str, Any] | None
extra: dict[str, Any] | list | None

if isinstance(exc, HTTPException):
status_code = exc.status_code
headers = exc.headers
extra = exc.extra
else:
detail = getattr(exc, "detail", repr(exc))
status_code = HTTP_500_INTERNAL_SERVER_ERROR
headers = None
extra = None

detail = (
exc.detail
if isinstance(exc, LitestarException) and status_code != HTTP_500_INTERNAL_SERVER_ERROR
else "Internal Server Error"
)

try:
media_type = request.route_handler.media_type
Expand All @@ -136,8 +162,8 @@ def create_exception_response(request: Request[Any, Any, Any], exc: Exception) -
content = ExceptionResponseContent(
status_code=status_code,
detail=detail,
headers=getattr(exc, "headers", None),
extra=getattr(exc, "extra", None),
headers=headers,
extra=extra,
media_type=media_type,
)
return content.to_response(request=request)
Expand Down Expand Up @@ -246,12 +272,13 @@ async def handle_websocket_exception(send: Send, exc: Exception) -> None:
Returns:
None.
"""
code = 4000 + HTTP_500_INTERNAL_SERVER_ERROR
reason = "Internal Server Error"
if isinstance(exc, WebSocketException):
code = exc.code
reason = exc.detail
else:
code = 4000 + getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
reason = getattr(exc, "detail", repr(exc))
elif isinstance(exc, LitestarException):
reason = exc.detail

event: WebSocketCloseEvent = {"type": "websocket.close", "code": code, "reason": reason}
await send(event)
Expand All @@ -266,7 +293,7 @@ def default_http_exception_handler(self, request: Request, exc: Exception) -> Re
Returns:
An HTTP response.
"""
status_code = getattr(exc, "status_code", HTTP_500_INTERNAL_SERVER_ERROR)
status_code = exc.status_code if isinstance(exc, HTTPException) else HTTP_500_INTERNAL_SERVER_ERROR
if status_code == HTTP_500_INTERNAL_SERVER_ERROR and self._get_debug_scope(request.scope):
return create_debug_response(request=request, exc=exc)
return create_exception_response(request=request, exc=exc)
Expand Down
10 changes: 1 addition & 9 deletions tests/unit/test_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,15 +164,7 @@ def test_app_config_object_used(app_config_object: AppConfig, monkeypatch: pytes
# have been accessed during app instantiation.
property_mocks: List[Tuple[str, Mock]] = []
for field in fields(AppConfig):
if field.name == "response_cache_config":
property_mock = PropertyMock(return_value=ResponseCacheConfig())
if field.name in ["event_emitter_backend", "response_cache_config"]:
property_mock = PropertyMock(return_value=Mock())
else:
# default iterable return value allows the mock properties that need to be iterated over in
# `Litestar.__init__()` to not blow up, for other properties it shouldn't matter what the value is for the
# sake of this test.
property_mock = PropertyMock(return_value=[])
property_mock = PropertyMock()
property_mocks.append((field.name, property_mock))
monkeypatch.setattr(type(app_config_object), field.name, property_mock, raising=False)

Expand Down
43 changes: 34 additions & 9 deletions tests/unit/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,15 +111,13 @@ def test_create_exception_response_utility_litestar_http_exception(media_type: M

@pytest.mark.parametrize("media_type", [MediaType.JSON, MediaType.TEXT])
def test_create_exception_response_utility_starlette_http_exception(media_type: MediaType) -> None:
exc = StarletteHTTPException(detail="starlette http exception", status_code=HTTP_400_BAD_REQUEST)
request = RequestFactory(handler_kwargs={"media_type": media_type}).get()
response = create_exception_response(request=request, exc=exc)
assert response.status_code == HTTP_400_BAD_REQUEST
assert response.media_type == media_type
if media_type == MediaType.JSON:
assert response.content == {"status_code": 400, "detail": "starlette http exception"}
else:
assert response.content == b'{"status_code":400,"detail":"starlette http exception"}'
@get("/", media_type=media_type)
def handler() -> str:
raise StarletteHTTPException(status_code=400)

with create_test_client(handler) as client:
response = client.get("/", headers={"Accept": media_type})
assert response.json() == {"status_code": 400, "detail": "Bad Request"}


@pytest.mark.parametrize("media_type", [MediaType.JSON, MediaType.TEXT])
Expand Down Expand Up @@ -171,3 +169,30 @@ def handler() -> None:
assert response.json().get("details").startswith("Traceback (most recent call last")
else:
assert response.text.startswith("Traceback (most recent call last")


def test_non_litestar_exception_with_status_code_is_500() -> None:
# https://github.com/litestar-org/litestar/issues/3082
class MyException(Exception):
status_code: int = 400

@get("/")
def handler() -> None:
raise MyException("hello")

with create_test_client([handler]) as client:
assert client.get("/").status_code == 500


def test_non_litestar_exception_with_detail_is_not_included() -> None:
# https://github.com/litestar-org/litestar/issues/3082
class MyException(Exception):
status_code: int = 400
detail: str = "hello"

@get("/")
def handler() -> None:
raise MyException()

with create_test_client([handler], debug=False) as client:
assert client.get("/", headers={"Accept": MediaType.JSON}).json().get("detail") == "Internal Server Error"
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from litestar.logging.config import LoggingConfig, StructLoggingConfig
from litestar.middleware.exceptions import ExceptionHandlerMiddleware
from litestar.middleware.exceptions._debug_response import get_symbol_name
from litestar.middleware.exceptions.middleware import get_exception_handler
from litestar.middleware.exceptions.middleware import _starlette_exception_handler, get_exception_handler
from litestar.status_codes import HTTP_400_BAD_REQUEST, HTTP_500_INTERNAL_SERVER_ERROR
from litestar.testing import TestClient, create_test_client
from litestar.types import ExceptionHandlersMap
Expand Down Expand Up @@ -124,7 +124,8 @@ def exception_handler(request: Request, exc: Exception) -> Response:

app = Litestar(route_handlers=[handler], exception_handlers={Exception: exception_handler}, openapi_config=None)
assert app.asgi_router.root_route_map_node.children["/"].asgi_handlers["GET"][0].exception_handlers == { # type: ignore
Exception: exception_handler
Exception: exception_handler,
StarletteHTTPException: _starlette_exception_handler,
}


Expand Down

0 comments on commit 67ab44d

Please sign in to comment.