Skip to content

Commit

Permalink
Merge branch 'main' into discriminated-unions-msgspec-support
Browse files Browse the repository at this point in the history
  • Loading branch information
koxudaxi authored Nov 10, 2024
2 parents e6c97b4 + 6c3e114 commit c56c569
Show file tree
Hide file tree
Showing 13 changed files with 295 additions and 71 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: 'v0.6.7'
rev: 'v0.7.2'
hooks:
- id: ruff
files: "^datamodel_code_generator|^tests"
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -455,7 +455,7 @@ Model customization:
--output-datetime-class {datetime,AwareDatetime,NaiveDatetime}
Choose Datetime class between AwareDatetime, NaiveDatetime or
datetime. Each output model has its default mapping, and only
pydantic and dataclass support this override"
pydantic, dataclass, and msgspec support this override"
--reuse-model Reuse models on the field when a module has the model with the same
content
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12}
Expand Down
2 changes: 1 addition & 1 deletion datamodel_code_generator/model/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def get_data_model_types(
data_model=msgspec.Struct,
root_model=msgspec.RootModel,
field_model=msgspec.DataModelField,
data_type_manager=DataTypeManager,
data_type_manager=msgspec.DataTypeManager,
dump_resolve_reference_action=None,
known_third_party=['msgspec'],
)
Expand Down
60 changes: 58 additions & 2 deletions datamodel_code_generator/model/msgspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
Dict,
List,
Optional,
Sequence,
Set,
Tuple,
Type,
Expand All @@ -15,7 +16,14 @@

from pydantic import Field

from datamodel_code_generator.imports import Import
from datamodel_code_generator import DatetimeClassType, PythonVersion
from datamodel_code_generator.imports import (
IMPORT_DATE,
IMPORT_DATETIME,
IMPORT_TIME,
IMPORT_TIMEDELTA,
Import,
)
from datamodel_code_generator.model import DataModel, DataModelFieldBase
from datamodel_code_generator.model.base import UNDEFINED
from datamodel_code_generator.model.imports import (
Expand All @@ -29,8 +37,16 @@
Constraints as _Constraints,
)
from datamodel_code_generator.model.rootmodel import RootModel as _RootModel
from datamodel_code_generator.model.types import DataTypeManager as _DataTypeManager
from datamodel_code_generator.model.types import type_map_factory
from datamodel_code_generator.reference import Reference
from datamodel_code_generator.types import chain_as_tuple, get_optional_type
from datamodel_code_generator.types import (
DataType,
StrictTypes,
Types,
chain_as_tuple,
get_optional_type,
)


def _has_field_assignment(field: DataModelFieldBase) -> bool:
Expand Down Expand Up @@ -279,3 +295,43 @@ def _get_default_as_struct_model(self) -> Optional[str]:
elif data_type.reference and isinstance(data_type.reference.source, Struct):
return f'lambda: {self._PARSE_METHOD}({repr(self.default)}, type={data_type.alias or data_type.reference.source.class_name})'
return None


class DataTypeManager(_DataTypeManager):
def __init__(
self,
python_version: PythonVersion = PythonVersion.PY_38,
use_standard_collections: bool = False,
use_generic_container_types: bool = False,
strict_types: Optional[Sequence[StrictTypes]] = None,
use_non_positive_negative_number_constrained_types: bool = False,
use_union_operator: bool = False,
use_pendulum: bool = False,
target_datetime_class: DatetimeClassType = DatetimeClassType.Datetime,
):
super().__init__(
python_version,
use_standard_collections,
use_generic_container_types,
strict_types,
use_non_positive_negative_number_constrained_types,
use_union_operator,
use_pendulum,
target_datetime_class,
)

datetime_map = (
{
Types.time: self.data_type.from_import(IMPORT_TIME),
Types.date: self.data_type.from_import(IMPORT_DATE),
Types.date_time: self.data_type.from_import(IMPORT_DATETIME),
Types.timedelta: self.data_type.from_import(IMPORT_TIMEDELTA),
}
if target_datetime_class is DatetimeClassType.Datetime
else {}
)

self.type_map: Dict[Types, DataType] = {
**type_map_factory(self.data_type),
**datetime_map,
}
1 change: 1 addition & 0 deletions datamodel_code_generator/model/pydantic_v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ class ConfigDict(_BaseModel):
arbitrary_types_allowed: Optional[bool] = None
protected_namespaces: Optional[Tuple[str, ...]] = None
regex_engine: Optional[str] = None
use_enum_values: Optional[bool] = None


__all__ = [
Expand Down
53 changes: 36 additions & 17 deletions datamodel_code_generator/parser/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -804,6 +804,7 @@ def __apply_discriminator_type(
if not data_type.reference: # pragma: no cover
continue
discriminator_model = data_type.reference.source

if not isinstance( # pragma: no cover
discriminator_model,
(
Expand All @@ -814,26 +815,43 @@ def __apply_discriminator_type(
),
):
continue # pragma: no cover
type_names = []
if mapping:

type_names: List[str] = []

def check_paths(
model: Union[
pydantic_model.BaseModel,
pydantic_model_v2.BaseModel,
Reference,
],
mapping: Dict[str, str],
type_names: List[str] = type_names,
) -> None:
"""Helper function to validate paths for a given model."""
for name, path in mapping.items():
if (
discriminator_model.path.split('#/')[-1]
!= path.split('#/')[-1]
model.path.split('#/')[-1] != path.split('#/')[-1]
) and (
path.startswith('#/')
or model.path[:-1] != path.split('/')[-1]
):
if (
path.startswith('#/')
or discriminator_model.path[:-1]
!= path.split('/')[-1]
):
t_path = path[str(path).find('/') + 1 :]
t_disc = discriminator_model.path[
: str(discriminator_model.path).find('#')
].lstrip('../')
t_disc_2 = '/'.join(t_disc.split('/')[1:])
if t_path != t_disc and t_path != t_disc_2:
continue
t_path = path[str(path).find('/') + 1 :]
t_disc = model.path[: str(model.path).find('#')].lstrip(
'../'
)
t_disc_2 = '/'.join(t_disc.split('/')[1:])
if t_path != t_disc and t_path != t_disc_2:
continue
type_names.append(name)

# Check the main discriminator model path
if mapping:
check_paths(discriminator_model, mapping)

# Check the base_classes if they exist
if len(type_names) == 0:
for base_class in discriminator_model.base_classes:
check_paths(base_class.reference, mapping)
else:
type_names = [discriminator_model.path.split('/')[-1]]
if not type_names: # pragma: no cover
Expand Down Expand Up @@ -891,7 +909,8 @@ def __apply_discriminator_type(
else IMPORT_LITERAL_BACKPORT
)
has_imported_literal = any(
literal == import_ for import_ in imports
literal == import_ # type: ignore [comparison-overlap]
for import_ in imports
)
if has_imported_literal: # pragma: no cover
imports.append(literal)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ Model customization:
--output-datetime-class {datetime,AwareDatetime,NaiveDatetime}
Choose Datetime class between AwareDatetime, NaiveDatetime or
datetime. Each output model has its default mapping, and only
pydantic and dataclass support this override"
pydantic, dataclass, and msgspec support this override"
--reuse-model Reuse models on the field when a module has the model with the same
content
--target-python-version {3.6,3.7,3.8,3.9,3.10,3.11,3.12}
Expand Down
Loading

0 comments on commit c56c569

Please sign in to comment.