Skip to content

Commit 6357e5b

Browse files
committed
Fix type annotations for xml_field_validator/serializer
There are 2 issues with the annotations: 1. Function types are contravariant rather than covariant in their argument types. This means that `Callable[[BaseXmlModel], Any]` is a subtype rather than a supertype of `Callable[[MyModel], Any]` and so using `Callable[[BaseXmlModel], Any]` as a lower bound of TypeVars doesn't work. Since Python doesn't support upper bounds on TypeVars, the solution is to push the TypeVar inside. To see the issue run mypy on the example from the documentation[1]. Currently, you get these errors: ``` t.py:13: error: Value of type variable "ValidatorFuncT" of function cannot be "Callable[[type[Plot], XmlElementReader, str], list[float]]" [type-var] t.py:21: error: Value of type variable "SerializerFuncT" of function cannot be "Callable[[Plot, XmlElementWriter, list[float], str], None]" [type-var] ``` This PR resolves these (however, validate_space_separated_list needs to be decorated with @classmethod). 2. A recent commit[2] has a typo where ValidatorFuncT was bound by SerializerFunc instead of ValidatorFunc. [1] https://pydantic-xml.readthedocs.io/en/latest/pages/misc.html#custom-xml-serialization [2] ec4b547
1 parent e5f21f5 commit 6357e5b

File tree

4 files changed

+30
-17
lines changed

4 files changed

+30
-17
lines changed

examples/xml-serialization-decorator/model.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@ class Plot(BaseXmlModel):
1111
y: List[float] = element()
1212

1313
@xml_field_validator('x', 'y')
14+
@classmethod
1415
def validate_space_separated_list(cls, element: XmlElementReader, field_name: str) -> List[float]:
15-
if element := element.pop_element(field_name, search_mode=cls.__xml_search_mode__):
16-
return list(map(float, element.pop_text().split()))
16+
if (sub_element := element.pop_element(field_name, search_mode=cls.__xml_search_mode__)) and (
17+
text := sub_element.pop_text()
18+
):
19+
return list(map(float, text.split()))
1720

1821
return []
1922

pydantic_xml/fields.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import dataclasses as dc
22
import typing
3-
from typing import Any, Callable, Optional, TypeVar, Union
3+
from typing import Any, Callable, Optional, Union
44

55
import pydantic as pd
66
import pydantic_core as pdc
@@ -292,36 +292,38 @@ def computed_element(
292292
return computed_entity(EntityLocation.ELEMENT, prop, path=tag, ns=ns, nsmap=nsmap, nillable=nillable, **kwargs)
293293

294294

295-
ValidatorFuncT = TypeVar('ValidatorFuncT', bound='model.SerializerFunc')
296-
297-
298-
def xml_field_validator(field: str, /, *fields: str) -> Callable[[ValidatorFuncT], ValidatorFuncT]:
295+
def xml_field_validator(
296+
field: str, /, *fields: str
297+
) -> 'Callable[[model.ValidatorFuncT[model.ModelT]], model.ValidatorFuncT[model.ModelT]]':
299298
"""
300299
Marks the method as a field xml validator.
301300
302301
:param field: field to be validated
303302
:param fields: fields to be validated
304303
"""
305304

306-
def wrapper(func: ValidatorFuncT) -> ValidatorFuncT:
305+
def wrapper(func: model.ValidatorFuncT[model.ModelT]) -> model.ValidatorFuncT[model.ModelT]:
306+
if isinstance(func, (classmethod, staticmethod)):
307+
func = func.__func__
307308
setattr(func, '__xml_field_validator__', (field, *fields))
308309
return func
309310

310311
return wrapper
311312

312313

313-
SerializerFuncT = TypeVar('SerializerFuncT', bound='model.SerializerFunc')
314-
315-
316-
def xml_field_serializer(field: str, /, *fields: str) -> Callable[[SerializerFuncT], SerializerFuncT]:
314+
def xml_field_serializer(
315+
field: str, /, *fields: str
316+
) -> 'Callable[[model.SerializerFuncT[model.ModelT]], model.SerializerFuncT[model.ModelT]]':
317317
"""
318318
Marks the method as a field xml serializer.
319319
320320
:param field: field to be serialized
321321
:param fields: fields to be serialized
322322
"""
323323

324-
def wrapper(func: SerializerFuncT) -> SerializerFuncT:
324+
def wrapper(func: model.SerializerFuncT[model.ModelT]) -> model.SerializerFuncT[model.ModelT]:
325+
if isinstance(func, (classmethod, staticmethod)):
326+
func = func.__func__
325327
setattr(func, '__xml_field_serializer__', (field, *fields))
326328
return func
327329

pydantic_xml/model.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,12 @@
2020
__all__ = (
2121
'BaseXmlModel',
2222
'create_model',
23+
'ModelT',
2324
'RootXmlModel',
2425
'SerializerFunc',
26+
'SerializerFuncT',
2527
'ValidatorFunc',
28+
'ValidatorFuncT',
2629
'XmlModelMeta',
2730
)
2831

@@ -137,9 +140,11 @@ def _collect_xml_field_serializers_validators(mcls, cls: Type['BaseXmlModel']) -
137140
cls.__xml_field_validators__[field] = func
138141

139142

140-
ValidatorFunc = Callable[[Type['BaseXmlModel'], XmlElementReader, str], Any]
141-
SerializerFunc = Callable[['BaseXmlModel', XmlElementWriter, Any, str], Any]
142143
ModelT = TypeVar('ModelT', bound='BaseXmlModel')
144+
ValidatorFuncT = Callable[[Type[ModelT], XmlElementReader, str], Any]
145+
ValidatorFunc = ValidatorFuncT['BaseXmlModel']
146+
SerializerFuncT = Callable[[ModelT, XmlElementWriter, Any, str], Any]
147+
SerializerFunc = SerializerFuncT['BaseXmlModel']
143148

144149

145150
class BaseXmlModel(BaseModel, __xml_abstract__=True, metaclass=XmlModelMeta):

tests/test_preprocessors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,12 @@ class TestModel(BaseXmlModel, tag='model1'):
1313
element1: List[int] = element()
1414

1515
@xml_field_validator('element1')
16+
@classmethod
1617
def validate_element(cls, element: XmlElementReader, field_name: str) -> List[int]:
17-
if element := element.pop_element(field_name, search_mode=cls.__xml_search_mode__):
18-
return list(map(int, element.pop_text().split()))
18+
if (sub_element := element.pop_element(field_name, search_mode=cls.__xml_search_mode__)) and (
19+
text := sub_element.pop_text()
20+
):
21+
return list(map(float, text.split()))
1922

2023
return []
2124

0 commit comments

Comments
 (0)