Skip to content

Commit

Permalink
Refactor the get_model_key util function (0b01001001#302)
Browse files Browse the repository at this point in the history
🔀 refactor(types.py): add line breaks to import statements
The import statements in the _types.py file were too long and difficult to read. Line breaks were added to improve readability.

🔀 refactor(response.py): add naming_strategy parameter to generate_spec method
The generate_spec method in the Response class now accepts a naming_strategy parameter. This allows the user to specify a custom naming strategy for the model names in the response spec.

🔀 refactor(spec.py): add naming_strategy parameter to __init__ and parse_model methods
The __init__ and parse_model methods in the SpecTree class now accept a naming_strategy parameter. This allows the user to specify a custom naming strategy for the model names in the OpenAPI spec.

🔀 refactor(utils.py): add naming_strategy parameter to get_model_schema and parse_resp functions
The get_model_schema and parse_resp functions in the utils.py file now accept a naming_strategy parameter. This allows the user to specify a
  • Loading branch information
AndreasBBS authored Apr 22, 2023
1 parent 78edc00 commit d41dea0
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 26 deletions.
14 changes: 13 additions & 1 deletion spectree/_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,22 @@
from typing import Any, Dict, Iterator, List, Mapping, Optional, Sequence, Type, Union
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Mapping,
Optional,
Sequence,
Type,
Union,
)

from pydantic import BaseModel
from typing_extensions import Protocol

ModelType = Type[BaseModel]
OptionalModelType = Optional[ModelType]
NamingStrategy = Callable[[ModelType], str]


class MultiDict(Protocol):
Expand Down
8 changes: 5 additions & 3 deletions spectree/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from pydantic import BaseModel

from ._types import ModelType, OptionalModelType
from ._types import ModelType, NamingStrategy, OptionalModelType
from .utils import gen_list_model, get_model_key, parse_code


Expand Down Expand Up @@ -121,7 +121,9 @@ def models(self) -> Iterable[ModelType]:
"""
return self.code_models.values()

def generate_spec(self) -> Dict[str, Any]:
def generate_spec(
self, naming_strategy: NamingStrategy = get_model_key
) -> Dict[str, Any]:
"""
generate the spec for responses
Expand All @@ -134,7 +136,7 @@ def generate_spec(self) -> Dict[str, Any]:
}

for code, model in self.code_models.items():
model_name = get_model_key(model=model)
model_name = naming_strategy(model)
responses[parse_code(code)] = {
"description": self.get_code_description(code),
"content": {
Expand Down
12 changes: 8 additions & 4 deletions spectree/spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
get_type_hints,
)

from ._types import FunctionDecorator, ModelType
from ._types import FunctionDecorator, ModelType, NamingStrategy
from .config import Configuration, ModeEnum
from .models import Tag, ValidationError
from .plugins import PLUGINS, BasePlugin
Expand Down Expand Up @@ -62,8 +62,10 @@ def __init__(
after: Callable = default_after_handler,
validation_error_status: int = 422,
validation_error_model: Optional[ModelType] = None,
naming_strategy: NamingStrategy = get_model_key,
**kwargs: Any,
):
self.naming_strategy = naming_strategy
self.before = before
self.after = after
self.validation_error_status = validation_error_status
Expand Down Expand Up @@ -255,8 +257,10 @@ def _add_model(self, model: ModelType) -> str:
unified model processing
"""

model_key = get_model_key(model=model)
self.models[model_key] = deepcopy(get_model_schema(model=model))
model_key = self.naming_strategy(model)
self.models[model_key] = deepcopy(
get_model_schema(model=model, naming_strategy=self.naming_strategy)
)

return model_key

Expand Down Expand Up @@ -297,7 +301,7 @@ def _generate_spec(self) -> Dict[str, Any]:
"description": desc or "",
"tags": [str(x) for x in getattr(func, "tags", ())],
"parameters": parse_params(func, parameters[:], self.models),
"responses": parse_resp(func),
"responses": parse_resp(func, self.naming_strategy),
}

security = getattr(func, "security", None)
Expand Down
36 changes: 18 additions & 18 deletions spectree/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

from pydantic import BaseModel, ValidationError

from ._types import ModelType, MultiDict
from ._types import ModelType, MultiDict, NamingStrategy

# parse HTTP status code to get the code
HTTP_CODE = re.compile(r"^HTTP_(?P<code>\d{3})$")
Expand Down Expand Up @@ -131,21 +131,6 @@ def parse_params(
return params


def parse_resp(func: Any):
"""
get the response spec
If this function does not have explicit ``resp`` but have other models,
a ``422 Validation Error`` will be appended to the response spec, since
this may be triggered in the validation step.
"""
responses = {}
if hasattr(func, "resp"):
responses = func.resp.generate_spec()

return responses


def has_model(func: Any) -> bool:
"""
return True if this function have ``pydantic.BaseModel``
Expand Down Expand Up @@ -244,7 +229,7 @@ def get_model_key(model: ModelType) -> str:
return f"{model.__name__}.{hash_module_path(module_path=model.__module__)}"


def get_model_schema(model: ModelType):
def get_model_schema(model: ModelType, naming_strategy: NamingStrategy = get_model_key):
"""
return a dictionary representing the model as JSON Schema with a hashed
infix in ref to ensure name uniqueness
Expand All @@ -255,7 +240,7 @@ def get_model_schema(model: ModelType):
assert issubclass(model, BaseModel)

return model.schema(
ref_template=f"#/components/schemas/{get_model_key(model)}.{{model}}"
ref_template=f"#/components/schemas/{naming_strategy(model)}.{{model}}"
)


Expand Down Expand Up @@ -334,3 +319,18 @@ def werkzeug_parse_rule(
if ">" in remaining or "<" in remaining:
raise ValueError(f"malformed url rule: {rule!r}")
yield None, None, remaining


def parse_resp(func: Any, naming_strategy: NamingStrategy = get_model_key):
"""
get the response spec
If this function does not have explicit ``resp`` but have other models,
a ``422 Validation Error`` will be appended to the response spec, since
this may be triggered in the validation step.
"""
responses = {}
if hasattr(func, "resp"):
responses = func.resp.generate_spec(naming_strategy)

return responses

0 comments on commit d41dea0

Please sign in to comment.