Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: msgspec discriminated unions #2081

Merged
merged 6 commits into from
Nov 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions datamodel_code_generator/model/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

IMPORT_DATACLASS = Import.from_full_path('dataclasses.dataclass')
IMPORT_FIELD = Import.from_full_path('dataclasses.field')
IMPORT_CLASSVAR = Import.from_full_path('typing.ClassVar')
IMPORT_TYPED_DICT = Import.from_full_path('typing.TypedDict')
IMPORT_TYPED_DICT_BACKPORT = Import.from_full_path('typing_extensions.TypedDict')
IMPORT_NOT_REQUIRED = Import.from_full_path('typing.NotRequired')
Expand Down
16 changes: 14 additions & 2 deletions datamodel_code_generator/model/msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import (
IMPORT_CLASSVAR,
IMPORT_MSGSPEC_CONVERT,
IMPORT_MSGSPEC_FIELD,
IMPORT_MSGSPEC_META,
Expand Down Expand Up @@ -72,6 +73,8 @@ def new_imports(self: DataModelFieldBaseT) -> Tuple[Import, ...]:
extra_imports.append(IMPORT_MSGSPEC_CONVERT)
if self.annotated:
extra_imports.append(IMPORT_MSGSPEC_META)
if self.extras.get('is_classvar'):
extra_imports.append(IMPORT_CLASSVAR)
return chain_as_tuple(original_imports.fget(self), extra_imports) # type: ignore

setattr(cls, 'imports', property(new_imports))
Expand Down Expand Up @@ -119,6 +122,10 @@ def __init__(
nullable=nullable,
keyword_only=keyword_only,
)
self.extra_template_data.setdefault('base_class_kwargs', {})

def add_base_class_kwarg(self, name: str, value):
self.extra_template_data['base_class_kwargs'][name] = value


class Constraints(_Constraints):
Expand Down Expand Up @@ -257,11 +264,16 @@ def annotated(self) -> Optional[str]:

meta = f'Meta({", ".join(meta_arguments)})'

if not self.required:
if not self.required and not self.extras.get('is_classvar'):
type_hint = self.data_type.type_hint
annotated_type = f'Annotated[{type_hint}, {meta}]'
return get_optional_type(annotated_type, self.data_type.use_union_operator)
return f'Annotated[{self.type_hint}, {meta}]'

annotated_type = f'Annotated[{self.type_hint}, {meta}]'
if self.extras.get('is_classvar'):
annotated_type = f'ClassVar[{annotated_type}]'

return annotated_type

def _get_default_as_struct_model(self) -> Optional[str]:
for data_type in self.data_type.data_types or (self.data_type,):
Expand Down
4 changes: 3 additions & 1 deletion datamodel_code_generator/model/template/msgspec.jinja2
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
{{ decorator }}
{% endfor -%}
{%- if base_class %}
class {{ class_name }}({{ base_class }}):
class {{ class_name }}({{ base_class }}{%- for key, value in (base_class_kwargs|default({})).items() -%}
, {{ key }}={{ value }}
{%- endfor -%}):
{%- else %}
class {{ class_name }}:
{%- endif %}
Expand Down
12 changes: 12 additions & 0 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
Imports,
)
from datamodel_code_generator.model import dataclass as dataclass_model
from datamodel_code_generator.model import msgspec as msgspec_model
from datamodel_code_generator.model import pydantic as pydantic_model
from datamodel_code_generator.model import pydantic_v2 as pydantic_model_v2
from datamodel_code_generator.model.base import (
Expand Down Expand Up @@ -810,6 +811,7 @@ def __apply_discriminator_type(
pydantic_model.BaseModel,
pydantic_model_v2.BaseModel,
dataclass_model.DataClass,
msgspec_model.Struct,
),
):
continue # pragma: no cover
Expand Down Expand Up @@ -870,6 +872,16 @@ def check_paths(
else None
):
has_one_literal = True
if isinstance(
discriminator_model, msgspec_model.Struct
): # pragma: no cover
discriminator_model.add_base_class_kwarg(
'tag_field', f"'{property_name}'"
)
discriminator_model.add_base_class_kwarg(
'tag', discriminator_field.represented_default
)
discriminator_field.extras['is_classvar'] = True
continue
for (
field_data_type
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
# generated by datamodel-codegen:
# filename: discriminator_literals.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type1(Struct, tag_field='type_', tag='a'):
type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a'


class Type2(Struct, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'


class Response(Struct):
inner: Annotated[Union[Type1, Type2], Meta(title='Inner')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
# generated by datamodel-codegen:
# filename: schema.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Optional, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type1(Struct, tag_field='type_', tag='a'):
type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a'


class Type2(Struct, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'
ref_type: Optional[Annotated[Type1, Meta(description='A referenced type.')]] = None


class Type4(Struct, tag_field='type_', tag='d'):
type_: ClassVar[Annotated[Literal['d'], Meta(title='Type ')]] = 'd'


class Type5(Struct, tag_field='type_', tag='e'):
type_: ClassVar[Annotated[Literal['e'], Meta(title='Type ')]] = 'e'


class Type3(Struct, tag_field='type_', tag='c'):
type_: ClassVar[Annotated[Literal['c'], Meta(title='Type ')]] = 'c'


class Response(Struct):
inner: Annotated[Union[Type1, Type2, Type3, Type4, Type5], Meta(title='Inner')]
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# generated by datamodel-codegen:
# filename: discriminator_with_external_reference
# timestamp: 2019-07-26T00:00:00+00:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# generated by datamodel-codegen:
# filename: discriminator_with_external_reference
# timestamp: 2019-07-26T00:00:00+00:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# generated by datamodel-codegen:
# filename: discriminator_with_external_reference
# timestamp: 2019-07-26T00:00:00+00:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# generated by datamodel-codegen:
# filename: inner_folder/artificial_folder/type-1.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type1(Struct, tag_field='type_', tag='a'):
type_: ClassVar[Annotated[Literal['a'], Meta(title='Type ')]] = 'a'
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# generated by datamodel-codegen:
# filename: inner_folder/schema.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Union

from msgspec import Meta, Struct
from typing_extensions import Annotated

from .. import type_4
from ..subfolder import type_5
from . import type_2
from .artificial_folder import type_1


class Type3(Struct, tag_field='type_', tag='c'):
type_: ClassVar[Annotated[Literal['c'], Meta(title='Type ')]] = 'c'


class Response(Struct):
inner: Annotated[
Union[type_1.Type1, type_2.Type2, Type3, type_4.Type4, type_5.Type5],
Meta(title='Inner'),
]
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
# generated by datamodel-codegen:
# filename: inner_folder/type-2.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal, Optional

from msgspec import Meta, Struct
from typing_extensions import Annotated

from .artificial_folder import type_1


class Type2(Struct, tag_field='type_', tag='b'):
type_: ClassVar[Annotated[Literal['b'], Meta(title='Type ')]] = 'b'
ref_type: Optional[
Annotated[type_1.Type1, Meta(description='A referenced type.')]
] = None
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# generated by datamodel-codegen:
# filename: discriminator_with_external_reference
# timestamp: 2019-07-26T00:00:00+00:00
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# generated by datamodel-codegen:
# filename: subfolder/type-5.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type5(Struct, tag_field='type_', tag='e'):
type_: ClassVar[Annotated[Literal['e'], Meta(title='Type ')]] = 'e'
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# generated by datamodel-codegen:
# filename: type-4.json
# timestamp: 2019-07-26T00:00:00+00:00

from __future__ import annotations

from typing import ClassVar, Literal

from msgspec import Meta, Struct
from typing_extensions import Annotated


class Type4(Struct, tag_field='type_', tag='d'):
type_: ClassVar[Annotated[Literal['d'], Meta(title='Type ')]] = 'd'
71 changes: 57 additions & 14 deletions tests/main/jsonschema/test_main_jsonschema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3149,8 +3149,21 @@ def test_main_typed_dict_not_required_nullable():
)


@pytest.mark.parametrize(
'output_model,expected_output',
[
(
'pydantic_v2.BaseModel',
'discriminator_literals.py',
),
(
'msgspec.Struct',
'discriminator_literals_msgspec.py',
),
],
)
@freeze_time('2019-07-26')
def test_main_jsonschema_discriminator_literals():
def test_main_jsonschema_discriminator_literals(output_model, expected_output):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
Expand All @@ -3160,18 +3173,33 @@ def test_main_jsonschema_discriminator_literals():
'--output',
str(output_file),
'--output-model-type',
'pydantic_v2.BaseModel',
output_model,
'--target-python',
'3.8',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (EXPECTED_JSON_SCHEMA_PATH / 'discriminator_literals.py').read_text()
== (EXPECTED_JSON_SCHEMA_PATH / expected_output).read_text()
)


@pytest.mark.parametrize(
'output_model,expected_output',
[
(
'pydantic_v2.BaseModel',
'discriminator_with_external_reference.py',
),
(
'msgspec.Struct',
'discriminator_with_external_reference_msgspec.py',
),
],
)
@freeze_time('2019-07-26')
def test_main_jsonschema_external_discriminator():
def test_main_jsonschema_external_discriminator(output_model, expected_output):
with TemporaryDirectory() as output_dir:
output_file: Path = Path(output_dir) / 'output.py'
return_code: Exit = main(
Expand All @@ -3186,20 +3214,33 @@ def test_main_jsonschema_external_discriminator():
'--output',
str(output_file),
'--output-model-type',
'pydantic_v2.BaseModel',
output_model,
'--target-python',
'3.8',
]
)
assert return_code == Exit.OK
assert (
output_file.read_text()
== (
EXPECTED_JSON_SCHEMA_PATH / 'discriminator_with_external_reference.py'
).read_text()
)
== (EXPECTED_JSON_SCHEMA_PATH / expected_output).read_text()
), EXPECTED_JSON_SCHEMA_PATH / expected_output


@pytest.mark.parametrize(
'output_model,expected_output',
[
(
'pydantic.BaseModel',
'discriminator_with_external_references_folder',
),
(
'msgspec.Struct',
'discriminator_with_external_references_folder_msgspec',
),
],
)
@freeze_time('2019-07-26')
def test_main_jsonschema_external_discriminator_folder():
def test_main_jsonschema_external_discriminator_folder(output_model, expected_output):
with TemporaryDirectory() as output_dir:
output_path: Path = Path(output_dir)
return_code: Exit = main(
Expand All @@ -3208,17 +3249,19 @@ def test_main_jsonschema_external_discriminator_folder():
str(JSON_SCHEMA_DATA_PATH / 'discriminator_with_external_reference'),
'--output',
str(output_path),
'--output-model-type',
output_model,
'--target-python',
'3.8',
]
)
assert return_code == Exit.OK
main_modular_dir = (
EXPECTED_JSON_SCHEMA_PATH / 'discriminator_with_external_references_folder'
)
main_modular_dir = EXPECTED_JSON_SCHEMA_PATH / expected_output
for path in main_modular_dir.rglob('*.py'):
result = output_path.joinpath(
path.relative_to(main_modular_dir)
).read_text()
assert result == path.read_text()
assert result == path.read_text(), path


@freeze_time('2019-07-26')
Expand Down
Loading