Skip to content

Commit

Permalink
fix: flask query list items (0b01001001#357)
Browse files Browse the repository at this point in the history
* fix: flask query list items

Signed-off-by: Keming <kemingy94@gmail.com>

* format

Signed-off-by: Keming <kemingy94@gmail.com>

* fix lint

Signed-off-by: Keming <kemingy94@gmail.com>

* fix type lint

Signed-off-by: Keming <kemingy94@gmail.com>

* fix type in py38

Signed-off-by: Keming <kemingy94@gmail.com>

* bump version

Signed-off-by: Keming <kemingy94@gmail.com>

---------

Signed-off-by: Keming <kemingy94@gmail.com>
  • Loading branch information
kemingy authored Oct 27, 2023
1 parent 05441a6 commit d768f70
Show file tree
Hide file tree
Showing 13 changed files with 268 additions and 9 deletions.
25 changes: 25 additions & 0 deletions examples/query_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import List, Union

from flask import Flask, jsonify, request
from pydantic.v1 import BaseModel, Field

from spectree import SpecTree


class SampleQueryParams(BaseModel):
id_list: Union[int, List[int]] = Field(..., description="List of IDs")


app = Flask(__name__)
spec = SpecTree("flask")


@app.route("/api/v1/samples", methods=["GET"])
@spec.validate(query=SampleQueryParams)
def get_samples():
return jsonify(text=f"it works: {request.context.query}")


if __name__ == "__main__":
spec.register(app) # if you don't register in api init step
app.run(port=8000)
59 changes: 59 additions & 0 deletions examples/resp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
from typing import List, Union

from flask import Flask, make_response
from pydantic import BaseModel

from spectree import Response, SpecTree


class AppError(BaseModel):
message: str


class User(BaseModel):
user_id: int


class UserResponse(BaseModel):
__root__: Union[List[User], AppError]


class StrDict(BaseModel):
__root__: dict[str, str]


spec = SpecTree("flask")
# spec = SpecTree("falcon")
app = Flask(__name__)


@app.route("/ping")
@spec.validate(resp=Response(HTTP_200=StrDict))
def ping():
resp = make_response({"msg": "pong"}, 203)
resp.set_cookie(key="pub", value="abcdefg")
return resp


@app.route("/users")
@spec.validate(resp=Response(HTTP_200=UserResponse))
def get_users():
return [User(user_id=1), User(user_id=2)]


class UserResource:
@spec.validate(resp=Response(HTTP_200=UserResponse))
def on_get(self, req, resp):
resp.media = [User(user_id=0)]


if __name__ == "__main__":
spec.register(app)
app.run()

# app = falcon.App()
# app.add_route("/users", UserResource())
# spec.register(app)

# httpd = simple_server.make_server("localhost", 8000, app)
# httpd.serve_forever()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "spectree"
version = "1.2.6"
version = "1.2.7"
dynamic = []
description = "generate OpenAPI document and validate request&response with Python annotations."
readme = "README.md"
Expand Down
4 changes: 2 additions & 2 deletions spectree/plugins/flask_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,9 +140,9 @@ def request_validation(self, request, query, json, form, headers, cookies):
req_headers: werkzeug.datastructures.EnvironHeaders
req_cookies: werkzeug.datastructures.ImmutableMultiDict
"""
req_query = get_multidict_items(request.args)
req_query = get_multidict_items(request.args, query)
req_headers = dict(iter(request.headers)) or {}
req_cookies = get_multidict_items(request.cookies) or {}
req_cookies = get_multidict_items(request.cookies)
has_data = request.method not in ("GET", "DELETE")
# flask Request.mimetype is already normalized
use_json = json and has_data and request.mimetype not in self.FORM_MIMETYPE
Expand Down
26 changes: 23 additions & 3 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,13 @@
)

from ._pydantic import BaseModel, ValidationError
from ._types import ModelType, MultiDict, NamingStrategy, NestedNamingStrategy
from ._types import (
ModelType,
MultiDict,
NamingStrategy,
NestedNamingStrategy,
OptionalModelType,
)

# parse HTTP status code to get the code
HTTP_CODE = re.compile(r"^HTTP_(?P<code>\d{3})$")
Expand Down Expand Up @@ -272,20 +278,34 @@ def get_security(security: Union[None, Mapping, Sequence[Any]]) -> List[Any]:
return []


def get_multidict_items(multidict: MultiDict) -> Dict[str, Union[None, str, List[str]]]:
def get_multidict_items(
multidict: MultiDict, model: OptionalModelType = None
) -> Dict[str, Union[None, str, List[str]]]:
"""
return the items of a :class:`werkzeug.datastructures.ImmutableMultiDict`
"""
res: Dict[str, Union[None, str, List[str]]] = {}
for key in multidict:
if len(multidict.getlist(key)) > 1:
if model is not None and is_list_item(key, model):
res[key] = multidict.getlist(key)
elif len(multidict.getlist(key)) > 1:
res[key] = multidict.getlist(key)
else:
res[key] = multidict.get(key)

return res


def is_list_item(key: str, model: OptionalModelType) -> bool:
"""Check if this key is a list item in the model."""
if model is None:
return False
model_filed = model.__fields__.get(key)
if model_filed is None:
return False
return getattr(model_filed.annotation, "__origin__", None) is list


def gen_list_model(model: Type[BaseModel]) -> Type[BaseModel]:
"""
generate the corresponding list[model] class for a given model class
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@
"title": "Order",
"type": "integer"
},
"QueryList.7068f62": {
"properties": {
"ids": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
},
"required": [
"ids"
],
"title": "QueryList",
"type": "object"
},
"Resp.7068f62": {
"properties": {
"name": {
Expand Down Expand Up @@ -358,6 +374,30 @@
"tags": []
}
},
"/api/query_list": {
"get": {
"description": "",
"operationId": "get__api_query_list",
"parameters": [
{
"description": "",
"in": "query",
"name": "ids",
"required": true,
"schema": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
}
],
"responses": {},
"summary": "query_list <GET>",
"tags": []
}
},
"/api/return_list": {
"get": {
"description": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@
"title": "Order",
"type": "integer"
},
"QueryList.7068f62": {
"properties": {
"ids": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
},
"required": [
"ids"
],
"title": "QueryList",
"type": "object"
},
"Resp.7068f62": {
"properties": {
"name": {
Expand Down Expand Up @@ -358,6 +374,30 @@
"tags": []
}
},
"/api/query_list": {
"get": {
"description": "",
"operationId": "get__api_query_list",
"parameters": [
{
"description": "",
"in": "query",
"name": "ids",
"required": true,
"schema": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
}
],
"responses": {},
"summary": "query_list <GET>",
"tags": []
}
},
"/api/return_list": {
"get": {
"description": "",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,22 @@
"title": "Order",
"type": "integer"
},
"QueryList.7068f62": {
"properties": {
"ids": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
},
"required": [
"ids"
],
"title": "QueryList",
"type": "object"
},
"Resp.7068f62": {
"properties": {
"name": {
Expand Down Expand Up @@ -363,6 +379,30 @@
"tags": []
}
},
"/api/query_list": {
"get": {
"description": "",
"operationId": "get__api_query_list",
"parameters": [
{
"description": "",
"in": "query",
"name": "ids",
"required": true,
"schema": {
"items": {
"type": "integer"
},
"title": "Ids",
"type": "array"
}
}
],
"responses": {},
"summary": "get <GET>",
"tags": []
}
},
"/api/return_list": {
"get": {
"description": "",
Expand Down
4 changes: 4 additions & 0 deletions tests/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@ class Query(BaseModel):
order: Order


class QueryList(BaseModel):
ids: List[int]


class FormFileUpload(BaseModel):
file: BaseFile

Expand Down
5 changes: 5 additions & 0 deletions tests/flask_imports/dry_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,3 +257,8 @@ def test_flask_optional_alias_response(client):
resp = client.get("/api/return_optional_alias")
assert resp.status_code == 200
assert resp.json == {"schema": "test"}, resp.json


def test_flask_query_list(client):
resp = client.get("/api/query_list?ids=1&ids=2&ids=3")
assert resp.status_code == 200
8 changes: 8 additions & 0 deletions tests/test_plugin_flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalAliasResp,
Order,
Query,
QueryList,
Resp,
RootResp,
StrDict,
Expand Down Expand Up @@ -184,6 +185,13 @@ def json_list():
return {}


@app.route("/api/query_list", methods=["GET"])
@api.validate(query=QueryList)
def query_list():
assert request.context.query.ids == [1, 2, 3]
return {}


@app.route("/api/return_list", methods=["GET"])
@api.validate(resp=Response(HTTP_200=List[JSON]))
def return_list():
Expand Down
8 changes: 8 additions & 0 deletions tests/test_plugin_flask_blueprint.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
OptionalAliasResp,
Order,
Query,
QueryList,
Resp,
RootResp,
StrDict,
Expand Down Expand Up @@ -171,6 +172,13 @@ def list_json():
return {}


@app.route("/api/query_list")
@api.validate(query=QueryList)
def query_list():
assert request.context.query.ids == [1, 2, 3]
return {}


@app.route("/api/return_list", methods=["GET"])
@api.validate(resp=Response(HTTP_200=List[JSON]))
def return_list():
Expand Down
Loading

0 comments on commit d768f70

Please sign in to comment.