Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor the get_model_key util function #302

Merged
merged 1 commit into from
Apr 22, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 13 additions & 1 deletion spectree/_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)

from pydantic import BaseModel
from typing_extensions import Protocol

ModelType = Type[BaseModel]
OptionalModelType = Optional[ModelType]
NamingStrategy = Callable[[ModelType], str]


class MultiDict(Protocol):
Expand Down
8 changes: 5 additions & 3 deletions spectree/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel

from ._types import ModelType, OptionalModelType
from ._types import ModelType, NamingStrategy, OptionalModelType
from .utils import gen_list_model, get_model_key, parse_code


Expand Down Expand Up @@ -121,7 +121,9 @@ def models(self) -> Iterable[ModelType]:
"""
return self.code_models.values()

def generate_spec(self) -> Dict[str, Any]:
def generate_spec(
self, naming_strategy: NamingStrategy = get_model_key
) -> Dict[str, Any]:
"""
generate the spec for responses

Expand All @@ -134,7 +136,7 @@ def generate_spec(self) -> Dict[str, Any]:
}

for code, model in self.code_models.items():
model_name = get_model_key(model=model)
model_name = naming_strategy(model)
responses[parse_code(code)] = {
"description": self.get_code_description(code),
"content": {
Expand Down
12 changes: 8 additions & 4 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_type_hints,
)

from ._types import FunctionDecorator, ModelType
from ._types import FunctionDecorator, ModelType, NamingStrategy
from .config import Configuration, ModeEnum
from .models import Tag, ValidationError
from .plugins import PLUGINS, BasePlugin
Expand Down Expand Up @@ -62,8 +62,10 @@ def __init__(
after: Callable = default_after_handler,
validation_error_status: int = 422,
validation_error_model: Optional[ModelType] = None,
naming_strategy: NamingStrategy = get_model_key,
**kwargs: Any,
):
self.naming_strategy = naming_strategy
self.before = before
self.after = after
self.validation_error_status = validation_error_status
Expand Down Expand Up @@ -255,8 +257,10 @@ def _add_model(self, model: ModelType) -> str:
unified model processing
"""

model_key = get_model_key(model=model)
self.models[model_key] = deepcopy(get_model_schema(model=model))
model_key = self.naming_strategy(model)
self.models[model_key] = deepcopy(
get_model_schema(model=model, naming_strategy=self.naming_strategy)
)

return model_key

Expand Down Expand Up @@ -297,7 +301,7 @@ def _generate_spec(self) -> Dict[str, Any]:
"description": desc or "",
"tags": [str(x) for x in getattr(func, "tags", ())],
"parameters": parse_params(func, parameters[:], self.models),
"responses": parse_resp(func),
"responses": parse_resp(func, self.naming_strategy),
}

security = getattr(func, "security", None)
Expand Down
36 changes: 18 additions & 18 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pydantic import BaseModel, ValidationError

from ._types import ModelType, MultiDict
from ._types import ModelType, MultiDict, NamingStrategy

# parse HTTP status code to get the code
HTTP_CODE = re.compile(r"^HTTP_(?P<code>\d{3})$")
Expand Down Expand Up @@ -131,21 +131,6 @@ def parse_params(
return params


def parse_resp(func: Any):
"""
get the response spec

If this function does not have explicit ``resp`` but have other models,
a ``422 Validation Error`` will be appended to the response spec, since
this may be triggered in the validation step.
"""
responses = {}
if hasattr(func, "resp"):
responses = func.resp.generate_spec()

return responses


def has_model(func: Any) -> bool:
"""
return True if this function have ``pydantic.BaseModel``
Expand Down Expand Up @@ -244,7 +229,7 @@ def get_model_key(model: ModelType) -> str:
return f"{model.__name__}.{hash_module_path(module_path=model.__module__)}"


def get_model_schema(model: ModelType):
def get_model_schema(model: ModelType, naming_strategy: NamingStrategy = get_model_key):
"""
return a dictionary representing the model as JSON Schema with a hashed
infix in ref to ensure name uniqueness
Expand All @@ -255,7 +240,7 @@ def get_model_schema(model: ModelType):
assert issubclass(model, BaseModel)

return model.schema(
ref_template=f"#/components/schemas/{get_model_key(model)}.{{model}}"
ref_template=f"#/components/schemas/{naming_strategy(model)}.{{model}}"
)


Expand Down Expand Up @@ -334,3 +319,18 @@ def werkzeug_parse_rule(
if ">" in remaining or "<" in remaining:
raise ValueError(f"malformed url rule: {rule!r}")
yield None, None, remaining


def parse_resp(func: Any, naming_strategy: NamingStrategy = get_model_key):
"""
get the response spec

If this function does not have explicit ``resp`` but have other models,
a ``422 Validation Error`` will be appended to the response spec, since
this may be triggered in the validation step.
"""
responses = {}
if hasattr(func, "resp"):
responses = func.resp.generate_spec(naming_strategy)

return responses