From d8f7989b137a63998add59c3fe9a0fa6869d05e0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Janek=20Nouvertn=C3=A9?= <25355197+provinzkraut@users.noreply.github.com> Date: Tue, 13 Feb 2024 20:03:34 +0100 Subject: [PATCH] Handle invalid body --- litestar/data_extractors.py | 47 +++++++++++++++---- litestar/middleware/logging.py | 11 ++--- .../test_logging_middleware.py | 16 ++++++- 3 files changed, 59 insertions(+), 15 deletions(-) diff --git a/litestar/data_extractors.py b/litestar/data_extractors.py index 6d4b182133..c291875c37 100644 --- a/litestar/data_extractors.py +++ b/litestar/data_extractors.py @@ -1,6 +1,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypedDict, cast +import inspect +from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast from litestar._parsers import parse_cookie_string from litestar.connection.request import Request @@ -70,6 +71,7 @@ class ConnectionDataExtractor: "parse_query", "obfuscate_headers", "obfuscate_cookies", + "skip_parse_malformed_body", ) def __init__( @@ -88,6 +90,7 @@ def __init__( obfuscate_headers: set[str] | None = None, parse_body: bool = False, parse_query: bool = False, + skip_parse_malformed_body: bool = False, ) -> None: """Initialize ``ConnectionDataExtractor`` @@ -106,9 +109,11 @@ def __init__( obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'. parse_body: Whether to parse the body value or return the raw byte string, (for requests only). parse_query: Whether to parse query parameters or return the raw byte string. + skip_parse_malformed_body: Whether to skip parsing the body if it is malformed """ self.parse_body = parse_body self.parse_query = parse_query + self.skip_parse_malformed_body = skip_parse_malformed_body self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())} self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())} self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {} @@ -153,6 +158,25 @@ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedR ) return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()}) + async def extract( + self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str] + ) -> ExtractedRequestData: + extractors = ( + {**self.connection_extractors, **self.request_extractors} # type: ignore + if isinstance(connection, Request) + else self.connection_extractors + ) + data = {} + for key, extractor in extractors.items(): + if key not in fields: + continue + if inspect.iscoroutinefunction(extractor): + value = await extractor(connection) + else: + value = extractor(connection) + data[key] = value + return data + @staticmethod def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str: """Extract the scheme from an ``ASGIConnection`` @@ -272,13 +296,20 @@ async def extract_body(self, request: Request[Any, Any, Any]) -> Any: return None if not self.parse_body: return await request.body() - request_encoding_type = request.content_type[0] - if request_encoding_type == RequestEncodingType.JSON: - return await request.json() - form_data = await request.form() - if request_encoding_type == RequestEncodingType.URL_ENCODED: - return dict(form_data) - return {key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()} + try: + request_encoding_type = request.content_type[0] + if request_encoding_type == RequestEncodingType.JSON: + return await request.json() + form_data = await request.form() + if request_encoding_type == RequestEncodingType.URL_ENCODED: + return dict(form_data) + return { + key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items() + } + except Exception as exc: + if self.skip_parse_on_exception: + return await request.body() + raise exc class ExtractedResponseData(TypedDict, total=False): diff --git a/litestar/middleware/logging.py b/litestar/middleware/logging.py index d52963b60e..dc827e303e 100644 --- a/litestar/middleware/logging.py +++ b/litestar/middleware/logging.py @@ -1,7 +1,6 @@ from __future__ import annotations from dataclasses import dataclass, field -from inspect import isawaitable from typing import TYPE_CHECKING, Any, Iterable from litestar.constants import ( @@ -81,6 +80,7 @@ def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None: obfuscate_headers=self.config.request_headers_to_obfuscate, parse_body=self.is_struct_logger, parse_query=self.is_struct_logger, + skip_parse_malformed_body=True, ) self.response_extractor = ResponseDataExtractor( extract_body="body" in self.config.response_log_fields, @@ -172,12 +172,11 @@ async def extract_request_data(self, request: Request) -> dict[str, Any]: data: dict[str, Any] = {"message": self.config.request_log_message} serializer = get_serializer_from_scope(request.scope) - extracted_data = self.request_extractor(connection=request) + + extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields) + for key in self.config.request_log_fields: - value = extracted_data.get(key) - if isawaitable(value): - value = await value - data[key] = self._serialize_value(serializer, value) + data[key] = self._serialize_value(serializer, extracted_data.get(key)) return data def extract_response_data(self, scope: Scope) -> dict[str, Any]: diff --git a/tests/unit/test_middleware/test_logging_middleware.py b/tests/unit/test_middleware/test_logging_middleware.py index 98b1bee3e1..7aaaa85076 100644 --- a/tests/unit/test_middleware/test_logging_middleware.py +++ b/tests/unit/test_middleware/test_logging_middleware.py @@ -1,5 +1,5 @@ from logging import INFO -from typing import TYPE_CHECKING, Dict +from typing import TYPE_CHECKING, Any, Dict import pytest from structlog.testing import capture_logs @@ -286,3 +286,17 @@ async def get_session() -> None: assert response.status_code == HTTP_200_OK assert "session" in client.cookies assert client.cookies["session"] == session_id + + +def test_structlog_invalid_request_body_handled(): + # https://github.com/litestar-org/litestar/issues/3063 + @post("/") + async def hello_world(data: dict[str, Any]) -> dict[str, Any]: + return data + + with create_test_client( + route_handlers=[hello_world], + logging_config=StructLoggingConfig(log_exceptions="always"), + middleware=[LoggingMiddlewareConfig().middleware], + ) as client: + assert client.post("/", headers={"Content-Type": "application/json"}, content=b'{"a": "b",}').status_code == 400