From 63e62c5f5f294cdc76a2e5e24d4735c82ef5ac80 Mon Sep 17 00:00:00 2001 From: Jean-Edouard BOULANGER Date: Thu, 21 Sep 2023 02:50:39 +0100 Subject: [PATCH] Unify response validation (#339) * Remove special list response handling * Unify response validation across plugins * @kemingy's suggestions --- Makefile | 2 +- spectree/_types.py | 3 + spectree/plugins/base.py | 67 ++++++- spectree/plugins/falcon_plugin.py | 108 ++++------- spectree/plugins/flask_plugin.py | 84 +++------ spectree/plugins/quart_plugin.py | 87 +++------ spectree/plugins/starlette_plugin.py | 32 ++-- .../test_plugin_spec[flask][full_spec].json | 4 +- ...ugin_spec[flask_blueprint][full_spec].json | 4 +- ...st_plugin_spec[flask_view][full_spec].json | 4 +- ...est_plugin_spec[starlette][full_spec].json | 4 +- tests/common.py | 31 +++- tests/flask_imports/dry_plugin_flask.py | 22 ++- tests/quart_imports/dry_plugin_quart.py | 18 +- tests/test_base_plugin.py | 171 ++++++++++++++++++ tests/test_plugin_falcon.py | 28 ++- tests/test_plugin_falcon_asgi.py | 3 +- tests/test_plugin_flask.py | 19 +- tests/test_plugin_flask_blueprint.py | 20 +- tests/test_plugin_flask_view.py | 15 +- tests/test_plugin_quart.py | 15 +- tests/test_plugin_starlette.py | 35 +++- 22 files changed, 519 insertions(+), 257 deletions(-) create mode 100644 tests/test_base_plugin.py diff --git a/Makefile b/Makefile index a52a4b50..a12db61b 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ format: lint: isort --check --diff --project=spectree ${SOURCE_FILES} black --check --diff ${SOURCE_FILES} - flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=21 + flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=15 mypy --install-types --non-interactive ${MYPY_SOURCE_FILES} .PHONY: test doc diff --git a/spectree/_types.py b/spectree/_types.py index 7aae8809..323c9955 100644 --- a/spectree/_types.py +++ b/spectree/_types.py @@ -41,3 +41,6 @@ class FunctionDecorator(Protocol): deprecated: bool path_parameter_descriptions: Optional[Mapping[str, str]] _decorator: Any + + +JsonType = Union[None, int, str, bool, List["JsonType"], Dict[str, "JsonType"]] diff --git a/spectree/plugins/base.py b/spectree/plugins/base.py index 819c68e9..6a7bce8e 100644 --- a/spectree/plugins/base.py +++ b/spectree/plugins/base.py @@ -1,4 +1,5 @@ import logging +from dataclasses import dataclass from typing import ( TYPE_CHECKING, Any, @@ -8,9 +9,11 @@ NamedTuple, Optional, TypeVar, + Union, ) -from .._types import ModelType +from .._pydantic import ValidationError, is_root_model, serialize_model_instance +from .._types import JsonType, ModelType, OptionalModelType from ..config import Configuration from ..response import Response @@ -122,3 +125,65 @@ def get_func_operation_id(self, func: Callable, path: str, method: str): if not operation_id: operation_id = f"{method.lower()}_{path.replace('/', '_')}" return operation_id + + +@dataclass(frozen=True) +class RawResponsePayload: + payload: Union[JsonType, bytes] + + +@dataclass(frozen=True) +class ResponseValidationResult: + payload: Any + + +def validate_response( + skip_validation: bool, + validation_model: OptionalModelType, + response_payload: Any, +): + """Validate a given `response_payload` against a `validation_model`. + + :param skip_validation: When set to true, validation is not carried out + and the input `response_payload` is returned as-is. This is equivalent + to not providing a `validation_model`. + :param validation_model: Pydantic model used to validate the provided + `response_payload`. + :param response_payload: Validated response payload. A `RawResponsePayload` + should be provided when the plugin view function returned an already + JSON-serialized response payload. + """ + final_response_payload = None + if isinstance(response_payload, RawResponsePayload): + final_response_payload = response_payload.payload + elif skip_validation or validation_model is None: + final_response_payload = response_payload + + if not skip_validation and validation_model and not final_response_payload: + if isinstance(response_payload, validation_model): + skip_validation = True + final_response_payload = serialize_model_instance(response_payload) + elif is_root_model(validation_model) and not isinstance( + response_payload, validation_model + ): + # Make it possible to return an instance of the model __root__ type + # (i.e. not the root model itself). + try: + response_payload = validation_model(__root__=response_payload) + except ValidationError: + raise + else: + skip_validation = True + final_response_payload = serialize_model_instance(response_payload) + else: + final_response_payload = response_payload + + if validation_model and not skip_validation: + validator = ( + validation_model.parse_raw + if isinstance(final_response_payload, bytes) + else validation_model.parse_obj + ) + validator(final_response_payload) + + return ResponseValidationResult(payload=final_response_payload) diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index 283e03c6..4d702496 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -4,18 +4,12 @@ from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints from falcon import HTTP_400, HTTP_415, HTTPError -from falcon import Response as FalconResponse from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN -from .._pydantic import ( - BaseModel, - ValidationError, - is_root_model, - serialize_model_instance, -) +from .._pydantic import ValidationError from .._types import ModelType from ..response import Response -from .base import BasePlugin +from .base import BasePlugin, validate_response class OpenAPI: @@ -194,53 +188,6 @@ def request_validation(self, req, query, json, form, headers, cookies): req_form = {x.name: x.stream.read() for x in req.get_media()} req.context.form = form.parse_obj(req_form) - def response_validation( - self, - response_spec: Optional[Response], - falcon_response: FalconResponse, - skip_validation: bool, - ) -> None: - if not skip_validation and response_spec and response_spec.has_model(): - model = falcon_response.media - status = int(falcon_response.status[:3]) - expect_model = response_spec.find_model(status) - if response_spec.expect_list_result(status) and isinstance(model, list): - expected_list_item_type = response_spec.get_expected_list_item_type( - status - ) - if all(isinstance(entry, expected_list_item_type) for entry in model): - skip_validation = True - falcon_response.media = [ - ( - serialize_model_instance(entry) - if isinstance(entry, BaseModel) - else entry - ) - for entry in model - ] - elif ( - expect_model - and is_root_model(expect_model) - and not isinstance(model, expect_model) - ): - # Make it possible to return an instance of the model __root__ type - # (i.e. not the root model itself). - try: - model = expect_model(__root__=model) - except ValidationError: - raise - else: - falcon_response.media = serialize_model_instance(model) - skip_validation = True - elif expect_model and isinstance(falcon_response.media, expect_model): - falcon_response.media = serialize_model_instance(model) - skip_validation = True - - if self._data_set_manually(falcon_response): - skip_validation = True - if expect_model and not skip_validation: - expect_model.parse_obj(falcon_response.media) - def validate( self, func: Callable, @@ -279,20 +226,25 @@ def validate( func(*args, **kwargs) - try: - self.response_validation( - response_spec=resp, - falcon_response=_resp, - skip_validation=skip_validation, - ) - except ValidationError as err: - resp_validation_error = err - _resp.status = HTTP_500 - _resp.media = err.errors() + if not self._data_set_manually(_resp): + try: + status = int(_resp.status[:3]) + response_validation_result = validate_response( + skip_validation=skip_validation, + validation_model=resp.find_model(status) if resp else None, + response_payload=_resp.media, + ) + except ValidationError as err: + resp_validation_error = err + _resp.status = HTTP_500 + _resp.media = err.errors() + else: + _resp.media = response_validation_result.payload after(_req, _resp, resp_validation_error, _self) - def _data_set_manually(self, resp): + @staticmethod + def _data_set_manually(resp): return (resp.text is not None or resp.data is not None) and resp.media is None def bypass(self, func, method): @@ -375,15 +327,19 @@ async def validate( await func(*args, **kwargs) - try: - self.response_validation( - response_spec=resp, - falcon_response=_resp, - skip_validation=skip_validation, - ) - except ValidationError as err: - resp_validation_error = err - _resp.status = HTTP_500 - _resp.media = err.errors() + if not self._data_set_manually(_resp): + try: + status = int(_resp.status[:3]) + response_validation_result = validate_response( + skip_validation=skip_validation, + validation_model=resp.find_model(status) if resp else None, + response_payload=_resp.media, + ) + except ValidationError as err: + resp_validation_error = err + _resp.status = HTTP_500 + _resp.media = err.errors() + else: + _resp.media = response_validation_result.payload after(_req, _resp, resp_validation_error, _self) diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index beca0c61..2226d4b6 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -4,16 +4,11 @@ from flask import Blueprint, abort, current_app, jsonify, make_response, request from werkzeug.routing import parse_converter_args -from .._pydantic import ( - BaseModel, - ValidationError, - is_root_model, - serialize_model_instance, -) +from .._pydantic import ValidationError from .._types import ModelType from ..response import Response from ..utils import get_multidict_items, werkzeug_parse_rule -from .base import BasePlugin, Context +from .base import BasePlugin, Context, RawResponsePayload, validate_response class FlaskPlugin(BasePlugin): @@ -206,63 +201,38 @@ def validate( status = 200 rest = [] - if resp and isinstance(result, tuple) and isinstance(result[0], BaseModel): + if resp and isinstance(result, tuple): if len(result) > 1: - model, status, *rest = result + response_payload, status, *rest = result else: - model = result[0] + response_payload = result[0] else: - model = result + response_payload = result - if not skip_validation and resp and not isinstance(result, flask.Response): - expect_model = resp.find_model(status) - if resp.expect_list_result(status) and isinstance(model, list): - expected_list_item_type = resp.get_expected_list_item_type(status) - if all(isinstance(entry, expected_list_item_type) for entry in model): - skip_validation = True - result = ( - [ - ( - serialize_model_instance(entry) - if isinstance(entry, BaseModel) - else entry - ) - for entry in model - ], + try: + response_validation_result = validate_response( + skip_validation=skip_validation, + validation_model=resp.find_model(status) if resp else None, + response_payload=( + RawResponsePayload(payload=response_payload.get_json()) + if ( + isinstance(response_payload, flask.Response) + and not skip_validation + ) + else response_payload + ), + ) + except ValidationError: + response = make_response( + jsonify({"message": "response validation error"}), 500 + ) + else: + response = make_response( + ( + response_validation_result.payload, status, *rest, ) - elif ( - expect_model - and is_root_model(expect_model) - and not isinstance(model, expect_model) - ): - # Make it possible to return an instance of the model __root__ type - # (i.e. not the root model itself). - try: - model = expect_model(__root__=model) - except ValidationError as err: - resp_validation_error = err - else: - skip_validation = True - result = (serialize_model_instance(model), status, *rest) - elif expect_model and isinstance(model, expect_model): - skip_validation = True - result = (serialize_model_instance(model), status, *rest) - - response = make_response(result) - - if resp and resp.has_model() and not resp_validation_error: - model = resp.find_model(response.status_code) - if model and not skip_validation: - try: - model.parse_obj(response.get_json()) - except ValidationError as err: - resp_validation_error = err - - if resp_validation_error: - response = make_response( - jsonify({"message": "response validation error"}), 500 ) after(request, response, resp_validation_error, None) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index 526cabb7..ddd1748c 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -1,19 +1,15 @@ import inspect from typing import Any, Callable, Mapping, Optional, Tuple, get_type_hints +import quart from quart import Blueprint, abort, current_app, jsonify, make_response, request from werkzeug.routing import parse_converter_args -from .._pydantic import ( - BaseModel, - ValidationError, - is_root_model, - serialize_model_instance, -) +from .._pydantic import ValidationError from .._types import ModelType from ..response import Response from ..utils import get_multidict_items, werkzeug_parse_rule -from .base import BasePlugin, Context +from .base import BasePlugin, Context, RawResponsePayload, validate_response class QuartPlugin(BasePlugin): @@ -217,63 +213,40 @@ async def validate( status = 200 rest = [] - if resp and isinstance(result, tuple) and isinstance(result[0], BaseModel): + if resp and isinstance(result, tuple): if len(result) > 1: - model, status, *rest = result + response_payload, status, *rest = result else: - model = result[0] + response_payload = result[0] else: - model = result + response_payload = result - if not skip_validation and resp: - expect_model = resp.find_model(status) - if resp.expect_list_result(status) and isinstance(model, list): - expected_list_item_type = resp.get_expected_list_item_type(status) - if all(isinstance(entry, expected_list_item_type) for entry in model): - skip_validation = True - result = ( - [ - ( - serialize_model_instance(entry) - if isinstance(entry, BaseModel) - else entry - ) - for entry in model - ], + try: + response_validation_result = validate_response( + skip_validation=skip_validation, + validation_model=resp.find_model(status) if resp else None, + response_payload=( + RawResponsePayload( + payload=(await response.get_json()) if response else None + ) + if ( + isinstance(response_payload, quart.Response) + and not skip_validation + ) + else response_payload + ), + ) + except ValidationError: + response = await make_response( + jsonify({"message": "response validation error"}), 500 + ) + else: + response = await make_response( + ( + response_validation_result.payload, status, *rest, ) - elif ( - expect_model - and is_root_model(expect_model) - and not isinstance(model, expect_model) - ): - # Make it possible to return an instance of the model __root__ type - # (i.e. not the root model itself). - try: - model = expect_model(__root__=model) - except ValidationError as err: - resp_validation_error = err - else: - skip_validation = True - result = (serialize_model_instance(model), status, *rest) - elif expect_model and isinstance(model, expect_model): - skip_validation = True - result = (serialize_model_instance(model), status, *rest) - - response = await make_response(result) - - if resp and resp.has_model() and not resp_validation_error: - model = resp.find_model(response.status_code) - if model and not skip_validation: - try: - model.parse_obj(await response.get_json()) - except ValidationError as err: - resp_validation_error = err - - if resp_validation_error: - response = await make_response( - jsonify({"message": "response validation error"}), 500 ) after(request, response, resp_validation_error, None) diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index 4fa4c7f4..d4006532 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -12,27 +12,22 @@ from .._pydantic import BaseModel, ValidationError, serialize_model_instance from .._types import ModelType from ..response import Response -from .base import BasePlugin, Context +from .base import BasePlugin, Context, RawResponsePayload, validate_response METHODS = {"get", "post", "put", "patch", "delete"} Route = namedtuple("Route", ["path", "methods", "func"]) +class _PydanticResponseModel(BaseModel): + __root__: Any + + def PydanticResponse(content): class _PydanticResponse(JSONResponse): def render(self, content) -> bytes: self._model_class = content.__class__ return super().render( - [ - ( - serialize_model_instance(entry) - if isinstance(entry, BaseModel) - else entry - ) - for entry in content - ] - if isinstance(content, list) - else serialize_model_instance(content) + serialize_model_instance(_PydanticResponseModel(__root__=content)) ) return _PydanticResponse(content) @@ -141,13 +136,14 @@ async def validate( ): skip_validation = True - model = resp.find_model(response.status_code) - if model and not skip_validation: - try: - model.parse_raw(response.body) - except ValidationError as err: - resp_validation_error = err - response = JSONResponse(err.errors(), 500) + try: + validate_response( + skip_validation=skip_validation, + validation_model=resp.find_model(response.status_code), + response_payload=RawResponsePayload(payload=response.body), + ) + except ValidationError as err: + response = JSONResponse(err.errors(), 500) after(request, response, resp_validation_error, instance) diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json index 19bd9452..45f40d4f 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json @@ -763,7 +763,7 @@ } ], "responses": { - "200": { + "202": { "content": { "application/json": { "schema": { @@ -771,7 +771,7 @@ } } }, - "description": "OK" + "description": "Accepted" }, "422": { "content": { diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json index 81ca061b..8a472b59 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json @@ -763,7 +763,7 @@ } ], "responses": { - "200": { + "202": { "content": { "application/json": { "schema": { @@ -771,7 +771,7 @@ } } }, - "description": "OK" + "description": "Accepted" }, "422": { "content": { diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index 15b9cd56..9a4ceb35 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -747,7 +747,7 @@ } ], "responses": { - "200": { + "202": { "content": { "application/json": { "schema": { @@ -755,7 +755,7 @@ } } }, - "description": "OK" + "description": "Accepted" }, "422": { "content": { diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json index 6c4cd074..aca73a73 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json @@ -763,7 +763,7 @@ } ], "responses": { - "200": { + "202": { "content": { "application/json": { "schema": { @@ -771,7 +771,7 @@ } } }, - "description": "OK" + "description": "Accepted" }, "422": { "content": { diff --git a/tests/common.py b/tests/common.py index a6847e51..4290d186 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,7 @@ +import xml.etree.ElementTree as ET +from dataclasses import dataclass from enum import Enum, IntEnum -from typing import Any, Dict, List, Union +from typing import Any, Dict, List, Union, cast from spectree import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag from spectree._pydantic import BaseModel, Field, root_validator @@ -202,3 +204,30 @@ def get_root_resp_data(pre_serialize: bool, return_what: str): if "__root__" in data: data = data["__root__"] return data + + +@dataclass(frozen=True) +class UserXmlData: + name: str + score: List[int] + + @staticmethod + def parse_xml(data: str) -> "UserXmlData": + root = ET.ElementTree(ET.fromstring(data)).getroot() + assert root.tag == "user" + children = [node for node in root] + assert len(children) == 2 + assert children[0].tag == "name" + assert children[1].tag == "x_score" + return UserXmlData( + name=cast(str, children[0].text), + score=[int(entry) for entry in cast(str, children[1].text).split(",")], + ) + + def dump_xml(self) -> str: + return f""" + + {self.name} + {','.join(str(entry) for entry in self.score)} + + """ diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index bd036a00..19f0122e 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -2,20 +2,30 @@ import pytest +from tests.common import UserXmlData -def test_flask_skip_validation(client): - client.set_cookie(key="pub", value="abcdefg") +@pytest.mark.parametrize("response_format", ["json", "xml"]) +def test_flask_skip_validation(client, response_format: str): + client.set_cookie(key="pub", value="abcdefg") + assert response_format in ("json", "xml") resp = client.post( - "/api/user_skip/flask?order=1", + f"/api/user_skip/flask?order=1&response_format={response_format}", json=dict(name="flask", limit=10), content_type="application/json", ) - assert resp.status_code == 200, resp.json + assert resp.status_code == 200 assert resp.headers.get("X-Validation") is None assert resp.headers.get("X-API") == "OK" - assert resp.json["name"] == "flask" - assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) + if response_format == "json": + assert resp.content_type == "application/json" + assert resp.json["name"] == "flask" + assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) + else: + assert resp.content_type == "text/xml" + user_xml_data = UserXmlData.parse_xml(resp.text) + assert user_xml_data.name == "flask" + assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) def test_flask_return_model(client): diff --git a/tests/quart_imports/dry_plugin_quart.py b/tests/quart_imports/dry_plugin_quart.py index 57098d16..d33ffd09 100644 --- a/tests/quart_imports/dry_plugin_quart.py +++ b/tests/quart_imports/dry_plugin_quart.py @@ -2,13 +2,16 @@ import pytest +from tests.common import UserXmlData -def test_quart_skip_validation(client): + +@pytest.mark.parametrize("response_format", ["json", "xml"]) +def test_quart_skip_validation(client, response_format: str): client.set_cookie("quart", "pub", "abcdefg") resp = asyncio.run( client.post( - "/api/user_skip/quart?order=1", + f"/api/user_skip/quart?order=1&response_format={response_format}", json=dict(name="quart", limit=10), headers={"Content-Type": "application/json"}, ) @@ -17,8 +20,15 @@ def test_quart_skip_validation(client): assert resp.status_code == 200, resp_json assert resp.headers.get("X-Validation") is None assert resp.headers.get("X-API") == "OK" - assert resp_json["name"] == "quart" - assert resp_json["x_score"] == sorted(resp_json["x_score"], reverse=True) + if response_format == "json": + assert resp.content_type == "application/json" + assert resp.json["name"] == "quart" + assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) + else: + assert resp.content_type == "text/xml" + user_xml_data = UserXmlData.parse_xml(resp.text) + assert user_xml_data.name == "quart" + assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) def test_quart_return_model(client): diff --git a/tests/test_base_plugin.py b/tests/test_base_plugin.py new file mode 100644 index 00000000..27a27560 --- /dev/null +++ b/tests/test_base_plugin.py @@ -0,0 +1,171 @@ +from contextlib import nullcontext as does_not_raise +from dataclasses import dataclass +from typing import Any, Union + +import pytest + +from spectree._pydantic import ValidationError +from spectree._types import OptionalModelType +from spectree.plugins.base import ( + RawResponsePayload, + ResponseValidationResult, + validate_response, +) +from spectree.utils import gen_list_model +from tests.common import JSON, Resp, RootResp, StrDict + +RespList = gen_list_model(Resp) + + +@dataclass(frozen=True) +class DummyResponse: + payload: bytes + content_type: str + + +@pytest.mark.parametrize( + [ + "skip_validation", + "validation_model", + "response_payload", + "expected_result", + ], + [ + ( + False, + Resp, + Resp(name="user1", score=[1, 2]), + ResponseValidationResult({"name": "user1", "score": [1, 2]}), + ), + ( + False, + Resp, + {"name": "user1", "score": [1, 2]}, + ResponseValidationResult({"name": "user1", "score": [1, 2]}), + ), + ( + False, + Resp, + RawResponsePayload({"name": "user1", "score": [1, 2]}), + ResponseValidationResult({"name": "user1", "score": [1, 2]}), + ), + ( + False, + Resp, + {}, + ValidationError, + ), + ( + False, + Resp, + {"name": "user1"}, + ValidationError, + ), + ( + False, + RootResp, + [1, 2, 3], + ResponseValidationResult([1, 2, 3]), + ), + ( + False, + RootResp, + RawResponsePayload([1, 2, 3]), + ResponseValidationResult([1, 2, 3]), + ), + ( + False, + StrDict, + StrDict(__root__={"key1": "value1", "key2": "value2"}), + ResponseValidationResult({"key1": "value1", "key2": "value2"}), + ), + ( + False, + RootResp, + {"name": "user2", "limit": 1}, + ResponseValidationResult({"name": "user2", "limit": 1}), + ), + ( + False, + RootResp, + RawResponsePayload({"name": "user2", "limit": 1}), + ResponseValidationResult({"name": "user2", "limit": 1}), + ), + ( + False, + RootResp, + JSON(name="user3", limit=5), + ResponseValidationResult({"name": "user3", "limit": 5}), + ), + ( + False, + RootResp, + RootResp(__root__=JSON(name="user4", limit=23)), + ResponseValidationResult({"name": "user4", "limit": 23}), + ), + ( + False, + RootResp, + {}, + ValidationError, + ), + ( + False, + RespList, + [], + ResponseValidationResult([]), + ), + ( + False, + RespList, + [{"name": "user5", "score": [5, 10]}], + ResponseValidationResult([{"name": "user5", "score": [5, 10]}]), + ), + ( + False, + RespList, + [Resp(name="user6", score=[10, 20]), Resp(name="user7", score=[30, 40])], + ResponseValidationResult( + [ + {"name": "user6", "score": [10, 20]}, + {"name": "user7", "score": [30, 40]}, + ] + ), + ), + ( + False, + None, + {"user_id": "user1", "locale": "en-gb"}, + ResponseValidationResult({"user_id": "user1", "locale": "en-gb"}), + ), + ( + True, + None, + DummyResponse(payload="".encode(), content_type="text/html"), + ResponseValidationResult( + DummyResponse( + payload="".encode(), content_type="text/html" + ) + ), + ), + ], +) +def test_validate_response( + skip_validation: bool, + validation_model: OptionalModelType, + response_payload: Any, + expected_result: Union[ResponseValidationResult, ValidationError], +): + runtime_expectation = ( + pytest.raises(ValidationError) + if expected_result == ValidationError + else does_not_raise() + ) + with runtime_expectation: + result = validate_response( + skip_validation=skip_validation, + validation_model=validation_model, + response_payload=response_payload, + ) + assert isinstance(result, ResponseValidationResult) + assert result == expected_result diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index d937c2ed..48f81088 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -1,8 +1,9 @@ from random import randint from typing import List +import falcon import pytest -from falcon import App, testing +from falcon import HTTP_202, App, testing from spectree import Response, SpecTree @@ -16,6 +17,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_root_resp_data, ) @@ -43,6 +45,7 @@ def on_get(self, req, resp): description """ resp.media = {"msg": "pong"} + resp.status = HTTP_202 class UserScore: @@ -114,11 +117,17 @@ def on_get(self, req, resp, name): skip_validation=True, ) def on_post(self, req, resp, name, query: Query, json: JSON, cookies: Cookies): + response_format = req.params.get("response_format") + assert response_format in ("json", "xml") score = [randint(0, req.context.json.limit) for _ in range(5)] score.sort(reverse=req.context.query.order) assert req.context.cookies.pub == "abcdefg" assert req.cookies["pub"] == "abcdefg" - resp.media = {"name": req.context.json.name, "x_score": score} + if response_format == "json": + resp.media = {"name": req.context.json.name, "x_score": score} + else: + resp.content_type = falcon.MEDIA_XML + resp.text = UserXmlData(name=req.context.json.name, score=score).dump_xml() class UserScoreModel: @@ -267,6 +276,7 @@ def test_falcon_validate(client): resp = client.simulate_request( "GET", "/ping", headers={"lang": "en-US", "Content-Type": "text/plain"} ) + assert resp.status_code == 202 assert resp.json == {"msg": "pong"} assert resp.headers.get("X-Error") is None assert resp.headers.get("X-Name") == "health check" @@ -302,16 +312,22 @@ def test_falcon_validate(client): assert resp.headers.get("X-Name") == "sorted random score" -def test_falcon_skip_validation(client): +@pytest.mark.parametrize("response_format", ["json", "xml"]) +def test_falcon_skip_validation(client, response_format: str): resp = client.simulate_request( "POST", - "/api/user_skip/falcon?order=1", + f"/api/user_skip/falcon?order=1&response_format={response_format}", json=dict(name="falcon", limit=10), headers={"Cookie": "pub=abcdefg"}, ) - assert resp.json["name"] == "falcon" - assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) assert resp.headers.get("X-Name") == "sorted random score" + if response_format == "json": + assert resp.json["name"] == "falcon" + assert resp.json["x_score"] == sorted(resp.json["x_score"], reverse=True) + else: + user_xml_data = UserXmlData.parse_xml(resp.text) + assert user_xml_data.name == "falcon" + assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) def test_falcon_return_model(client): diff --git a/tests/test_plugin_falcon_asgi.py b/tests/test_plugin_falcon_asgi.py index e0c2cfa2..3752454f 100644 --- a/tests/test_plugin_falcon_asgi.py +++ b/tests/test_plugin_falcon_asgi.py @@ -2,7 +2,7 @@ from typing import List import pytest -from falcon import testing +from falcon import HTTP_202, testing from falcon.asgi import App from spectree import Response, SpecTree @@ -317,6 +317,7 @@ async def on_get(self, req, resp): description """ resp.media = {"msg": "pong"} + resp.status = HTTP_202 app = App() app.add_route("/ping", Ping()) diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index 9adaf946..511c5210 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -19,6 +19,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_root_resp_data, ) @@ -43,25 +44,28 @@ def api_after_handler(req, resp, err, _): api = SpecTree("flask", before=before_handler, after=after_handler, annotations=True) app = Flask(__name__) app.config["TESTING"] = True +app.config["DEBUG"] = True api_secure = SpecTree("flask", security_schemes=SECURITY_SCHEMAS) app_secure = Flask(__name__) app_secure.config["TESTING"] = True +app_secure.config["DEBUG"] = True api_global_secure = SpecTree( "flask", security_schemes=SECURITY_SCHEMAS, security={"auth_apiKey": []} ) app_global_secure = Flask(__name__) app_global_secure.config["TESTING"] = True +app_global_secure.config["DEBUG"] = True @app.route("/ping") -@api.validate(headers=Headers, resp=Response(HTTP_200=StrDict), tags=["test", "health"]) +@api.validate(headers=Headers, resp=Response(HTTP_202=StrDict), tags=["test", "health"]) def ping(): """summary description""" - return jsonify(msg="pong") + return jsonify(msg="pong"), 202 @app.route("/api/file_upload", methods=["POST"]) @@ -119,11 +123,19 @@ def user_score_annotated(name, query: Query, json: JSON, form: Form, cookies: Co skip_validation=True, ) def user_score_skip_validation(name): + response_format = request.args.get("response_format") + assert response_format in ("json", "xml") score = [randint(0, request.context.json.limit) for _ in range(5)] score.sort(reverse=(request.context.query.order == Order.desc)) assert request.context.cookies.pub == "abcdefg" assert request.cookies["pub"] == "abcdefg" - return jsonify(name=request.context.json.name, x_score=score) + if response_format == "json": + return jsonify(name=request.context.json.name, x_score=score) + else: + return app.response_class( + UserXmlData(name=request.context.json.name, score=score).dump_xml(), + content_type="text/xml", + ) @app.route("/api/user_model/", methods=["POST"]) @@ -219,6 +231,7 @@ def test_client_and_api(request): api = SpecTree(*api_args, **api_kwargs) app = Flask(__name__) app.config["TESTING"] = True + app.config["DEBUG"] = True @app.route("/ping") @api.validate(**endpoint_kwargs) diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 20d65f24..2a5ac3a2 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -1,6 +1,7 @@ from random import randint from typing import List +import flask import pytest from flask import Blueprint, Flask, jsonify, request @@ -18,6 +19,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_paths, get_root_resp_data, @@ -45,12 +47,12 @@ def api_after_handler(req, resp, err, _): @app.route("/ping") -@api.validate(headers=Headers, resp=Response(HTTP_200=StrDict), tags=["test", "health"]) +@api.validate(headers=Headers, resp=Response(HTTP_202=StrDict), tags=["test", "health"]) def ping(): """summary description""" - return jsonify(msg="pong") + return jsonify(msg="pong"), 202 @app.route("/api/file_upload", methods=["POST"]) @@ -108,11 +110,19 @@ def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies, form: skip_validation=True, ) def user_score_skip_validation(name): + response_format = request.args.get("response_format") + assert response_format in ("json", "xml") score = [randint(0, request.context.json.limit) for _ in range(5)] score.sort(reverse=True if request.context.query.order == Order.desc else False) assert request.context.cookies.pub == "abcdefg" assert request.cookies["pub"] == "abcdefg" - return jsonify(name=request.context.json.name, x_score=score) + if response_format == "json": + return jsonify(name=request.context.json.name, x_score=score) + else: + return flask.Response( + UserXmlData(name=request.context.json.name, score=score).dump_xml(), + content_type="text/xml", + ) @app.route("/api/user_model/", methods=["POST"]) @@ -180,6 +190,8 @@ def return_root(): api.register(app) flask_app = Flask(__name__) +flask_app.config["DEBUG"] = True +flask_app.config["TESTING"] = True flask_app.register_blueprint(app) with flask_app.app_context(): api.spec @@ -203,7 +215,7 @@ def test_blueprint_prefix(client, prefix): assert resp.headers.get("X-Error") == "Validation Error" resp = client.get(prefix + "/ping", headers={"lang": "en-US"}) - assert resp.status_code == 200 + assert resp.status_code == 202 assert resp.json == {"msg": "pong"} assert resp.headers.get("X-Error") is None assert resp.headers.get("X-Validation") == "Pass" diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index bea69871..2e2f4b6e 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -19,6 +19,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_root_resp_data, ) @@ -47,13 +48,13 @@ def api_after_handler(req, resp, err, _): class Ping(MethodView): @api.validate( - headers=Headers, resp=Response(HTTP_200=StrDict), tags=["test", "health"] + headers=Headers, resp=Response(HTTP_202=StrDict), tags=["test", "health"] ) def get(self): """summary description""" - return jsonify(msg="pong") + return jsonify(msg="pong"), 202 class FileUploadView(MethodView): @@ -111,12 +112,20 @@ class UserSkip(MethodView): skip_validation=True, ) def post(self, name, query: Query, json: JSON, form: Form, cookies: Cookies): + response_format = request.args.get("response_format") + assert response_format in ("json", "xml") data_src = json or form score = [randint(0, int(data_src.limit)) for _ in range(5)] score.sort(reverse=(query.order == Order.desc)) assert cookies.pub == "abcdefg" assert request.cookies["pub"] == "abcdefg" - return jsonify(name=data_src.name, x_score=score) + if response_format == "json": + return jsonify(name=request.context.json.name, x_score=score) + else: + return app.response_class( + UserXmlData(name=request.context.json.name, score=score).dump_xml(), + content_type="text/xml", + ) class UserModel(MethodView): diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index 6c1dd404..f2f69a87 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -17,6 +17,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_root_resp_data, ) @@ -56,7 +57,7 @@ async def ping(): """summary description""" - return jsonify(msg="pong") + return jsonify(msg="pong"), 202 @app.route("/api/user/", methods=["POST"]) @@ -101,11 +102,21 @@ async def user_score_annotated(name, query: Query, json: JSON, cookies: Cookies) skip_validation=True, ) async def user_score_skip_validation(name): + response_format = request.args.get("response_format") + assert response_format in ("json", "xml") score = [randint(0, request.context.json.limit) for _ in range(5)] score.sort(reverse=True if request.context.query.order == Order.desc else False) assert request.context.cookies.pub == "abcdefg" assert request.cookies["pub"] == "abcdefg" - return jsonify(name=request.context.json.name, x_score=score) + if response_format == "json": + return jsonify(name=request.context.json.name, x_score=score) + else: + return app.response_class( + response=UserXmlData( + name=request.context.json.name, score=score + ).dump_xml(), + content_type="text/xml", + ) @app.route("/api/user_model/", methods=["POST"]) diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index 9e7f9a19..39b3deda 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -6,6 +6,7 @@ from starlette.applications import Starlette from starlette.endpoints import HTTPEndpoint from starlette.responses import JSONResponse +from starlette.responses import Response as StarletteResponse from starlette.routing import Mount, Route from starlette.staticfiles import StaticFiles from starlette.testclient import TestClient @@ -24,6 +25,7 @@ Resp, RootResp, StrDict, + UserXmlData, api_tag, get_root_resp_data, ) @@ -52,7 +54,7 @@ class Ping(HTTPEndpoint): @api.validate( headers=Headers, - resp=Response(HTTP_200=StrDict), + resp=Response(HTTP_202=StrDict), tags=["test", "health"], after=method_handler, ) @@ -60,7 +62,7 @@ def get(self, request): """summary description""" - return JSONResponse({"msg": "pong"}) + return JSONResponse({"msg": "pong"}, status_code=202) @api.validate( @@ -108,11 +110,18 @@ async def user_score_annotated(request, query: Query, json: JSON, cookies: Cooki skip_validation=True, ) async def user_score_skip(request): + response_format = request.query_params.get("response_format") score = [randint(0, request.context.json.limit) for _ in range(5)] score.sort(reverse=True if request.context.query.order == Order.desc else False) assert request.context.cookies.pub == "abcdefg" assert request.cookies["pub"] == "abcdefg" - return JSONResponse({"name": request.context.json.name, "x_score": score}) + if response_format == "json": + return JSONResponse({"name": request.context.json.name, "x_score": score}) + else: + return StarletteResponse( + UserXmlData(name=request.context.json.name, score=score).dump_xml(), + media_type="text/xml", + ) @api.validate( @@ -239,6 +248,7 @@ def test_starlette_validate(client): resp = client.get("/ping", headers={"lang": "en-US"}) assert resp.json() == {"msg": "pong"} + assert resp.status_code == 202 assert resp.headers.get("X-Error") is None assert resp.headers.get("X-Name") == "Ping" assert resp.headers.get("X-Validation") is None @@ -268,16 +278,23 @@ def test_starlette_validate(client): assert resp.headers.get("X-Validation") == "Pass" -def test_starlette_skip_validation(client): +@pytest.mark.parametrize("response_format", ["json", "xml"]) +def test_starlette_skip_validation(client, response_format: str): client.cookies = dict(pub="abcdefg") + assert response_format in ("json", "xml") resp = client.post( - "/api/user_skip/starlette?order=1", + f"/api/user_skip/starlette?order=1&response_format={response_format}", json=dict(name="starlette", limit=10), ) - resp_body = resp.json() - assert resp_body["name"] == "starlette" - assert resp_body["x_score"] == sorted(resp_body["x_score"], reverse=True) assert resp.headers.get("X-Validation") == "Pass" + if response_format == "json": + resp_body = resp.json() + assert resp_body["name"] == "starlette" + assert resp_body["x_score"] == sorted(resp_body["x_score"], reverse=True) + else: + user_xml_data = UserXmlData.parse_xml(resp.text) + assert user_xml_data.name == "starlette" + assert user_xml_data.score == sorted(user_xml_data.score, reverse=True) def test_starlette_return_model(client): @@ -316,7 +333,7 @@ def get(self, request): """summary description""" - return JSONResponse({"msg": "pong"}) + return JSONResponse({"msg": "pong"}, status_code=202) app = Starlette(routes=[Route("/ping", Ping)]) api.register(app)