Skip to content

Commit

Permalink
Support returning lists of unserialized models (0b01001001#335)
Browse files Browse the repository at this point in the history
* Fully support list responses

* Skip validation if all list entries have the expected type

* Implement common solution for Falcon sync/async response validation
  • Loading branch information
jean-edouard-boulanger authored Aug 16, 2023
1 parent bce9963 commit d7e19da
Show file tree
Hide file tree
Showing 23 changed files with 536 additions and 45 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ format:
lint:
isort --check --diff --project=spectree ${SOURCE_FILES}
black --check --diff ${SOURCE_FILES}
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=15
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=17
mypy --install-types --non-interactive ${MYPY_SOURCE_FILES}

.PHONY: test doc
4 changes: 3 additions & 1 deletion spectree/_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@
Optional,
Sequence,
Type,
TypeVar,
Union,
)

from typing_extensions import Protocol

from ._pydantic import BaseModel

ModelType = Type[BaseModel]
BaseModelSubclassType = TypeVar("BaseModelSubclassType", bound=BaseModel)
ModelType = Type[BaseModelSubclassType]
OptionalModelType = Optional[ModelType]
NamingStrategy = Callable[[ModelType], str]
NestedNamingStrategy = Callable[[str, str], str]
Expand Down
84 changes: 50 additions & 34 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,10 @@
from typing import Any, Callable, Dict, List, Mapping, Optional, get_type_hints

from falcon import HTTP_400, HTTP_415, HTTPError
from falcon import Response as FalconResponse
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN

from .._pydantic import ValidationError
from .._pydantic import BaseModel, ValidationError
from .._types import ModelType
from ..response import Response
from .base import BasePlugin
Expand Down Expand Up @@ -188,6 +189,34 @@ def request_validation(self, req, query, json, form, headers, cookies):
req_form = {x.name: x.stream.read() for x in req.get_media()}
req.context.form = form.parse_obj(req_form)

def response_validation(
self,
response_spec: Optional[Response],
falcon_response: FalconResponse,
skip_validation: bool,
) -> None:
if response_spec and response_spec.has_model():
model = falcon_response.media
status = int(falcon_response.status[:3])
expect_model = response_spec.find_model(status)
if response_spec.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = response_spec.get_expected_list_item_type(
status
)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
falcon_response.media = [
(entry.dict() if isinstance(entry, BaseModel) else entry)
for entry in model
]
elif expect_model and isinstance(falcon_response.media, expect_model):
falcon_response.media = model.dict()
skip_validation = True
if self._data_set_manually(falcon_response):
skip_validation = True
if expect_model and not skip_validation:
expect_model.parse_obj(falcon_response.media)

def validate(
self,
func: Callable,
Expand Down Expand Up @@ -226,22 +255,16 @@ def validate(

func(*args, **kwargs)

if resp and resp.has_model():
model = resp.find_model(_resp.status[:3])
if model and isinstance(_resp.media, model):
_resp.media = _resp.media.dict()
skip_validation = True

if self._data_set_manually(_resp):
skip_validation = True

if model and not skip_validation:
try:
model.parse_obj(_resp.media)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
try:
self.response_validation(
response_spec=resp,
falcon_response=_resp,
skip_validation=skip_validation,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)

Expand Down Expand Up @@ -328,22 +351,15 @@ async def validate(

await func(*args, **kwargs)

if resp and resp.has_model():
model = resp.find_model(_resp.status[:3])
if model and isinstance(_resp.media, model):
_resp.media = _resp.media.dict()
skip_validation = True

if self._data_set_manually(_resp):
skip_validation = True

model = resp.find_model(_resp.status[:3])
if model and not skip_validation:
try:
model.parse_obj(_resp.media)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()
try:
self.response_validation(
response_spec=resp,
falcon_response=_resp,
skip_validation=skip_validation,
)
except ValidationError as err:
resp_validation_error = err
_resp.status = HTTP_500
_resp.media = err.errors()

after(_req, _resp, resp_validation_error, _self)
14 changes: 13 additions & 1 deletion spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,19 @@ def validate(

if resp:
expect_model = resp.find_model(status)
if expect_model and isinstance(model, expect_model):
if resp.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = resp.get_expected_list_item_type(status)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
result = (
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
for entry in model
],
status,
*rest,
)
elif expect_model and isinstance(model, expect_model):
skip_validation = True
result = (model.dict(), status, *rest)

Expand Down
14 changes: 13 additions & 1 deletion spectree/plugins/quart_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,19 @@ async def validate(

if resp:
expect_model = resp.find_model(status)
if expect_model and isinstance(model, expect_model):
if resp.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = resp.get_expected_list_item_type(status)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
result = (
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
for entry in model
],
status,
*rest,
)
elif expect_model and isinstance(model, expect_model):
skip_validation = True
result = (model.dict(), status, *rest)

Expand Down
11 changes: 9 additions & 2 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import compile_path

from .._pydantic import ValidationError
from .._pydantic import BaseModel, ValidationError
from .._types import ModelType
from ..response import Response
from .base import BasePlugin, Context
Expand All @@ -22,7 +22,14 @@ def PydanticResponse(content):
class _PydanticResponse(JSONResponse):
def render(self, content) -> bytes:
self._model_class = content.__class__
return super().render(content.dict())
return super().render(
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
for entry in content
]
if isinstance(content, list)
else content.dict()
)

return _PydanticResponse(content)

Expand Down
35 changes: 30 additions & 5 deletions spectree/response.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from http import HTTPStatus
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union

from ._pydantic import BaseModel
from ._types import ModelType, NamingStrategy, OptionalModelType
from ._types import BaseModelSubclassType, ModelType, NamingStrategy, OptionalModelType
from .utils import gen_list_model, get_model_key, parse_code

# according to https://tools.ietf.org/html/rfc2616#section-10
Expand Down Expand Up @@ -30,22 +30,30 @@ class Response:
examples:
>>> from typing import List
>>> from spectree.response import Response
>>> from pydantic import BaseModel
...
>>> class User(BaseModel):
... id: int
...
>>> response = Response(HTTP_200)
>>> response = Response("HTTP_200")
>>> response = Response(HTTP_200=None)
>>> response = Response(HTTP_200=User)
>>> response = Response(HTTP_200=(User, "status code description"))
>>> response = Response(HTTP_200=List[User])
>>> response = Response(HTTP_200=(List[User], "status code description"))
"""

def __init__(
self,
*codes: str,
**code_models: Union[OptionalModelType, Tuple[OptionalModelType, str]],
**code_models: Union[
OptionalModelType,
Tuple[OptionalModelType, str],
Type[List[BaseModelSubclassType]],
Tuple[Type[List[BaseModelSubclassType]], str],
],
) -> None:
self.codes: List[str] = []

Expand All @@ -55,6 +63,7 @@ def __init__(

self.code_models: Dict[str, ModelType] = {}
self.code_descriptions: Dict[str, Optional[str]] = {}
self.code_list_item_types: Dict[str, ModelType] = {}
for code, model_and_description in code_models.items():
assert code in DEFAULT_CODE_DESC, "invalid HTTP status code"
description: Optional[str] = None
Expand All @@ -72,7 +81,9 @@ def __init__(
origin_type = getattr(model, "__origin__", None)
if origin_type is list or origin_type is List:
# type is List[BaseModel]
model = gen_list_model(getattr(model, "__args__")[0])
list_item_type = getattr(model, "__args__")[0]
model = gen_list_model(list_item_type)
self.code_list_item_types[code] = list_item_type
assert issubclass(model, BaseModel), "invalid `pydantic.BaseModel`"
assert description is None or isinstance(
description, str
Expand Down Expand Up @@ -119,6 +130,20 @@ def find_model(self, code: int) -> OptionalModelType:
"""
return self.code_models.get(f"HTTP_{code}")

def expect_list_result(self, code: int) -> bool:
"""Check whether a specific HTTP code expects a list result.
:param code: Status code (example: 200)
"""
return f"HTTP_{code}" in self.code_list_item_types

def get_expected_list_item_type(self, code: int) -> ModelType:
"""Get the expected list result item type.
:param code: Status code (example: 200)
"""
return self.code_list_item_types[f"HTTP_{code}"]

def get_code_description(self, code: str) -> str:
"""Get the description of the given status code.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@
"title": "JSON",
"type": "object"
},
"JSONList.a9993e3": {
"items": {
"$ref": "#/components/schemas/JSONList.a9993e3.JSON"
},
"title": "JSONList",
"type": "array"
},
"JSONList.a9993e3.JSON": {
"properties": {
"limit": {
"title": "Limit",
"type": "integer"
},
"name": {
"title": "Name",
"type": "string"
}
},
"required": [
"name",
"limit"
],
"title": "JSON",
"type": "object"
},
"ListJSON.7068f62": {
"items": {
"$ref": "#/components/schemas/ListJSON.7068f62.JSON"
Expand Down Expand Up @@ -327,6 +352,37 @@
"tags": []
}
},
"/api/return_list": {
"get": {
"description": "",
"operationId": "get__api_return_list",
"parameters": [],
"responses": {
"200": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/JSONList.a9993e3"
}
}
},
"description": "OK"
},
"422": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ValidationError.6a07bef"
}
}
},
"description": "Unprocessable Entity"
}
},
"summary": "on_get <GET>",
"tags": []
}
},
"/api/user/{name}": {
"get": {
"description": "",
Expand Down
Loading

0 comments on commit d7e19da

Please sign in to comment.