Skip to content

Commit e2efd31

Browse files
committed
fix tests
1 parent 5a52650 commit e2efd31

File tree

4 files changed

+55
-30
lines changed

4 files changed

+55
-30
lines changed

flask_pydantic_spec/types.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import re
2-
from typing import Optional, Type, Iterable, Mapping, Any, Dict, List
2+
from typing import Optional, Type, Iterable, Mapping, Any, Dict, List, NamedTuple
33

44
from pydantic.v1 import BaseModel
55

@@ -47,11 +47,23 @@ def __init__(self, *args: Any, **kwargs: Any):
4747
else:
4848
assert key in DEFAULT_CODE_DESC, "invalid HTTP status code"
4949
if value:
50-
assert issubclass(value, BaseModel), "invalid `pydantic.BaseModel`"
51-
self.code_models[key] = value
50+
if self.is_list_type(value):
51+
assert issubclass(
52+
value.__args__[0], BaseModel
53+
), "invalid `pydantic.BaseModel`"
54+
self.code_models[key] = value.__args__[0]
55+
else:
56+
assert issubclass(
57+
value, BaseModel
58+
), "invalid `pydantic.BaseModel`"
59+
self.code_models[key] = value
5260
else:
5361
self.codes.append(key)
5462

63+
@staticmethod
64+
def is_list_type(value: Any) -> bool:
65+
return hasattr(value, "__origin__") and value.__origin__ is list
66+
5567
def has_model(self) -> bool:
5668
"""
5769
:returns: boolean -- does this response has models or not

flask_pydantic_spec/utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,15 @@ def parse_request(func: Callable) -> Mapping[str, Any]:
5050
request_body = getattr(func, "body")
5151
if isinstance(request_body, RequestBase):
5252
result: Mapping[str, Any] = request_body.generate_spec()
53-
elif issubclass(request_body, BaseModel):
54-
result = Request(request_body).generate_spec()
5553
else:
56-
result = {}
54+
try:
55+
if issubclass(request_body, BaseModel):
56+
result = Request(request_body).generate_spec()
57+
else:
58+
result = {}
59+
except TypeError:
60+
# request_body is not a class (e.g., generic type)
61+
result = {}
5762
return result
5863
return {}
5964

requirements/development.txt

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
-r production.txt
2-
openapi-spec-validator>=0.2.9, <0.3
3-
pytest==7.1.2
4-
flake8==7.3.0
5-
flask==2.2.5
6-
werkzeug==2.2.3
7-
requests==2.31.0
8-
black==24.4.2
9-
mypy==1.17.1
2+
openapi-spec-validator == 0.7.1
3+
pytest>=8.3.5, <9
4+
flake8 == 7.1.1
5+
flask>=2.0.2, <3
6+
requests>=2.24.0, <3
7+
black>=20.8b1
8+
mypy==1.10.0

tests/test_spec.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
import pytest
55
from flask import Flask
66
from typing import List
7-
from openapi_spec_validator import validate_v3_spec
8-
from pydantic.v1 import BaseModel, Field, StrictFloat
7+
from openapi_spec_validator import OpenAPIV30SpecValidator
8+
from pydantic.v1 import BaseModel, StrictFloat, Field
99

1010
from flask_pydantic_spec import Response
1111
from flask_pydantic_spec.flask_backend import FlaskBackend
@@ -94,49 +94,57 @@ def test_spec_generate(name, app):
9494
def create_app():
9595
app = Flask(__name__)
9696

97-
@app.route("/foo")
98-
@api.validate()
97+
@app.get("/foo")
98+
@api.validate(resp=Response(HTTP_200=ExampleModel))
9999
def foo():
100100
pass
101101

102-
@app.route("/bar")
103-
@api_strict.validate()
102+
@app.get("/bar")
103+
@api_strict.validate(resp=Response(HTTP_200=ExampleModel))
104104
def bar():
105105
pass
106106

107-
@app.route("/lone", methods=["GET"])
107+
@app.get("/lone")
108+
@api.validate(
109+
resp=Response(HTTP_200=ExampleNestedList, HTTP_400=ExampleNestedModel),
110+
tags=["lone"],
111+
)
108112
def lone_get():
109113
pass
110114

111-
@app.route("/lone", methods=["POST"])
115+
@app.post("/lone")
112116
@api.validate(
113117
body=Request(ExampleModel),
114-
resp=Response(HTTP_200=ExampleNestedList, HTTP_400=ExampleNestedModel),
118+
resp=Response(HTTP_200=List[ExampleModel], HTTP_400=ExampleNestedModel),
115119
tags=["lone"],
116120
deprecated=True,
117121
)
118122
def lone_post():
119123
pass
120124

121-
@app.route("/query", methods=["GET"])
122-
@api.validate(query=ExampleQuery)
125+
@app.get("/query")
126+
@api.validate(
127+
query=ExampleQuery,
128+
resp=Response(HTTP_200=List[ExampleModel]),
129+
tags=["alpha"],
130+
)
123131
def get_query():
124132
pass
125133

126-
@app.route("/file")
134+
@app.get("/file")
127135
@api.validate(resp=FileResponse())
128136
def get_file():
129137
pass
130138

131-
@app.route("/file", methods=["POST"])
139+
@app.post("/file")
132140
@api.validate(
133141
body=Request(content_type="application/octet-stream"),
134-
resp=Response(HTTP_200=None),
142+
resp=FileResponse(),
135143
)
136144
def post_file():
137145
pass
138146

139-
@app.route("/multipart-file", methods=["POST"])
147+
@app.post("/multipart-file")
140148
@api.validate(
141149
body=MultipartFormRequest(ExampleModel), resp=Response(HTTP_200=ExampleModel)
142150
)
@@ -197,7 +205,8 @@ def test_valid_openapi_spec():
197205
app = create_app()
198206
api.register(app)
199207
spec = api.spec
200-
validate_v3_spec(spec)
208+
OpenAPIV30SpecValidator(spec).validate()
209+
assert OpenAPIV30SpecValidator(spec).is_valid()
201210

202211

203212
def test_openapi_tags():

0 commit comments

Comments
 (0)