Skip to content

Commit

Permalink
Fix discriminator mapping resolution in schemas with parent-child hie…
Browse files Browse the repository at this point in the history
…rarchy (#2145)

* Add discriminator with properties test

* Patch __apply_discriminator_type

* Remove unnecessary reference.path check

* Fix typing for python3.8

* Remove unnecessary base_class.reference check

---------

Co-authored-by: Koudai Aono <koxudaxi@gmail.com>
  • Loading branch information
sternakt and koxudaxi authored Nov 10, 2024
1 parent cb0462a commit 6c3e114
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 17 deletions.
53 changes: 36 additions & 17 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -803,6 +803,7 @@ def __apply_discriminator_type(
if not data_type.reference: # pragma: no cover
continue
discriminator_model = data_type.reference.source

if not isinstance( # pragma: no cover
discriminator_model,
(
Expand All @@ -812,26 +813,43 @@ def __apply_discriminator_type(
),
):
continue # pragma: no cover
type_names = []
if mapping:

type_names: List[str] = []

def check_paths(
model: Union[
pydantic_model.BaseModel,
pydantic_model_v2.BaseModel,
Reference,
],
mapping: Dict[str, str],
type_names: List[str] = type_names,
) -> None:
"""Helper function to validate paths for a given model."""
for name, path in mapping.items():
if (
discriminator_model.path.split('#/')[-1]
!= path.split('#/')[-1]
model.path.split('#/')[-1] != path.split('#/')[-1]
) and (
path.startswith('#/')
or model.path[:-1] != path.split('/')[-1]
):
if (
path.startswith('#/')
or discriminator_model.path[:-1]
!= path.split('/')[-1]
):
t_path = path[str(path).find('/') + 1 :]
t_disc = discriminator_model.path[
: str(discriminator_model.path).find('#')
].lstrip('../')
t_disc_2 = '/'.join(t_disc.split('/')[1:])
if t_path != t_disc and t_path != t_disc_2:
continue
t_path = path[str(path).find('/') + 1 :]
t_disc = model.path[: str(model.path).find('#')].lstrip(
'../'
)
t_disc_2 = '/'.join(t_disc.split('/')[1:])
if t_path != t_disc and t_path != t_disc_2:
continue
type_names.append(name)

# Check the main discriminator model path
if mapping:
check_paths(discriminator_model, mapping)

# Check the base_classes if they exist
if len(type_names) == 0:
for base_class in discriminator_model.base_classes:
check_paths(base_class.reference, mapping)
else:
type_names = [discriminator_model.path.split('/')[-1]]
if not type_names: # pragma: no cover
Expand Down Expand Up @@ -879,7 +897,8 @@ def __apply_discriminator_type(
else IMPORT_LITERAL_BACKPORT
)
has_imported_literal = any(
literal == import_ for import_ in imports
literal == import_ # type: ignore [comparison-overlap]
for import_ in imports
)
if has_imported_literal: # pragma: no cover
imports.append(literal)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# generated by datamodel-codegen:
# filename: discriminator_with_properties.yaml
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import Literal, Optional, Union

from pydantic import BaseModel, Field, RootModel


class UserContextVariable(BaseModel):
accountId: str = Field(..., description='The account ID of the user.')
type: str = Field(..., description='Type of custom context variable.')


class IssueContextVariable(BaseModel):
id: Optional[int] = Field(None, description='The issue ID.')
key: Optional[str] = Field(None, description='The issue key.')
type: str = Field(..., description='Type of custom context variable.')


class CustomContextVariable1(UserContextVariable):
type: Literal['user'] = Field(..., description='Type of custom context variable.')


class CustomContextVariable2(IssueContextVariable):
type: Literal['issue'] = Field(..., description='Type of custom context variable.')


class CustomContextVariable(
RootModel[Union[CustomContextVariable1, CustomContextVariable2]]
):
root: Union[CustomContextVariable1, CustomContextVariable2] = Field(
..., discriminator='type'
)
46 changes: 46 additions & 0 deletions tests/data/openapi/discriminator_with_properties.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
openapi: 3.0.1
components:
schemas:
CustomContextVariable:
oneOf:
- $ref: '#/components/schemas/UserContextVariable'
- $ref: '#/components/schemas/IssueContextVariable'
properties:
type:
description: Type of custom context variable.
type: string
discriminator:
mapping:
user: '#/components/schemas/UserContextVariable'
issue: '#/components/schemas/IssueContextVariable'
propertyName: type
required:
- type
type: object
UserContextVariable:
properties:
accountId:
description: The account ID of the user.
type: string
type:
description: Type of custom context variable.
type: string
required:
- accountId
- type
type: object
IssueContextVariable:
properties:
id:
description: The issue ID.
format: int64
type: integer
key:
description: The issue key.
type: string
type:
description: Type of custom context variable.
type: string
required:
- type
type: object
28 changes: 28 additions & 0 deletions tests/main/openapi/test_main_openapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,34 @@ def test_main_openapi_discriminator_enum_duplicate():
)


@freeze_time('2019-07-26')
def test_main_openapi_discriminator_with_properties():
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
[
'--input',
str(OPEN_API_DATA_PATH / 'discriminator_with_properties.yaml'),
'--output',
str(output_file),
'--output-model-type',
'pydantic_v2.BaseModel',
]
)
assert return_code == Exit.OK

print(output_file.read_text())

assert (
output_file.read_text()
== (
EXPECTED_OPENAPI_PATH
/ 'discriminator'
/ 'discriminator_with_properties.py'
).read_text()
)


@freeze_time('2019-07-26')
def test_main_pydantic_basemodel():
with TemporaryDirectory() as output_dir:
Expand Down

0 comments on commit 6c3e114

Please sign in to comment.