Skip to content

Commit

Permalink
Unify response validation (0b01001001#339)
Browse files Browse the repository at this point in the history
* Remove special list response handling

* Unify response validation across plugins

* @kemingy's suggestions
  • Loading branch information
jean-edouard-boulanger authored Sep 21, 2023
1 parent 3cac8f3 commit 63e62c5
Show file tree
Hide file tree
Showing 22 changed files with 519 additions and 257 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ format:
lint:
isort --check --diff --project=spectree ${SOURCE_FILES}
black --check --diff ${SOURCE_FILES}
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=21
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=15
mypy --install-types --non-interactive ${MYPY_SOURCE_FILES}

.PHONY: test doc
3 changes: 3 additions & 0 deletions spectree/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,6 @@ class FunctionDecorator(Protocol):
deprecated: bool
path_parameter_descriptions: Optional[Mapping[str, str]]
_decorator: Any


JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]]
67 changes: 66 additions & 1 deletion spectree/plugins/base.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
from dataclasses import dataclass
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -8,9 +9,11 @@
NamedTuple,
Optional,
TypeVar,
Union,
)

from .._types import ModelType
from .._pydantic import ValidationError, is_root_model, serialize_model_instance
from .._types import JsonType, ModelType, OptionalModelType
from ..config import Configuration
from ..response import Response

Expand Down Expand Up @@ -122,3 +125,65 @@ def get_func_operation_id(self, func: Callable, path: str, method: str):
if not operation_id:
operation_id = f"{method.lower()}_{path.replace('/', '_')}"
return operation_id


@dataclass(frozen=True)
class RawResponsePayload:
payload: Union[JsonType, bytes]


@dataclass(frozen=True)
class ResponseValidationResult:
payload: Any


def validate_response(
skip_validation: bool,
validation_model: OptionalModelType,
response_payload: Any,
):
"""Validate a given `response_payload` against a `validation_model`.
:param skip_validation: When set to true, validation is not carried out
and the input `response_payload` is returned as-is. This is equivalent
to not providing a `validation_model`.
:param validation_model: Pydantic model used to validate the provided
`response_payload`.
:param response_payload: Validated response payload. A `RawResponsePayload`
should be provided when the plugin view function returned an already
JSON-serialized response payload.
"""
final_response_payload = None
if isinstance(response_payload, RawResponsePayload):
final_response_payload = response_payload.payload
elif skip_validation or validation_model is None:
final_response_payload = response_payload

if not skip_validation and validation_model and not final_response_payload:
if isinstance(response_payload, validation_model):
skip_validation = True
final_response_payload = serialize_model_instance(response_payload)
elif is_root_model(validation_model) and not isinstance(
response_payload, validation_model
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
response_payload = validation_model(__root__=response_payload)
except ValidationError:
raise
else:
skip_validation = True
final_response_payload = serialize_model_instance(response_payload)
else:
final_response_payload = response_payload

if validation_model and not skip_validation:
validator = (
validation_model.parse_raw
if isinstance(final_response_payload, bytes)
else validation_model.parse_obj
)
validator(final_response_payload)

return ResponseValidationResult(payload=final_response_payload)
108 changes: 32 additions & 76 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,12 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon import Response as FalconResponse
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN

from .._pydantic import (
BaseModel,
ValidationError,
is_root_model,
serialize_model_instance,
)
from .._pydantic import ValidationError
from .._types import ModelType
from ..response import Response
from .base import BasePlugin
from .base import BasePlugin, validate_response


class OpenAPI:
Expand Down Expand Up @@ -194,53 +188,6 @@ def request_validation(self, req, query, json, form, headers, cookies):
req_form = {x.name: x.stream.read() for x in req.get_media()}
req.context.form = form.parse_obj(req_form)

def response_validation(
self,
response_spec: Optional[Response],
falcon_response: FalconResponse,
skip_validation: bool,
) -> None:
if not skip_validation and response_spec and response_spec.has_model():
model = falcon_response.media
status = int(falcon_response.status[:3])
expect_model = response_spec.find_model(status)
if response_spec.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = response_spec.get_expected_list_item_type(
status
)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
falcon_response.media = [
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in model
]
elif (
expect_model
and is_root_model(expect_model)
and not isinstance(model, expect_model)
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
model = expect_model(__root__=model)
except ValidationError:
raise
else:
falcon_response.media = serialize_model_instance(model)
skip_validation = True
elif expect_model and isinstance(falcon_response.media, expect_model):
falcon_response.media = serialize_model_instance(model)
skip_validation = True

if self._data_set_manually(falcon_response):
skip_validation = True
if expect_model and not skip_validation:
expect_model.parse_obj(falcon_response.media)

def validate(
self,
func: Callable,
Expand Down Expand Up @@ -279,20 +226,25 @@ def validate(

func(*args, **kwargs)

try:
self.response_validation(
response_spec=resp,
falcon_response=_resp,
skip_validation=skip_validation,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
if not self._data_set_manually(_resp):
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=_resp.media,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
else:
_resp.media = response_validation_result.payload

after(_req, _resp, resp_validation_error, _self)

def _data_set_manually(self, resp):
@staticmethod
def _data_set_manually(resp):
return (resp.text is not None or resp.data is not None) and resp.media is None

def bypass(self, func, method):
Expand Down Expand Up @@ -375,15 +327,19 @@ async def validate(

await func(*args, **kwargs)

try:
self.response_validation(
response_spec=resp,
falcon_response=_resp,
skip_validation=skip_validation,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
if not self._data_set_manually(_resp):
try:
status = int(_resp.status[:3])
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=_resp.media,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
else:
_resp.media = response_validation_result.payload

after(_req, _resp, resp_validation_error, _self)
84 changes: 27 additions & 57 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,11 @@
from flask import Blueprint, abort, current_app, jsonify, make_response, request
from werkzeug.routing import parse_converter_args

from .._pydantic import (
BaseModel,
ValidationError,
is_root_model,
serialize_model_instance,
)
from .._pydantic import ValidationError
from .._types import ModelType
from ..response import Response
from ..utils import get_multidict_items, werkzeug_parse_rule
from .base import BasePlugin, Context
from .base import BasePlugin, Context, RawResponsePayload, validate_response


class FlaskPlugin(BasePlugin):
Expand Down Expand Up @@ -206,63 +201,38 @@ def validate(

status = 200
rest = []
if resp and isinstance(result, tuple) and isinstance(result[0], BaseModel):
if resp and isinstance(result, tuple):
if len(result) > 1:
model, status, *rest = result
response_payload, status, *rest = result
else:
model = result[0]
response_payload = result[0]
else:
model = result
response_payload = result

if not skip_validation and resp and not isinstance(result, flask.Response):
expect_model = resp.find_model(status)
if resp.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = resp.get_expected_list_item_type(status)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
result = (
[
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in model
],
try:
response_validation_result = validate_response(
skip_validation=skip_validation,
validation_model=resp.find_model(status) if resp else None,
response_payload=(
RawResponsePayload(payload=response_payload.get_json())
if (
isinstance(response_payload, flask.Response)
and not skip_validation
)
else response_payload
),
)
except ValidationError:
response = make_response(
jsonify({"message": "response validation error"}), 500
)
else:
response = make_response(
(
response_validation_result.payload,
status,
*rest,
)
elif (
expect_model
and is_root_model(expect_model)
and not isinstance(model, expect_model)
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
model = expect_model(__root__=model)
except ValidationError as err:
resp_validation_error = err
else:
skip_validation = True
result = (serialize_model_instance(model), status, *rest)
elif expect_model and isinstance(model, expect_model):
skip_validation = True
result = (serialize_model_instance(model), status, *rest)

response = make_response(result)

if resp and resp.has_model() and not resp_validation_error:
model = resp.find_model(response.status_code)
if model and not skip_validation:
try:
model.parse_obj(response.get_json())
except ValidationError as err:
resp_validation_error = err

if resp_validation_error:
response = make_response(
jsonify({"message": "response validation error"}), 500
)

after(request, response, resp_validation_error, None)
Expand Down
Loading

0 comments on commit 63e62c5

Please sign in to comment.