Skip to content

Commit 6a444d8

Browse files
authored
Merge pull request #50 from codectl/fix/#47
register HTTPResponse model response if registry being used
2 parents 76a5816 + cb7fbc6 commit 6a444d8

File tree

12 files changed

+517
-127
lines changed

12 files changed

+517
-127
lines changed

README.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ Example Usage
5757
from typing import Optional
5858
5959
from apispec import APISpec
60-
from apispec_plugins.base.registry import RegistryMixin
60+
from apispec_plugins.base.mixin import RegistryMixin
6161
from apispec_plugins.ext.pydantic import PydanticPlugin
6262
from apispec_plugins.webframeworks.flask import FlaskPlugin
6363
from flask import Flask
@@ -68,7 +68,7 @@ Example Usage
6868
spec = APISpec(
6969
title="Pet Store",
7070
version="1.0.0",
71-
openapi_version="3.1.0",
71+
openapi_version="3.0.3",
7272
info=dict(description="A minimal pet store API"),
7373
plugins=(FlaskPlugin(), PydanticPlugin()),
7474
)

poetry.lock

Lines changed: 311 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,15 +43,16 @@ classifiers = [
4343
[tool.poetry.dependencies]
4444
apispec = { extras = ["yaml"], version = "^6.3.0" }
4545
flask = { version = "^2.1.3", optional = true }
46-
pydantic = { version = "^1.10.6", optional = true }
46+
pydantic = { version = "^1.10.7", optional = true }
4747
python = "^3.8"
4848

4949
[tool.poetry.dev-dependencies]
5050
coverage = "^7.2.2"
5151
flask = "^2.1.2"
5252
pre-commit = "^3.1.1"
53-
pydantic = "^1.10.5"
53+
pydantic = "^1.10.7"
5454
pytest = "^7.2.2"
55+
pytest-mock = "^3.10.0"
5556

5657
[tool.poetry.extras]
5758
flask = ["Flask"]

src/apispec_plugins/base/mixin.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from dataclasses import MISSING, fields
2+
from typing import get_args
3+
4+
from apispec import APISpec
5+
from apispec.ext.marshmallow.openapi import OpenAPIConverter, marshmallow as ma
6+
from apispec_plugins.base.registry import Registry
7+
8+
9+
class RegistryMixin(Registry):
10+
def __init_subclass__(cls, **kwargs):
11+
super().__init_subclass__(**kwargs)
12+
cls.register(cls)
13+
14+
15+
class DataclassSchemaMixin:
16+
@classmethod
17+
def schema(cls):
18+
19+
# resolve pydantic schema
20+
model = getattr(cls, "__pydantic_model__", None)
21+
if model:
22+
Registry.register(model)
23+
return model.schema()
24+
25+
# or fallback to marshmallow resolver
26+
return cls.dataclass_schema()
27+
28+
@classmethod
29+
def dataclass_schema(cls, openapi_version="2.0"):
30+
openapi_converter = OpenAPIConverter(
31+
openapi_version=openapi_version,
32+
schema_name_resolver=lambda f: None,
33+
spec=APISpec("", "", openapi_version),
34+
)
35+
36+
def schema_type(t):
37+
return ma.Schema.TYPE_MAPPING[next(iter(get_args(t)), t)]
38+
39+
schema_dict = {
40+
f.name: schema_type(f.type)(data_key=f.name, required=f.default is MISSING)
41+
for f in fields(cls)
42+
}
43+
schema = ma.Schema.from_dict(schema_dict)
44+
return openapi_converter.schema2jsonschema(schema)

src/apispec_plugins/base/registry.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,20 @@
33
from typing import TypeVar
44

55
__all__ = (
6-
"RegistryMixin",
6+
"Registry",
77
"RegistryError",
88
)
99

1010
T = TypeVar("T")
1111

1212

13-
class RegistryMixin:
14-
13+
class Registry:
1514
_registry: dict[str, T] = {}
1615

17-
def __init_subclass__(cls, **kwargs):
18-
super().__init_subclass__(**kwargs)
19-
cls._registry[cls.__name__] = cls
16+
@classmethod
17+
def register(cls, record: T):
18+
if record.__name__ not in cls._registry:
19+
cls._registry[record.__name__] = record
2020

2121
@classmethod
2222
def get_registry(cls):

src/apispec_plugins/base/types.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
1-
from __future__ import annotations
1+
try:
2+
from pydantic.dataclasses import dataclass
3+
except ImportError:
4+
from dataclasses import dataclass
5+
from typing import Optional
26

3-
from dataclasses import dataclass
7+
from apispec_plugins.base.mixin import DataclassSchemaMixin
48

59

610
__all__ = (
711
"AuthSchemes",
8-
"HTTPResponse",
912
"Server",
1013
"Tag",
14+
"HTTPResponse",
1115
)
1216

1317

@@ -18,19 +22,19 @@ class BasicAuth:
1822
scheme: str = "basic"
1923

2024

21-
@dataclass
22-
class HTTPResponse:
23-
code: int
24-
description: str | None = None
25-
26-
2725
@dataclass
2826
class Server:
2927
url: str
30-
description: str | None = None
28+
description: Optional[str] = None
3129

3230

3331
@dataclass
3432
class Tag:
3533
name: str
36-
description: str | None = None
34+
description: Optional[str] = None
35+
36+
37+
@dataclass
38+
class HTTPResponse(DataclassSchemaMixin):
39+
code: int
40+
description: Optional[str] = None

src/apispec_plugins/ext/pydantic.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from apispec import BasePlugin, APISpec
77
from apispec.exceptions import APISpecError, DuplicateComponentNameError
8-
from apispec_plugins.base import registry
8+
from apispec_plugins.base.registry import Registry
99
from pydantic import BaseModel
1010

1111

@@ -152,7 +152,7 @@ def resolve_schema_instance(
152152
elif isinstance(schema, BaseModel):
153153
return schema.__class__
154154
elif isinstance(schema, str):
155-
return registry.RegistryMixin.get_cls(schema)
155+
return Registry.get_cls(schema)
156156
return None
157157

158158
@classmethod

src/apispec_plugins/utils.py

Lines changed: 2 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -2,19 +2,17 @@
22
import re
33
import typing
44
import urllib.parse
5-
from collections.abc import Sequence
6-
from dataclasses import MISSING, asdict, fields
5+
from dataclasses import asdict
76

87
from apispec import yaml_utils
9-
108
from apispec_plugins.base import types
119

10+
1211
__all__ = (
1312
"spec_from",
1413
"load_method_specs",
1514
"load_specs_from_docstring",
1615
"path_parser",
17-
"dataclass_schema_resolver",
1816
"base_template",
1917
)
2018

@@ -81,31 +79,6 @@ def path_parser(path, **kwargs):
8179
return parsed
8280

8381

84-
def dataclass_schema_resolver(schema):
85-
"""A schema resolver for dataclasses."""
86-
87-
def _resolve_field_type(f):
88-
if f.type == str:
89-
return "string"
90-
if f.type == int:
91-
return "integer"
92-
if f.type == float:
93-
return "number"
94-
if f.type == bool:
95-
return "boolean"
96-
elif isinstance(field.type, Sequence):
97-
return "array"
98-
return "object"
99-
100-
definition = {"type": "object", "properties": {}, "required": []}
101-
for field in fields(schema):
102-
name = field.name
103-
definition["properties"][name] = {"type": _resolve_field_type(field)}
104-
if field.default == MISSING and field.default_factory == MISSING:
105-
definition["required"].append(name)
106-
return definition
107-
108-
10982
def base_template(
11083
openapi_version: str,
11184
info: dict = None,

src/apispec_plugins/webframeworks/flask.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,7 @@ def operation_helper(self, path=None, operations=None, **kwargs):
7777
if http_schema_name not in self.spec.components.schemas:
7878
self.spec.components.schema(
7979
component_id=http_schema_name,
80-
component=spec_utils.dataclass_schema_resolver(
81-
types.HTTPResponse
82-
),
80+
component=types.HTTPResponse.schema(),
8381
)
8482

8583
if schema_name not in self.spec.components.responses:

tests/conftest.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
from typing import Optional
2+
3+
from apispec_plugins.base.mixin import RegistryMixin
4+
from pydantic import BaseModel
5+
6+
7+
class Pet(BaseModel, RegistryMixin):
8+
id: Optional[int]
9+
name: str

0 commit comments

Comments
 (0)