Skip to content

Commit

Permalink
Handle invalid body
Browse files Browse the repository at this point in the history
  • Loading branch information
provinzkraut committed Feb 13, 2024
1 parent b266f5a commit d8f7989
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 15 deletions.
47 changes: 39 additions & 8 deletions litestar/data_extractors.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Callable, Coroutine, Literal, TypedDict, cast
import inspect
from typing import TYPE_CHECKING, Any, Callable, Coroutine, Iterable, Literal, TypedDict, cast

from litestar._parsers import parse_cookie_string
from litestar.connection.request import Request
Expand Down Expand Up @@ -70,6 +71,7 @@ class ConnectionDataExtractor:
"parse_query",
"obfuscate_headers",
"obfuscate_cookies",
"skip_parse_malformed_body",
)

def __init__(
Expand All @@ -88,6 +90,7 @@ def __init__(
obfuscate_headers: set[str] | None = None,
parse_body: bool = False,
parse_query: bool = False,
skip_parse_malformed_body: bool = False,
) -> None:
"""Initialize ``ConnectionDataExtractor``
Expand All @@ -106,9 +109,11 @@ def __init__(
obfuscate_cookies: cookie keys to obfuscate. Obfuscated values are replaced with '*****'.
parse_body: Whether to parse the body value or return the raw byte string, (for requests only).
parse_query: Whether to parse query parameters or return the raw byte string.
skip_parse_malformed_body: Whether to skip parsing the body if it is malformed
"""
self.parse_body = parse_body
self.parse_query = parse_query
self.skip_parse_malformed_body = skip_parse_malformed_body
self.obfuscate_headers = {h.lower() for h in (obfuscate_headers or set())}
self.obfuscate_cookies = {c.lower() for c in (obfuscate_cookies or set())}
self.connection_extractors: dict[str, Callable[[ASGIConnection[Any, Any, Any, Any]], Any]] = {}
Expand Down Expand Up @@ -153,6 +158,25 @@ def __call__(self, connection: ASGIConnection[Any, Any, Any, Any]) -> ExtractedR
)
return cast("ExtractedRequestData", {key: extractor(connection) for key, extractor in extractors.items()})

async def extract(
self, connection: ASGIConnection[Any, Any, Any, Any], fields: Iterable[str]
) -> ExtractedRequestData:
extractors = (
{**self.connection_extractors, **self.request_extractors} # type: ignore
if isinstance(connection, Request)
else self.connection_extractors
)
data = {}
for key, extractor in extractors.items():
if key not in fields:
continue
if inspect.iscoroutinefunction(extractor):
value = await extractor(connection)
else:
value = extractor(connection)
data[key] = value
return data

@staticmethod
def extract_scheme(connection: ASGIConnection[Any, Any, Any, Any]) -> str:
"""Extract the scheme from an ``ASGIConnection``
Expand Down Expand Up @@ -272,13 +296,20 @@ async def extract_body(self, request: Request[Any, Any, Any]) -> Any:
return None
if not self.parse_body:
return await request.body()
request_encoding_type = request.content_type[0]
if request_encoding_type == RequestEncodingType.JSON:
return await request.json()
form_data = await request.form()
if request_encoding_type == RequestEncodingType.URL_ENCODED:
return dict(form_data)
return {key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()}
try:
request_encoding_type = request.content_type[0]
if request_encoding_type == RequestEncodingType.JSON:
return await request.json()
form_data = await request.form()
if request_encoding_type == RequestEncodingType.URL_ENCODED:
return dict(form_data)
return {
key: repr(value) if isinstance(value, UploadFile) else value for key, value in form_data.multi_items()
}
except Exception as exc:
if self.skip_parse_on_exception:
return await request.body()
raise exc


class ExtractedResponseData(TypedDict, total=False):
Expand Down
11 changes: 5 additions & 6 deletions litestar/middleware/logging.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from __future__ import annotations

from dataclasses import dataclass, field
from inspect import isawaitable
from typing import TYPE_CHECKING, Any, Iterable

from litestar.constants import (
Expand Down Expand Up @@ -81,6 +80,7 @@ def __init__(self, app: ASGIApp, config: LoggingMiddlewareConfig) -> None:
obfuscate_headers=self.config.request_headers_to_obfuscate,
parse_body=self.is_struct_logger,
parse_query=self.is_struct_logger,
skip_parse_malformed_body=True,
)
self.response_extractor = ResponseDataExtractor(
extract_body="body" in self.config.response_log_fields,
Expand Down Expand Up @@ -172,12 +172,11 @@ async def extract_request_data(self, request: Request) -> dict[str, Any]:

data: dict[str, Any] = {"message": self.config.request_log_message}
serializer = get_serializer_from_scope(request.scope)
extracted_data = self.request_extractor(connection=request)

extracted_data = await self.request_extractor.extract(connection=request, fields=self.config.request_log_fields)

for key in self.config.request_log_fields:
value = extracted_data.get(key)
if isawaitable(value):
value = await value
data[key] = self._serialize_value(serializer, value)
data[key] = self._serialize_value(serializer, extracted_data.get(key))
return data

def extract_response_data(self, scope: Scope) -> dict[str, Any]:
Expand Down
16 changes: 15 additions & 1 deletion tests/unit/test_middleware/test_logging_middleware.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from logging import INFO
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Any, Dict

import pytest
from structlog.testing import capture_logs
Expand Down Expand Up @@ -286,3 +286,17 @@ async def get_session() -> None:
assert response.status_code == HTTP_200_OK
assert "session" in client.cookies
assert client.cookies["session"] == session_id


def test_structlog_invalid_request_body_handled():
# https://github.com/litestar-org/litestar/issues/3063
@post("/")
async def hello_world(data: dict[str, Any]) -> dict[str, Any]:
return data

with create_test_client(
route_handlers=[hello_world],
logging_config=StructLoggingConfig(log_exceptions="always"),
middleware=[LoggingMiddlewareConfig().middleware],
) as client:
assert client.post("/", headers={"Content-Type": "application/json"}, content=b'{"a": "b",}').status_code == 400

0 comments on commit d8f7989

Please sign in to comment.