diff --git a/spectree/_types.py b/spectree/_types.py index b7579e50..5c36ee4d 100644 --- a/spectree/_types.py +++ b/spectree/_types.py @@ -17,6 +17,7 @@ ModelType = Type[BaseModel] OptionalModelType = Optional[ModelType] NamingStrategy = Callable[[ModelType], str] +NestedNamingStrategy = Callable[[str, str], str] class MultiDict(Protocol): diff --git a/spectree/spec.py b/spectree/spec.py index aacebf84..66836bbc 100644 --- a/spectree/spec.py +++ b/spectree/spec.py @@ -13,7 +13,7 @@ get_type_hints, ) -from ._types import FunctionDecorator, ModelType, NamingStrategy +from ._types import FunctionDecorator, ModelType, NamingStrategy, NestedNamingStrategy from .config import Configuration, ModeEnum from .models import Tag, ValidationError from .plugins import PLUGINS, BasePlugin @@ -23,6 +23,7 @@ default_before_handler, get_model_key, get_model_schema, + get_nested_key, get_security, parse_comments, parse_name, @@ -65,9 +66,11 @@ def __init__( validation_error_status: int = 422, validation_error_model: Optional[ModelType] = None, naming_strategy: NamingStrategy = get_model_key, + nested_naming_strategy: NestedNamingStrategy = get_nested_key, **kwargs: Any, ): self.naming_strategy = naming_strategy + self.nested_naming_strategy = nested_naming_strategy self.before = before self.after = after self.validation_error_status = validation_error_status @@ -261,7 +264,11 @@ def _add_model(self, model: ModelType) -> str: model_key = self.naming_strategy(model) self.models[model_key] = deepcopy( - get_model_schema(model=model, naming_strategy=self.naming_strategy) + get_model_schema( + model=model, + naming_strategy=self.naming_strategy, + nested_naming_strategy=self.nested_naming_strategy, + ) ) return model_key @@ -350,7 +357,9 @@ def _get_model_definitions(self) -> Dict[str, Any]: for name, schema in self.models.items(): if "definitions" in schema: for key, value in schema["definitions"].items(): - definitions[f"{name}.{key}"] = value + composed_key = self.nested_naming_strategy(name, key) + if composed_key not in definitions: + definitions[composed_key] = value del schema["definitions"] return definitions diff --git a/spectree/utils.py b/spectree/utils.py index 615a1a8e..52af7c74 100644 --- a/spectree/utils.py +++ b/spectree/utils.py @@ -18,7 +18,7 @@ from pydantic import BaseModel, ValidationError -from ._types import ModelType, MultiDict, NamingStrategy +from ._types import ModelType, MultiDict, NamingStrategy, NestedNamingStrategy # parse HTTP status code to get the code HTTP_CODE = re.compile(r"^HTTP_(?P\d{3})$") @@ -229,7 +229,22 @@ def get_model_key(model: ModelType) -> str: return f"{model.__name__}.{hash_module_path(module_path=model.__module__)}" -def get_model_schema(model: ModelType, naming_strategy: NamingStrategy = get_model_key): +def get_nested_key(parent: str, child: str) -> str: + """ + generate nested model reference name suffixed by parent model name + + :param parent: string of parent name + :param child: string of child name + """ + + return f"{parent}.{child}" + + +def get_model_schema( + model: ModelType, + naming_strategy: NamingStrategy = get_model_key, + nested_naming_strategy: NestedNamingStrategy = get_nested_key, +): """ return a dictionary representing the model as JSON Schema with a hashed infix in ref to ensure name uniqueness @@ -239,9 +254,9 @@ def get_model_schema(model: ModelType, naming_strategy: NamingStrategy = get_mod """ assert issubclass(model, BaseModel) - return model.schema( - ref_template=f"#/components/schemas/{naming_strategy(model)}.{{model}}" - ) + nested_key = nested_naming_strategy(naming_strategy(model), "{model}") + + return model.schema(ref_template=f"#/components/schemas/{nested_key}") def get_security(security: Union[None, Mapping, Sequence[Any]]) -> List[Any]: