Skip to content

Commit 218721a

Browse files
authored
Merge pull request #20 from beda-software/pydantic-v2
Prepare for pydanticv2
2 parents b899df3 + 9c3946d commit 218721a

File tree

6 files changed

+556
-381
lines changed

6 files changed

+556
-381
lines changed

fhir_py_types/ast.py

Lines changed: 36 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,11 @@
11
import ast
2-
import functools
32
import itertools
43
import keyword
54
import logging
65
from collections.abc import Iterable
76
from dataclasses import replace
87
from enum import Enum, auto
9-
from typing import Literal, cast
8+
from typing import Literal
109

1110
from fhir_py_types import (
1211
StructureDefinition,
@@ -79,11 +78,13 @@ def format_identifier(
7978
def uppercamelcase(s: str) -> str:
8079
return s[:1].upper() + s[1:]
8180

82-
return (
83-
identifier + uppercamelcase(type_.code)
84-
if is_polymorphic(definition)
85-
else identifier
86-
)
81+
if is_polymorphic(definition):
82+
# TODO: it's fast hack
83+
if type_.code[0].islower():
84+
return identifier + uppercamelcase(clear_primitive_id(type_.code))
85+
return identifier + uppercamelcase(type_.code)
86+
87+
return identifier
8788

8889

8990
def remap_type(
@@ -94,8 +95,8 @@ def remap_type(
9495
case "Resource":
9596
# Different contexts use 'Resource' type to refer to any
9697
# resource differentiated by its 'resourceType' (tagged union).
97-
# 'AnyResource' is not defined by the spec but rather
98-
# generated as a union of all defined resource types.
98+
# 'AnyResource' is defined in header as a special type
99+
# that dynamically replaced with a right type in run-time
99100
type_ = replace(type_, code="AnyResource")
100101

101102
if is_polymorphic(definition):
@@ -106,6 +107,12 @@ def remap_type(
106107
# with a custom validator that will enforce single required property rule.
107108
type_ = replace(type_, required=False)
108109

110+
if is_primitive_type(type_):
111+
# Primitive types defined from the small letter (like code)
112+
# and it might overlap with model fields
113+
# e.g. QuestionnaireItem has attribute code and linkId has type code
114+
type_ = replace(type_, code=make_primitive_id(type_.code))
115+
109116
return type_
110117

111118

@@ -115,8 +122,7 @@ def zip_identifier_type(
115122
result = []
116123

117124
for t in [remap_type(definition, t) for t in definition.type]:
118-
name = format_identifier(definition, identifier, t)
119-
result.append((name, t))
125+
result.append((format_identifier(definition, identifier, t), t))
120126
if definition.kind != StructureDefinitionKind.PRIMITIVE and is_primitive_type(
121127
t
122128
):
@@ -185,10 +191,16 @@ def order_type_overriding_properties(
185191
def define_class_object(
186192
definition: StructureDefinition,
187193
) -> Iterable[ast.stmt | ast.expr]:
194+
bases: list[ast.expr] = []
195+
if definition.kind == StructureDefinitionKind.RESOURCE:
196+
bases.append(ast.Name("AnyResource"))
197+
# BaseModel should be the last, because it overrides `extra`
198+
bases.append(ast.Name("BaseModel"))
199+
188200
return [
189201
ast.ClassDef(
190202
definition.id,
191-
bases=[ast.Name("BaseModel")],
203+
bases=bases,
192204
body=[
193205
ast.Expr(value=ast.Constant(definition.docstring)),
194206
*itertools.chain.from_iterable(
@@ -202,11 +214,6 @@ def define_class_object(
202214
keywords=[],
203215
type_params=[],
204216
),
205-
ast.Call(
206-
ast.Attribute(value=ast.Name(definition.id), attr="update_forward_refs"),
207-
args=[],
208-
keywords=[],
209-
),
210217
]
211218

212219

@@ -215,48 +222,22 @@ def define_class(definition: StructureDefinition) -> Iterable[ast.stmt | ast.exp
215222

216223

217224
def define_alias(definition: StructureDefinition) -> Iterable[ast.stmt]:
218-
return type_annotate(definition, definition.id, AnnotationForm.TypeAlias)
219-
220-
221-
def define_tagged_union(
222-
name: str, components: Iterable[StructureDefinition], distinct_by: str
223-
) -> ast.stmt:
224-
annotation = functools.reduce(
225-
lambda acc, n: ast.BinOp(left=acc, right=n, op=ast.BitOr()),
226-
(cast(ast.expr, ast.Name(d.id)) for d in components),
225+
# Primitive types are renamed to another name to avoid overlapping with model fields
226+
return type_annotate(
227+
definition, make_primitive_id(definition.id), AnnotationForm.TypeAlias
227228
)
228229

229-
return ast.Assign(
230-
targets=[ast.Name(name)],
231-
value=ast.Subscript(
232-
value=ast.Name("Annotated_"),
233-
slice=ast.Tuple(
234-
elts=[
235-
annotation,
236-
ast.Call(
237-
ast.Name("Field"),
238-
args=[ast.Constant(...)],
239-
keywords=[
240-
ast.keyword(
241-
arg="discriminator", value=ast.Constant(distinct_by)
242-
),
243-
],
244-
),
245-
]
246-
),
247-
),
248-
)
249230

231+
def make_primitive_id(name: str) -> str:
232+
if name in ("str", "int", "float", "bool"):
233+
return name
234+
return f"{name}Type"
250235

251-
def select_tagged_resources(
252-
definitions: Iterable[StructureDefinition], key: str
253-
) -> Iterable[StructureDefinition]:
254-
return (
255-
definition
256-
for definition in definitions
257-
if definition.kind == StructureDefinitionKind.RESOURCE
258-
and key in definition.elements
259-
)
236+
237+
def clear_primitive_id(name: str) -> str:
238+
if name.endswith("Type"):
239+
return name[:-4]
240+
return name
260241

261242

262243
def select_nested_definitions(
@@ -302,14 +283,6 @@ def build_ast(
302283
f"Unsupported definition {definition.id} of kind {definition.kind}, skipping"
303284
)
304285

305-
resources = list(select_tagged_resources(structure_definitions, key="resourceType"))
306-
if resources:
307-
typedefinitions.append(
308-
define_tagged_union(
309-
name="AnyResource", components=resources, distinct_by="resourceType"
310-
)
311-
)
312-
313286
return sorted(
314287
typedefinitions,
315288
# Defer any postprocessing until after the structure tree is defined.

fhir_py_types/header.py.tpl

Lines changed: 157 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,162 @@
1-
from typing import List as List_, Optional as Optional_, Literal as Literal_, Annotated as Annotated_
1+
from typing import (
2+
List as List_,
3+
Optional as Optional_,
4+
Literal as Literal_,
5+
Any as Any_,
6+
)
27

3-
from pydantic import BaseModel as BaseModel_, Field, Extra
8+
from pydantic import (
9+
BaseModel as BaseModel_,
10+
ConfigDict,
11+
Field,
12+
SerializationInfo,
13+
field_validator,
14+
field_serializer,
15+
ValidationError,
16+
)
17+
from pydantic.main import IncEx
18+
from pydantic_core import PydanticCustomError
19+
20+
21+
class AnyResource(BaseModel_):
22+
model_config = ConfigDict(extra="allow")
23+
24+
resourceType: str
425

526

627

728
class BaseModel(BaseModel_):
8-
class Config:
9-
extra = Extra.forbid
10-
validate_assignment = True
11-
allow_population_by_field_name = True
12-
13-
def dict(self, *args, **kwargs):
14-
by_alias = kwargs.pop('by_alias', True)
15-
return super().dict(*args, **kwargs, by_alias=by_alias)
29+
model_config = ConfigDict(
30+
# Extra attributes are disabled because fhir does not allow it
31+
extra="forbid",
32+
# Validation are applied while mutating the resource
33+
validate_assignment=True,
34+
# It's important for reserved keywords population in constructor (e.g. for_)
35+
populate_by_name=True,
36+
# Speed up initial load by lazy build
37+
defer_build=True,
38+
# It does not break anything, just for convinience
39+
coerce_numbers_to_str=True,
40+
)
41+
42+
def model_dump(
43+
self,
44+
*,
45+
mode: Literal_["json", "python"] | str = "python",
46+
include: IncEx = None,
47+
exclude: IncEx = None,
48+
context: Any_ | None = None,
49+
by_alias: bool = True,
50+
exclude_unset: bool = False,
51+
exclude_defaults: bool = False,
52+
exclude_none: bool = True,
53+
round_trip: bool = False,
54+
warnings: bool | Literal_["none", "warn", "error"] = True,
55+
serialize_as_any: bool = False,
56+
):
57+
# Override default parameters for by_alias and exclude_none preserving function declaration
58+
return super().model_dump(
59+
mode=mode,
60+
include=include,
61+
exclude=exclude,
62+
context=context,
63+
by_alias=by_alias,
64+
exclude_unset=exclude_unset,
65+
exclude_defaults=exclude_defaults,
66+
exclude_none=exclude_none,
67+
round_trip=round_trip,
68+
warnings=warnings,
69+
serialize_as_any=serialize_as_any,
70+
)
71+
72+
@field_serializer("*")
73+
@classmethod
74+
def serialize_all_fields(cls, value: Any_, info: SerializationInfo):
75+
if isinstance(value, list):
76+
return [_serialize(v, info) for v in value]
77+
78+
return _serialize(value, info)
79+
80+
@field_validator("*")
81+
@classmethod
82+
def validate_all_fields(cls, value: Any_):
83+
if isinstance(value, list):
84+
return [_validate(v, index=index) for index, v in enumerate(value)]
85+
return _validate(value)
86+
87+
88+
def _serialize(value: Any_, info: SerializationInfo):
89+
# Custom serializer for AnyResource fields
90+
kwargs = {
91+
"mode": info.mode,
92+
"include": info.include,
93+
"exclude": info.exclude,
94+
"context": info.context,
95+
"by_alias": info.by_alias,
96+
"exclude_unset": info.exclude_unset,
97+
"exclude_defaults": info.exclude_defaults,
98+
"exclude_none": info.exclude_none,
99+
"round_trip": info.round_trip,
100+
"serialize_as_any": info.serialize_as_any,
101+
}
102+
if isinstance(value, AnyResource):
103+
return value.model_dump(**kwargs) # type: ignore
104+
if isinstance(value, BaseModel_):
105+
return value.model_dump(**kwargs) # type: ignore
106+
return value
107+
108+
109+
def _validate(value: Any_, index: int | None = None):
110+
# Custom validator for AnyResource fields
111+
if isinstance(value, AnyResource):
112+
try:
113+
klass = globals()[value.resourceType]
114+
except KeyError as exc:
115+
raise ValidationError.from_exception_data(
116+
"ImportError",
117+
[
118+
{
119+
"loc": (index, "resourceType")
120+
if index is not None
121+
else ("resourceType",),
122+
"type": "value_error",
123+
"input": [value],
124+
"ctx": {"error": f"{value.resourceType} resource is not found"},
125+
}
126+
],
127+
) from exc
128+
129+
if not issubclass(klass, BaseModel) or "resourceType" not in klass.model_fields:
130+
raise ValidationError.from_exception_data(
131+
"ImportError",
132+
[
133+
{
134+
"loc": (index, "resourceType")
135+
if index is not None
136+
else ("resourceType",),
137+
"type": "value_error",
138+
"input": [value],
139+
"ctx": {"error": f"{value.resourceType} is not a resource"},
140+
}
141+
],
142+
)
143+
144+
try:
145+
return klass(**value.model_dump())
146+
except ValidationError as exc:
147+
raise ValidationError.from_exception_data(
148+
exc.title,
149+
[
150+
{
151+
"loc": (index, *error["loc"])
152+
if index is not None
153+
else error["loc"],
154+
"type": error["type"],
155+
"input": error["input"],
156+
"ctx": error["ctx"]
157+
}
158+
for error in exc.errors()
159+
],
160+
) from exc
161+
162+
return value

0 commit comments

Comments
 (0)