Skip to content

Commit

Permalink
[Proposal] Support Pydantic root model responses (0b01001001#338)
Browse files Browse the repository at this point in the history
* Support Pydantic root model responses

* kemingy's suggestions

* Add additional tests for _pydantic

* Also skip second validation if response validation error was previously found (flask)

* Fix typing issue in test_pydantic.py

* Update 'ping' flask blueprint endpoint

* Don't run first validation on Flask Response objects
  • Loading branch information
jean-edouard-boulanger authored Sep 19, 2023
1 parent 79d0d32 commit 3cac8f3
Show file tree
Hide file tree
Showing 21 changed files with 755 additions and 27 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ format:
lint:
isort --check --diff --project=spectree ${SOURCE_FILES}
black --check --diff ${SOURCE_FILES}
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=17
flake8 ${SOURCE_FILES} --count --show-source --statistics --ignore=D203,E203,W503 --max-line-length=88 --max-complexity=21
mypy --install-types --non-interactive ${MYPY_SOURCE_FILES}

.PHONY: test doc
42 changes: 42 additions & 0 deletions spectree/_pydantic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
from typing import Any

from pydantic.version import VERSION as PYDANTIC_VERSION

PYDANTIC2 = PYDANTIC_VERSION.startswith("2")
ROOT_FIELD = "__root__"


__all__ = [
"BaseModel",
Expand All @@ -11,6 +15,11 @@
"BaseSettings",
"EmailStr",
"validator",
"is_root_model",
"is_root_model_instance",
"serialize_model_instance",
"is_base_model",
"is_base_model_instance",
]

if PYDANTIC2:
Expand All @@ -35,3 +44,36 @@
root_validator,
validator,
)


def is_base_model(t: Any) -> bool:
"""Check whether a type is a Pydantic BaseModel"""
try:
return issubclass(t, BaseModel)
except TypeError:
return False


def is_base_model_instance(value: Any) -> bool:
"""Check whether a value is a Pydantic BaseModel instance."""
return is_base_model(type(value))


def is_root_model(t: Any) -> bool:
"""Check whether a type is a Pydantic RootModel."""
return is_base_model(t) and ROOT_FIELD in t.__fields__


def is_root_model_instance(value: Any):
"""Check whether a value is a Pydantic RootModel instance."""
return is_root_model(type(value))


def serialize_model_instance(value: BaseModel):
"""Serialize a Pydantic BaseModel (equivalent of calling `.dict()` on a BaseModel,
but additionally takes care of stripping __root__ for root models.
"""
serialized = value.dict()
if is_root_model_instance(value) and ROOT_FIELD in serialized:
return serialized[ROOT_FIELD]
return serialized
32 changes: 28 additions & 4 deletions spectree/plugins/falcon_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,12 @@
from falcon import Response as FalconResponse
from falcon.routing.compiled import _FIELD_PATTERN as FALCON_FIELD_PATTERN

from .._pydantic import BaseModel, ValidationError
from .._pydantic import (
BaseModel,
ValidationError,
is_root_model,
serialize_model_instance,
)
from .._types import ModelType
from ..response import Response
from .base import BasePlugin
Expand Down Expand Up @@ -195,7 +200,7 @@ def response_validation(
falcon_response: FalconResponse,
skip_validation: bool,
) -> None:
if response_spec and response_spec.has_model():
if not skip_validation and response_spec and response_spec.has_model():
model = falcon_response.media
status = int(falcon_response.status[:3])
expect_model = response_spec.find_model(status)
Expand All @@ -206,12 +211,31 @@ def response_validation(
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
falcon_response.media = [
(entry.dict() if isinstance(entry, BaseModel) else entry)
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in model
]
elif (
expect_model
and is_root_model(expect_model)
and not isinstance(model, expect_model)
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
model = expect_model(__root__=model)
except ValidationError:
raise
else:
falcon_response.media = serialize_model_instance(model)
skip_validation = True
elif expect_model and isinstance(falcon_response.media, expect_model):
falcon_response.media = model.dict()
falcon_response.media = serialize_model_instance(model)
skip_validation = True

if self._data_set_manually(falcon_response):
skip_validation = True
if expect_model and not skip_validation:
Expand Down
42 changes: 34 additions & 8 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from typing import Any, Callable, Mapping, Optional, Tuple, get_type_hints

import flask
from flask import Blueprint, abort, current_app, jsonify, make_response, request
from werkzeug.routing import parse_converter_args

from .._pydantic import BaseModel, ValidationError
from .._pydantic import (
BaseModel,
ValidationError,
is_root_model,
serialize_model_instance,
)
from .._types import ModelType
from ..response import Response
from ..utils import get_multidict_items, werkzeug_parse_rule
Expand Down Expand Up @@ -208,36 +214,56 @@ def validate(
else:
model = result

if resp:
if not skip_validation and resp and not isinstance(result, flask.Response):
expect_model = resp.find_model(status)
if resp.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = resp.get_expected_list_item_type(status)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
result = (
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in model
],
status,
*rest,
)
elif (
expect_model
and is_root_model(expect_model)
and not isinstance(model, expect_model)
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
model = expect_model(__root__=model)
except ValidationError as err:
resp_validation_error = err
else:
skip_validation = True
result = (serialize_model_instance(model), status, *rest)
elif expect_model and isinstance(model, expect_model):
skip_validation = True
result = (model.dict(), status, *rest)
result = (serialize_model_instance(model), status, *rest)

response = make_response(result)

if resp and resp.has_model():
if resp and resp.has_model() and not resp_validation_error:
model = resp.find_model(response.status_code)
if model and not skip_validation:
try:
model.parse_obj(response.get_json())
except ValidationError as err:
resp_validation_error = err
response = make_response(
jsonify({"message": "response validation error"}), 500
)

if resp_validation_error:
response = make_response(
jsonify({"message": "response validation error"}), 500
)

after(request, response, resp_validation_error, None)

Expand Down
41 changes: 33 additions & 8 deletions spectree/plugins/quart_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from quart import Blueprint, abort, current_app, jsonify, make_response, request
from werkzeug.routing import parse_converter_args

from .._pydantic import BaseModel, ValidationError
from .._pydantic import (
BaseModel,
ValidationError,
is_root_model,
serialize_model_instance,
)
from .._types import ModelType
from ..response import Response
from ..utils import get_multidict_items, werkzeug_parse_rule
Expand Down Expand Up @@ -220,36 +225,56 @@ async def validate(
else:
model = result

if resp:
if not skip_validation and resp:
expect_model = resp.find_model(status)
if resp.expect_list_result(status) and isinstance(model, list):
expected_list_item_type = resp.get_expected_list_item_type(status)
if all(isinstance(entry, expected_list_item_type) for entry in model):
skip_validation = True
result = (
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in model
],
status,
*rest,
)
elif (
expect_model
and is_root_model(expect_model)
and not isinstance(model, expect_model)
):
# Make it possible to return an instance of the model __root__ type
# (i.e. not the root model itself).
try:
model = expect_model(__root__=model)
except ValidationError as err:
resp_validation_error = err
else:
skip_validation = True
result = (serialize_model_instance(model), status, *rest)
elif expect_model and isinstance(model, expect_model):
skip_validation = True
result = (model.dict(), status, *rest)
result = (serialize_model_instance(model), status, *rest)

response = await make_response(result)

if resp and resp.has_model():
if resp and resp.has_model() and not resp_validation_error:
model = resp.find_model(response.status_code)
if model and not skip_validation:
try:
model.parse_obj(await response.get_json())
except ValidationError as err:
resp_validation_error = err
response = await make_response(
jsonify({"message": "response validation error"}), 500
)

if resp_validation_error:
response = await make_response(
jsonify({"message": "response validation error"}), 500
)

after(request, response, resp_validation_error, None)

Expand Down
12 changes: 8 additions & 4 deletions spectree/plugins/starlette_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from starlette.responses import HTMLResponse, JSONResponse
from starlette.routing import compile_path

from .._pydantic import BaseModel, ValidationError
from .._pydantic import BaseModel, ValidationError, serialize_model_instance
from .._types import ModelType
from ..response import Response
from .base import BasePlugin, Context
Expand All @@ -24,11 +24,15 @@ def render(self, content) -> bytes:
self._model_class = content.__class__
return super().render(
[
(entry.dict() if isinstance(entry, BaseModel) else entry)
(
serialize_model_instance(entry)
if isinstance(entry, BaseModel)
else entry
)
for entry in content
]
if isinstance(content, list)
else content.dict()
else serialize_model_instance(content)
)

return _PydanticResponse(content)
Expand Down Expand Up @@ -129,7 +133,7 @@ async def validate(
else:
response = func(*args, **kwargs)

if resp and response:
if not skip_validation and resp and response:
if (
isinstance(response, JSONResponse)
and hasattr(response, "_model_class")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,38 @@
"title": "Resp",
"type": "object"
},
"RootResp.7068f62": {
"anyOf": [
{
"$ref": "#/components/schemas/RootResp.7068f62.JSON"
},
{
"items": {
"type": "integer"
},
"type": "array"
}
],
"title": "RootResp"
},
"RootResp.7068f62.JSON": {
"properties": {
"limit": {
"title": "Limit",
"type": "integer"
},
"name": {
"title": "Name",
"type": "string"
}
},
"required": [
"name",
"limit"
],
"title": "JSON",
"type": "object"
},
"StrDict.7068f62": {
"additionalProperties": {
"type": "string"
Expand Down Expand Up @@ -383,6 +415,37 @@
"tags": []
}
},
"/api/return_root": {
"get": {
"description": "",
"operationId": "get__api_return_root",
"parameters": [],
"responses": {
"200": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/RootResp.7068f62"
}
}
},
"description": "OK"
},
"422": {
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/ValidationError.6a07bef"
}
}
},
"description": "Unprocessable Entity"
}
},
"summary": "on_get <GET>",
"tags": []
}
},
"/api/user/{name}": {
"get": {
"description": "",
Expand Down
Loading

0 comments on commit 3cac8f3

Please sign in to comment.