From 3cac8f34f1ca2fe3f40371ca58253c084a074d12 Mon Sep 17 00:00:00 2001 From: Jean-Edouard BOULANGER Date: Tue, 19 Sep 2023 04:47:38 +0100 Subject: [PATCH] [Proposal] Support Pydantic root model responses (#338) * Support Pydantic root model responses * kemingy's suggestions * Add additional tests for _pydantic * Also skip second validation if response validation error was previously found (flask) * Fix typing issue in test_pydantic.py * Update 'ping' flask blueprint endpoint * Don't run first validation on Flask Response objects --- Makefile | 2 +- spectree/_pydantic.py | 42 +++++ spectree/plugins/falcon_plugin.py | 32 +++- spectree/plugins/flask_plugin.py | 42 ++++- spectree/plugins/quart_plugin.py | 41 ++++- spectree/plugins/starlette_plugin.py | 12 +- .../test_plugin_spec[falcon][full_spec].json | 63 ++++++++ .../test_plugin_spec[flask][full_spec].json | 63 ++++++++ ...ugin_spec[flask_blueprint][full_spec].json | 63 ++++++++ ...st_plugin_spec[flask_view][full_spec].json | 32 ++++ ...est_plugin_spec[starlette][full_spec].json | 63 ++++++++ tests/common.py | 27 +++- tests/flask_imports/dry_plugin_flask.py | 15 ++ tests/test_plugin_falcon.py | 31 ++++ tests/test_plugin_falcon_asgi.py | 33 ++++ tests/test_plugin_flask.py | 11 ++ tests/test_plugin_flask_blueprint.py | 11 ++ tests/test_plugin_flask_view.py | 11 ++ tests/test_plugin_quart.py | 11 ++ tests/test_plugin_starlette.py | 30 +++- tests/test_pydantic.py | 147 ++++++++++++++++++ 21 files changed, 755 insertions(+), 27 deletions(-) create mode 100644 tests/test_pydantic.py diff --git a/Makefile b/Makefile index ab74608c..a52a4b50 100644 --- a/Makefile +++ b/Makefile @@ -45,7 +45,7 @@ format: lint: isort --check --diff --project=spectree ${SOURCE_FILES} black --check --diff ${SOURCE_FILES} - flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=17 + flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=21 mypy --install-types --non-interactive ${MYPY_SOURCE_FILES} .PHONY: test doc diff --git a/spectree/_pydantic.py b/spectree/_pydantic.py index 70559fe7..a1f96579 100644 --- a/spectree/_pydantic.py +++ b/spectree/_pydantic.py @@ -1,6 +1,10 @@ +from typing import Any + from pydantic.version import VERSION as PYDANTIC_VERSION PYDANTIC2 = PYDANTIC_VERSION.startswith("2") +ROOT_FIELD = "__root__" + __all__ = [ "BaseModel", @@ -11,6 +15,11 @@ "BaseSettings", "EmailStr", "validator", + "is_root_model", + "is_root_model_instance", + "serialize_model_instance", + "is_base_model", + "is_base_model_instance", ] if PYDANTIC2: @@ -35,3 +44,36 @@ root_validator, validator, ) + + +def is_base_model(t: Any) -> bool: + """Check whether a type is a Pydantic BaseModel""" + try: + return issubclass(t, BaseModel) + except TypeError: + return False + + +def is_base_model_instance(value: Any) -> bool: + """Check whether a value is a Pydantic BaseModel instance.""" + return is_base_model(type(value)) + + +def is_root_model(t: Any) -> bool: + """Check whether a type is a Pydantic RootModel.""" + return is_base_model(t) and ROOT_FIELD in t.__fields__ + + +def is_root_model_instance(value: Any): + """Check whether a value is a Pydantic RootModel instance.""" + return is_root_model(type(value)) + + +def serialize_model_instance(value: BaseModel): + """Serialize a Pydantic BaseModel (equivalent of calling `.dict()` on a BaseModel, + but additionally takes care of stripping __root__ for root models. + """ + serialized = value.dict() + if is_root_model_instance(value) and ROOT_FIELD in serialized: + return serialized[ROOT_FIELD] + return serialized diff --git a/spectree/plugins/falcon_plugin.py b/spectree/plugins/falcon_plugin.py index c6163508..283e03c6 100644 --- a/spectree/plugins/falcon_plugin.py +++ b/spectree/plugins/falcon_plugin.py @@ -7,7 +7,12 @@ from falcon import Response as FalconResponse from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN -from .._pydantic import BaseModel, ValidationError +from .._pydantic import ( + BaseModel, + ValidationError, + is_root_model, + serialize_model_instance, +) from .._types import ModelType from ..response import Response from .base import BasePlugin @@ -195,7 +200,7 @@ def response_validation( falcon_response: FalconResponse, skip_validation: bool, ) -> None: - if response_spec and response_spec.has_model(): + if not skip_validation and response_spec and response_spec.has_model(): model = falcon_response.media status = int(falcon_response.status[:3]) expect_model = response_spec.find_model(status) @@ -206,12 +211,31 @@ def response_validation( if all(isinstance(entry, expected_list_item_type) for entry in model): skip_validation = True falcon_response.media = [ - (entry.dict() if isinstance(entry, BaseModel) else entry) + ( + serialize_model_instance(entry) + if isinstance(entry, BaseModel) + else entry + ) for entry in model ] + elif ( + expect_model + and is_root_model(expect_model) + and not isinstance(model, expect_model) + ): + # Make it possible to return an instance of the model __root__ type + # (i.e. not the root model itself). + try: + model = expect_model(__root__=model) + except ValidationError: + raise + else: + falcon_response.media = serialize_model_instance(model) + skip_validation = True elif expect_model and isinstance(falcon_response.media, expect_model): - falcon_response.media = model.dict() + falcon_response.media = serialize_model_instance(model) skip_validation = True + if self._data_set_manually(falcon_response): skip_validation = True if expect_model and not skip_validation: diff --git a/spectree/plugins/flask_plugin.py b/spectree/plugins/flask_plugin.py index 5da90289..beca0c61 100644 --- a/spectree/plugins/flask_plugin.py +++ b/spectree/plugins/flask_plugin.py @@ -1,9 +1,15 @@ from typing import Any, Callable, Mapping, Optional, Tuple, get_type_hints +import flask from flask import Blueprint, abort, current_app, jsonify, make_response, request from werkzeug.routing import parse_converter_args -from .._pydantic import BaseModel, ValidationError +from .._pydantic import ( + BaseModel, + ValidationError, + is_root_model, + serialize_model_instance, +) from .._types import ModelType from ..response import Response from ..utils import get_multidict_items, werkzeug_parse_rule @@ -208,7 +214,7 @@ def validate( else: model = result - if resp: + if not skip_validation and resp and not isinstance(result, flask.Response): expect_model = resp.find_model(status) if resp.expect_list_result(status) and isinstance(model, list): expected_list_item_type = resp.get_expected_list_item_type(status) @@ -216,28 +222,48 @@ def validate( skip_validation = True result = ( [ - (entry.dict() if isinstance(entry, BaseModel) else entry) + ( + serialize_model_instance(entry) + if isinstance(entry, BaseModel) + else entry + ) for entry in model ], status, *rest, ) + elif ( + expect_model + and is_root_model(expect_model) + and not isinstance(model, expect_model) + ): + # Make it possible to return an instance of the model __root__ type + # (i.e. not the root model itself). + try: + model = expect_model(__root__=model) + except ValidationError as err: + resp_validation_error = err + else: + skip_validation = True + result = (serialize_model_instance(model), status, *rest) elif expect_model and isinstance(model, expect_model): skip_validation = True - result = (model.dict(), status, *rest) + result = (serialize_model_instance(model), status, *rest) response = make_response(result) - if resp and resp.has_model(): + if resp and resp.has_model() and not resp_validation_error: model = resp.find_model(response.status_code) if model and not skip_validation: try: model.parse_obj(response.get_json()) except ValidationError as err: resp_validation_error = err - response = make_response( - jsonify({"message": "response validation error"}), 500 - ) + + if resp_validation_error: + response = make_response( + jsonify({"message": "response validation error"}), 500 + ) after(request, response, resp_validation_error, None) diff --git a/spectree/plugins/quart_plugin.py b/spectree/plugins/quart_plugin.py index 4894e974..526cabb7 100644 --- a/spectree/plugins/quart_plugin.py +++ b/spectree/plugins/quart_plugin.py @@ -4,7 +4,12 @@ from quart import Blueprint, abort, current_app, jsonify, make_response, request from werkzeug.routing import parse_converter_args -from .._pydantic import BaseModel, ValidationError +from .._pydantic import ( + BaseModel, + ValidationError, + is_root_model, + serialize_model_instance, +) from .._types import ModelType from ..response import Response from ..utils import get_multidict_items, werkzeug_parse_rule @@ -220,7 +225,7 @@ async def validate( else: model = result - if resp: + if not skip_validation and resp: expect_model = resp.find_model(status) if resp.expect_list_result(status) and isinstance(model, list): expected_list_item_type = resp.get_expected_list_item_type(status) @@ -228,28 +233,48 @@ async def validate( skip_validation = True result = ( [ - (entry.dict() if isinstance(entry, BaseModel) else entry) + ( + serialize_model_instance(entry) + if isinstance(entry, BaseModel) + else entry + ) for entry in model ], status, *rest, ) + elif ( + expect_model + and is_root_model(expect_model) + and not isinstance(model, expect_model) + ): + # Make it possible to return an instance of the model __root__ type + # (i.e. not the root model itself). + try: + model = expect_model(__root__=model) + except ValidationError as err: + resp_validation_error = err + else: + skip_validation = True + result = (serialize_model_instance(model), status, *rest) elif expect_model and isinstance(model, expect_model): skip_validation = True - result = (model.dict(), status, *rest) + result = (serialize_model_instance(model), status, *rest) response = await make_response(result) - if resp and resp.has_model(): + if resp and resp.has_model() and not resp_validation_error: model = resp.find_model(response.status_code) if model and not skip_validation: try: model.parse_obj(await response.get_json()) except ValidationError as err: resp_validation_error = err - response = await make_response( - jsonify({"message": "response validation error"}), 500 - ) + + if resp_validation_error: + response = await make_response( + jsonify({"message": "response validation error"}), 500 + ) after(request, response, resp_validation_error, None) diff --git a/spectree/plugins/starlette_plugin.py b/spectree/plugins/starlette_plugin.py index d8fc115a..4fa4c7f4 100644 --- a/spectree/plugins/starlette_plugin.py +++ b/spectree/plugins/starlette_plugin.py @@ -9,7 +9,7 @@ from starlette.responses import HTMLResponse, JSONResponse from starlette.routing import compile_path -from .._pydantic import BaseModel, ValidationError +from .._pydantic import BaseModel, ValidationError, serialize_model_instance from .._types import ModelType from ..response import Response from .base import BasePlugin, Context @@ -24,11 +24,15 @@ def render(self, content) -> bytes: self._model_class = content.__class__ return super().render( [ - (entry.dict() if isinstance(entry, BaseModel) else entry) + ( + serialize_model_instance(entry) + if isinstance(entry, BaseModel) + else entry + ) for entry in content ] if isinstance(content, list) - else content.dict() + else serialize_model_instance(content) ) return _PydanticResponse(content) @@ -129,7 +133,7 @@ async def validate( else: response = func(*args, **kwargs) - if resp and response: + if not skip_validation and resp and response: if ( isinstance(response, JSONResponse) and hasattr(response, "_model_class") diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json index 9dacbf8f..3e521647 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[falcon][full_spec].json @@ -159,6 +159,38 @@ "title": "Resp", "type": "object" }, + "RootResp.7068f62": { + "anyOf": [ + { + "$ref": "#/components/schemas/RootResp.7068f62.JSON" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + } + ], + "title": "RootResp" + }, + "RootResp.7068f62.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "StrDict.7068f62": { "additionalProperties": { "type": "string" @@ -383,6 +415,37 @@ "tags": [] } }, + "/api/return_root": { + "get": { + "description": "", + "operationId": "get__api_return_root", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RootResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "on_get ", + "tags": [] + } + }, "/api/user/{name}": { "get": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json index 0559e6ba..19bd9452 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask][full_spec].json @@ -177,6 +177,38 @@ "title": "Resp", "type": "object" }, + "RootResp.7068f62": { + "anyOf": [ + { + "$ref": "#/components/schemas/RootResp.7068f62.JSON" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + } + ], + "title": "RootResp" + }, + "RootResp.7068f62.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "StrDict.7068f62": { "additionalProperties": { "type": "string" @@ -336,6 +368,37 @@ "tags": [] } }, + "/api/return_root": { + "get": { + "description": "", + "operationId": "get__api_return_root", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RootResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_root ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json index bea03587..81ca061b 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_blueprint][full_spec].json @@ -177,6 +177,38 @@ "title": "Resp", "type": "object" }, + "RootResp.7068f62": { + "anyOf": [ + { + "$ref": "#/components/schemas/RootResp.7068f62.JSON" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + } + ], + "title": "RootResp" + }, + "RootResp.7068f62.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "StrDict.7068f62": { "additionalProperties": { "type": "string" @@ -336,6 +368,37 @@ "tags": [] } }, + "/api/return_root": { + "get": { + "description": "", + "operationId": "get__api_return_root", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RootResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_root ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json index edac4522..15b9cd56 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[flask_view][full_spec].json @@ -177,6 +177,38 @@ "title": "Resp", "type": "object" }, + "RootResp.7068f62": { + "anyOf": [ + { + "$ref": "#/components/schemas/RootResp.7068f62.JSON" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + } + ], + "title": "RootResp" + }, + "RootResp.7068f62.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "StrDict.7068f62": { "additionalProperties": { "type": "string" diff --git a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json index 332a8dc6..6c4cd074 100644 --- a/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json +++ b/tests/__snapshots__/test_plugin/test_plugin_spec[starlette][full_spec].json @@ -159,6 +159,38 @@ "title": "Resp", "type": "object" }, + "RootResp.7068f62": { + "anyOf": [ + { + "$ref": "#/components/schemas/RootResp.7068f62.JSON" + }, + { + "items": { + "type": "integer" + }, + "type": "array" + } + ], + "title": "RootResp" + }, + "RootResp.7068f62.JSON": { + "properties": { + "limit": { + "title": "Limit", + "type": "integer" + }, + "name": { + "title": "Name", + "type": "string" + } + }, + "required": [ + "name", + "limit" + ], + "title": "JSON", + "type": "object" + }, "StrDict.7068f62": { "additionalProperties": { "type": "string" @@ -346,6 +378,37 @@ "tags": [] } }, + "/api/return_root": { + "get": { + "description": "", + "operationId": "get__api_return_root", + "parameters": [], + "responses": { + "200": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/RootResp.7068f62" + } + } + }, + "description": "OK" + }, + "422": { + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/ValidationError.6a07bef" + } + } + }, + "description": "Unprocessable Entity" + } + }, + "summary": "return_root ", + "tags": [] + } + }, "/api/user/{name}": { "post": { "description": "", diff --git a/tests/common.py b/tests/common.py index ec375a3e..a6847e51 100644 --- a/tests/common.py +++ b/tests/common.py @@ -1,5 +1,5 @@ from enum import Enum, IntEnum -from typing import Dict, List +from typing import Any, Dict, List, Union from spectree import BaseFile, ExternalDocs, SecurityScheme, SecuritySchemeData, Tag from spectree._pydantic import BaseModel, Field, root_validator @@ -46,6 +46,10 @@ class Resp(BaseModel): score: List[int] +class RootResp(BaseModel): + __root__: Union[JSON, List[int]] + + class Language(str, Enum): en = "en-US" zh = "zh-CN" @@ -177,3 +181,24 @@ def get_model_path_key(model_path: str) -> str: return model_name return f"{model_name}.{hash_module_path(module_path=model_path)}" + + +def get_root_resp_data(pre_serialize: bool, return_what: str): + assert return_what in ("RootResp_JSON", "RootResp_List", "JSON", "List") + data: Any + if return_what == "RootResp_JSON": + data = RootResp(__root__=JSON(name="user1", limit=1)) + elif return_what == "RootResp_List": + data = RootResp(__root__=[1, 2, 3, 4]) + elif return_what == "JSON": + data = JSON(name="user1", limit=1) + elif return_what == "List": + data = [1, 2, 3, 4] + pre_serialize = False + else: + assert False + if pre_serialize: + data = data.dict() + if "__root__" in data: + data = data["__root__"] + return data diff --git a/tests/flask_imports/dry_plugin_flask.py b/tests/flask_imports/dry_plugin_flask.py index 7146c3f2..bd036a00 100644 --- a/tests/flask_imports/dry_plugin_flask.py +++ b/tests/flask_imports/dry_plugin_flask.py @@ -171,6 +171,21 @@ def test_flask_return_list_request(client, pre_serialize: bool): ] +@pytest.mark.parametrize("pre_serialize", [False, True]) +@pytest.mark.parametrize( + "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "List"] +) +def test_flask_return_root_request(client, pre_serialize: bool, return_what: str): + resp = client.get( + f"/api/return_root?pre_serialize={int(pre_serialize)}&return_what={return_what}" + ) + assert resp.status_code == 200 + if return_what in ("RootResp_JSON", "JSON"): + assert resp.json == {"name": "user1", "limit": 1} + elif return_what in ("RootResp_List", "List"): + assert resp.json == [1, 2, 3, 4] + + def test_flask_upload_file(client): file_content = "abcdef" data = {"file": (io.BytesIO(file_content.encode("utf-8")), "test.txt")} diff --git a/tests/test_plugin_falcon.py b/tests/test_plugin_falcon.py index 9f2890fc..d937c2ed 100644 --- a/tests/test_plugin_falcon.py +++ b/tests/test_plugin_falcon.py @@ -14,8 +14,10 @@ ListJSON, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) @@ -207,6 +209,17 @@ def on_get(self, req, resp): resp.media = [entry.dict() if pre_serialize else entry for entry in data] +class ReturnRootView: + name = "return root request view" + + @api.validate(resp=Response(HTTP_200=RootResp)) + def on_get(self, req, resp): + resp.media = get_root_resp_data( + pre_serialize=bool(int(req.params.get("pre_serialize", 0))), + return_what=req.params.get("return_what", "RootResp"), + ) + + class ViewWithCustomSerializer: name = "view with custom serializer" @@ -234,6 +247,7 @@ def on_post(self, req, resp): app.add_route("/api/file_upload", FileUploadView()) app.add_route("/api/list_json", ListJsonView()) app.add_route("/api/return_list", ReturnListView()) +app.add_route("/api/return_root", ReturnRootView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) api.register(app) @@ -348,6 +362,23 @@ def test_falcon_return_list_request_sync(client, pre_serialize: bool): ] +@pytest.mark.parametrize("pre_serialize", [False, True]) +@pytest.mark.parametrize( + "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "List"] +) +def test_falcon_return_root_request_sync(client, pre_serialize: bool, return_what: str): + resp = client.simulate_request( + "GET", + f"/api/return_root?pre_serialize={int(pre_serialize)}" + f"&return_what={return_what}", + ) + assert resp.status_code == 200 + if return_what in ("RootResp_JSON", "JSON"): + assert resp.json == {"name": "user1", "limit": 1} + elif return_what in ("RootResp_List", "List"): + assert resp.json == [1, 2, 3, 4] + + @pytest.fixture def test_client_and_api(request): api_args = ["falcon"] diff --git a/tests/test_plugin_falcon_asgi.py b/tests/test_plugin_falcon_asgi.py index 57689db4..e0c2cfa2 100644 --- a/tests/test_plugin_falcon_asgi.py +++ b/tests/test_plugin_falcon_asgi.py @@ -15,8 +15,10 @@ ListJSON, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) @@ -132,6 +134,17 @@ async def on_get(self, req, resp): resp.media = [entry.dict() if pre_serialize else entry for entry in data] +class ReturnRootView: + name = "return root request view" + + @api.validate(resp=Response(HTTP_200=RootResp)) + async def on_get(self, req, resp): + resp.media = get_root_resp_data( + pre_serialize=bool(int(req.params.get("pre_serialize", 0))), + return_what=req.params.get("return_what", "RootResp"), + ) + + class FileUploadView: name = "file upload view" @@ -168,6 +181,7 @@ async def on_post(self, req, resp): app.add_route("/api/file_upload", FileUploadView()) app.add_route("/api/list_json", ListJsonView()) app.add_route("/api/return_list", ReturnListView()) +app.add_route("/api/return_root", ReturnRootView()) app.add_route("/api/custom_serializer", ViewWithCustomSerializer()) api.register(app) @@ -213,6 +227,25 @@ def test_falcon_return_list_request_async(client, pre_serialize: bool): ] +@pytest.mark.parametrize("pre_serialize", [False, True]) +@pytest.mark.parametrize( + "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "List"] +) +def test_falcon_return_root_request_async( + client, pre_serialize: bool, return_what: str +): + resp = client.simulate_request( + "GET", + f"/api/return_root?pre_serialize={int(pre_serialize)}" + f"&return_what={return_what}", + ) + assert resp.status_code == 200 + if return_what in ("RootResp_JSON", "JSON"): + assert resp.json == {"name": "user1", "limit": 1} + elif return_what in ("RootResp_List", "List"): + assert resp.json == [1, 2, 3, 4] + + def test_falcon_validate(client): resp = client.simulate_request( "GET", "/ping", headers={"Content-Type": "text/plain"} diff --git a/tests/test_plugin_flask.py b/tests/test_plugin_flask.py index 5d2fa974..9adaf946 100644 --- a/tests/test_plugin_flask.py +++ b/tests/test_plugin_flask.py @@ -17,8 +17,10 @@ Order, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) # import tests to execute @@ -177,6 +179,15 @@ def return_list(): return [entry.dict() if pre_serialize else entry for entry in data] +@app.route("/api/return_root", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=RootResp)) +def return_root(): + return get_root_resp_data( + pre_serialize=bool(int(request.args.get("pre_serialize", default=0))), + return_what=request.args.get("return_what", default="RootResp"), + ) + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_flask_blueprint.py b/tests/test_plugin_flask_blueprint.py index 33c896c6..20d65f24 100644 --- a/tests/test_plugin_flask_blueprint.py +++ b/tests/test_plugin_flask_blueprint.py @@ -16,9 +16,11 @@ Order, Query, Resp, + RootResp, StrDict, api_tag, get_paths, + get_root_resp_data, ) # import tests to execute @@ -166,6 +168,15 @@ def return_list(): return [entry.dict() if pre_serialize else entry for entry in data] +@app.route("/api/return_root", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=RootResp)) +def return_root(): + return get_root_resp_data( + pre_serialize=bool(int(request.args.get("pre_serialize", default=0))), + return_what=request.args.get("return_what", default="RootResp"), + ) + + api.register(app) flask_app = Flask(__name__) diff --git a/tests/test_plugin_flask_view.py b/tests/test_plugin_flask_view.py index 03e645ca..bea69871 100644 --- a/tests/test_plugin_flask_view.py +++ b/tests/test_plugin_flask_view.py @@ -17,8 +17,10 @@ Order, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) # import tests to execute @@ -179,6 +181,15 @@ def get(self): return [entry.dict() if pre_serialize else entry for entry in data] +class ReturnRootView(MethodView): + @api.validate(resp=Response(HTTP_200=RootResp)) + def get(self): + return get_root_resp_data( + pre_serialize=bool(int(request.args.get("pre_serialize", default=0))), + return_what=request.args.get("return_what", default="RootResp"), + ) + + app.add_url_rule("/ping", view_func=Ping.as_view("ping")) app.add_url_rule("/api/user/", view_func=User.as_view("user"), methods=["POST"]) app.add_url_rule( diff --git a/tests/test_plugin_quart.py b/tests/test_plugin_quart.py index 40b51783..6c1dd404 100644 --- a/tests/test_plugin_quart.py +++ b/tests/test_plugin_quart.py @@ -15,8 +15,10 @@ Order, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) @@ -159,6 +161,15 @@ def return_list(): return [entry.dict() if pre_serialize else entry for entry in data] +@app.route("/api/return_root", methods=["GET"]) +@api.validate(resp=Response(HTTP_200=RootResp)) +def return_root(): + return get_root_resp_data( + pre_serialize=bool(int(request.args.get("pre_serialize", default=0))), + return_what=request.args.get("return_what", default="RootResp"), + ) + + # INFO: ensures that spec is calculated and cached _after_ registering # view functions for validations. This enables tests to access `api.spec` # without app_context. diff --git a/tests/test_plugin_starlette.py b/tests/test_plugin_starlette.py index 9fd20fd0..9e7f9a19 100644 --- a/tests/test_plugin_starlette.py +++ b/tests/test_plugin_starlette.py @@ -22,8 +22,10 @@ Order, Query, Resp, + RootResp, StrDict, api_tag, + get_root_resp_data, ) @@ -152,6 +154,18 @@ async def return_list(request): ) +@api.validate(resp=Response(HTTP_200=RootResp)) +async def return_root(request): + return PydanticResponse( + get_root_resp_data( + pre_serialize=bool( + int(request.query_params.get("pre_serialize", default=0)) + ), + return_what=request.query_params.get("return_what", default="RootResp"), + ) + ) + + app = Starlette( routes=[ Route("/ping", Ping), @@ -186,6 +200,7 @@ async def return_list(request): Route("/file_upload", file_upload, methods=["POST"]), Route("/list_json", list_json, methods=["POST"]), Route("/return_list", return_list, methods=["GET"]), + Route("/return_root", return_root, methods=["GET"]), ], ), Mount("/static", app=StaticFiles(directory="docs"), name="static"), @@ -386,7 +401,7 @@ def test_json_list_request(client): @pytest.mark.parametrize("pre_serialize", [False, True]) -def test_return_list_request(client, pre_serialize: bool): +def test_starlette_return_list_request(client, pre_serialize: bool): resp = client.get(f"/api/return_list?pre_serialize={int(pre_serialize)}") assert resp.status_code == 200 assert resp.json() == [ @@ -395,6 +410,19 @@ def test_return_list_request(client, pre_serialize: bool): ] +@pytest.mark.parametrize( + "return_what", ["RootResp_JSON", "RootResp_List", "JSON", "List"] +) +def test_starlette_return_root_request_sync(client, return_what: str): + resp = client.get(f"/api/return_root?pre_serialize=0&return_what={return_what}") + assert resp.status_code == 200 + assert resp.status_code == 200 + if return_what in ("RootResp_JSON", "JSON"): + assert resp.json() == {"name": "user1", "limit": 1} + elif return_what in ("RootResp_List", "List"): + assert resp.json() == [1, 2, 3, 4] + + def test_starlette_upload_file(client): file_content = "abcdef" file_io = io.BytesIO(file_content.encode("utf-8")) diff --git a/tests/test_pydantic.py b/tests/test_pydantic.py new file mode 100644 index 00000000..ae7b595a --- /dev/null +++ b/tests/test_pydantic.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass +from typing import Any, List + +import pytest + +from spectree._pydantic import ( + BaseModel, + is_base_model, + is_base_model_instance, + is_root_model, + is_root_model_instance, + serialize_model_instance, +) + + +class DummyRootModel(BaseModel): + __root__: List[int] + + +class NestedRootModel(BaseModel): + __root__: DummyRootModel + + +class SimpleModel(BaseModel): + user_id: int + + +class Users(BaseModel): + __root__: List[SimpleModel] + + +@dataclass +class RootModelLookalike: + __root__: List[str] + + +@pytest.mark.parametrize( + "value, expected", + [ + (DummyRootModel, True), + (DummyRootModel(__root__=[1, 2, 3]), False), + (NestedRootModel, True), + (NestedRootModel(__root__=DummyRootModel(__root__=[1, 2, 3])), False), + (SimpleModel, False), + (SimpleModel(user_id=1), False), + (RootModelLookalike, False), + (RootModelLookalike(__root__=["False"]), False), + (list, False), + ([1, 2, 3], False), + (str, False), + ("str", False), + (int, False), + (1, False), + ], +) +def test_is_root_model(value: Any, expected: bool): + assert is_root_model(value) is expected + + +@pytest.mark.parametrize( + "value, expected", + [ + (DummyRootModel, False), + (DummyRootModel(__root__=[1, 2, 3]), True), + (NestedRootModel, False), + (NestedRootModel(__root__=DummyRootModel(__root__=[1, 2, 3])), True), + (SimpleModel, False), + (SimpleModel(user_id=1), False), + (RootModelLookalike, False), + (RootModelLookalike(__root__=["False"]), False), + (list, False), + ([1, 2, 3], False), + (str, False), + ("str", False), + (int, False), + (1, False), + ], +) +def test_is_root_model_instance(value, expected): + assert is_root_model_instance(value) is expected + + +@pytest.mark.parametrize( + "value, expected", + [ + (DummyRootModel, True), + (DummyRootModel(__root__=[1, 2, 3]), False), + (NestedRootModel, True), + (NestedRootModel(__root__=DummyRootModel(__root__=[1, 2, 3])), False), + (SimpleModel, True), + (SimpleModel(user_id=1), False), + (RootModelLookalike, False), + (RootModelLookalike(__root__=["False"]), False), + (list, False), + ([1, 2, 3], False), + (str, False), + ("str", False), + (int, False), + (1, False), + ], +) +def test_is_base_model(value, expected): + assert is_base_model(value) is expected + + +@pytest.mark.parametrize( + "value, expected", + [ + (DummyRootModel, False), + (DummyRootModel(__root__=[1, 2, 3]), True), + (NestedRootModel, False), + (NestedRootModel(__root__=DummyRootModel(__root__=[1, 2, 3])), True), + (SimpleModel, False), + (SimpleModel(user_id=1), True), + (RootModelLookalike, False), + (RootModelLookalike(__root__=["False"]), False), + (list, False), + ([1, 2, 3], False), + (str, False), + ("str", False), + (int, False), + (1, False), + ], +) +def test_is_base_model_instance(value, expected): + assert is_base_model_instance(value) is expected + + +@pytest.mark.parametrize( + "value, expected", + [ + (SimpleModel(user_id=1), {"user_id": 1}), + (DummyRootModel(__root__=[1, 2, 3]), [1, 2, 3]), + (NestedRootModel(__root__=DummyRootModel(__root__=[1, 2, 3])), [1, 2, 3]), + ( + Users( + __root__=[ + SimpleModel(user_id=1), + SimpleModel(user_id=2), + ] + ), + [{"user_id": 1}, {"user_id": 2}], + ), + ], +) +def test_serialize_model_instance(value, expected): + assert serialize_model_instance(value) == expected