Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ Changelog

Features:

- Typing: Improve typings in `marshmallow.fields` (:pr:`2723`).
- Typing: Replace type comments with inline typings (:pr:`2718`).

Bug fixes:
Expand Down
4 changes: 0 additions & 4 deletions src/marshmallow/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,6 @@
class FieldABC(ABC):
"""Abstract base class from which all Field classes inherit."""

parent = None
name = None
root = None

@abstractmethod
def serialize(self, attr, obj, accessor=None):
pass
Expand Down
102 changes: 63 additions & 39 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,6 @@
"Pluck",
]

_T = typing.TypeVar("_T")


class Field(FieldABC):
"""Basic field from which other fields should extend. It applies no
Expand Down Expand Up @@ -132,7 +130,7 @@ class Field(FieldABC):
#: Default error messages for various kinds of errors. The keys in this dictionary
#: are passed to `Field.make_error`. The values are error messages passed to
#: :exc:`marshmallow.exceptions.ValidationError`.
default_error_messages = {
default_error_messages: dict[str, str] = {
"required": "Missing data for required field.",
"null": "Field may not be null.",
"validator_failed": "Invalid value.",
Expand Down Expand Up @@ -224,6 +222,10 @@ def __init__(
messages.update(error_messages or {})
self.error_messages = messages

self.parent: Field | Schema | None = None
self.name: str | None = None
self.root: Schema | None = None

def __repr__(self) -> str:
return (
f"<fields.{self.__class__.__name__}(dump_default={self.dump_default!r}, "
Expand All @@ -237,7 +239,15 @@ def __repr__(self) -> str:
def __deepcopy__(self, memo):
return copy.copy(self)

def get_value(self, obj, attr, accessor=None, default=missing_):
def get_value(
self,
obj: typing.Any,
attr: str,
accessor: (
typing.Callable[[typing.Any, str, typing.Any], typing.Any] | None
) = None,
default: typing.Any = missing_,
):
"""Return the value for a given key from an object.

:param object obj: The object to get the value from.
Expand All @@ -249,14 +259,14 @@ def get_value(self, obj, attr, accessor=None, default=missing_):
check_key = attr if self.attribute is None else self.attribute
return accessor_func(obj, check_key, default)

def _validate(self, value):
def _validate(self, value: typing.Any):
"""Perform validation on ``value``. Raise a :exc:`ValidationError` if validation
does not succeed.
"""
self._validate_all(value)

@property
def _validate_all(self):
def _validate_all(self) -> typing.Callable[[typing.Any], None]:
return And(*self.validators, error=self.error_messages["validator_failed"])

def make_error(self, key: str, **kwargs) -> ValidationError:
Expand Down Expand Up @@ -290,7 +300,7 @@ def fail(self, key: str, **kwargs):
)
raise self.make_error(key=key, **kwargs)

def _validate_missing(self, value):
def _validate_missing(self, value: typing.Any) -> None:
"""Validate missing values. Raise a :exc:`ValidationError` if
`value` should be considered missing.
"""
Expand Down Expand Up @@ -357,7 +367,7 @@ def deserialize(

# Methods for concrete classes to override.

def _bind_to_schema(self, field_name, schema):
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
"""Update field with values from its parent schema. Called by
:meth:`Schema._bind_field <marshmallow.Schema._bind_field>`.

Expand All @@ -372,7 +382,7 @@ def _bind_to_schema(self, field_name, schema):

def _serialize(
self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs
):
) -> typing.Any:
"""Serializes ``value`` to a basic Python datatype. Noop by default.
Concrete :class:`Field` classes should implement this method.

Expand All @@ -398,7 +408,7 @@ def _deserialize(
attr: str | None,
data: typing.Mapping[str, typing.Any] | None,
**kwargs,
):
) -> typing.Any:
"""Deserialize value. Concrete :class:`Field` classes should implement this method.

:param value: The value to be deserialized.
Expand All @@ -416,9 +426,11 @@ def _deserialize(
# Properties

@property
def context(self):
def context(self) -> dict | None:
"""The context dictionary for the parent :class:`Schema`."""
return self.parent.context
if self.parent:
return self.parent.context
return None

# the default and missing properties are provided for compatibility and
# emit warnings when they are accessed and set
Expand Down Expand Up @@ -630,12 +642,14 @@ def _serialize(self, nested_obj, attr, obj, **kwargs):
many = schema.many or self.many
return schema.dump(nested_obj, many=many)

def _test_collection(self, value):
def _test_collection(self, value: typing.Any) -> None:
many = self.schema.many or self.many
if many and not utils.is_collection(value):
raise self.make_error("type", input=value, type=value.__class__.__name__)

def _load(self, value, data, partial=None):
def _load(
self, value: typing.Any, partial: bool | types.StrSequenceOrSet | None = None
):
try:
valid_data = self.schema.load(value, unknown=self.unknown, partial=partial)
except ValidationError as error:
Expand All @@ -644,7 +658,14 @@ def _load(self, value, data, partial=None):
) from error
return valid_data

def _deserialize(self, value, attr, data, partial=None, **kwargs):
def _deserialize(
self,
value: typing.Any,
attr: str | None,
data: typing.Mapping[str, typing.Any] | None = None,
partial: bool | types.StrSequenceOrSet | None = None,
**kwargs,
) -> typing.Any:
"""Same as :meth:`Field._deserialize` with additional ``partial`` argument.

:param bool|tuple partial: For nested schemas, the ``partial``
Expand All @@ -654,7 +675,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
Add ``partial`` parameter.
"""
self._test_collection(value)
return self._load(value, data, partial=partial)
return self._load(value, partial=partial)


class Pluck(Nested):
Expand Down Expand Up @@ -694,7 +715,7 @@ def __init__(
self.field_name = field_name

@property
def _field_data_key(self):
def _field_data_key(self) -> str:
only_field = self.schema.fields[self.field_name]
return only_field.data_key or self.field_name

Expand All @@ -712,7 +733,7 @@ def _deserialize(self, value, attr, data, partial=None, **kwargs):
value = [{self._field_data_key: v} for v in value]
else:
value = {self._field_data_key: value}
return self._load(value, data, partial=partial)
return self._load(value, partial=partial)


class List(Field):
Expand Down Expand Up @@ -746,7 +767,7 @@ def __init__(self, cls_or_instance: Field | type[Field], **kwargs):
self.only = self.inner.only
self.exclude = self.inner.exclude

def _bind_to_schema(self, field_name, schema):
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
super()._bind_to_schema(field_name, schema)
self.inner = copy.deepcopy(self.inner)
self.inner._bind_to_schema(field_name, self)
Expand Down Expand Up @@ -790,7 +811,7 @@ class Tuple(Field):
`typing.NamedTuple`, using a Schema within a Nested field for them is
more appropriate than using a `Tuple` field.

:param Iterable[Field] tuple_fields: An iterable of field classes or
:param tuple_fields: An iterable of field classes or
instances.
:param kwargs: The same keyword arguments that :class:`Field` receives.

Expand All @@ -800,7 +821,7 @@ class Tuple(Field):
#: Default error messages.
default_error_messages = {"invalid": "Not a valid tuple."}

def __init__(self, tuple_fields, *args, **kwargs):
def __init__(self, tuple_fields: typing.Iterable[Field], *args, **kwargs):
super().__init__(*args, **kwargs)
if not utils.is_collection(tuple_fields):
raise ValueError(
Expand All @@ -820,7 +841,7 @@ def __init__(self, tuple_fields, *args, **kwargs):

self.validate_length = Length(equal=len(self.tuple_fields))

def _bind_to_schema(self, field_name, schema):
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
super()._bind_to_schema(field_name, schema)
new_tuple_fields = []
for field in self.tuple_fields:
Expand Down Expand Up @@ -910,7 +931,10 @@ def _deserialize(self, value, attr, data, **kwargs) -> uuid.UUID | None:
return self._validated(value)


class Number(Field):
_NumType = typing.TypeVar("_NumType")


class Number(Field, typing.Generic[_NumType]):
"""Base class for number fields.

:param bool as_string: If `True`, format the serialized value as a string.
Expand All @@ -929,14 +953,12 @@ def __init__(self, *, as_string: bool = False, **kwargs):
self.as_string = as_string
super().__init__(**kwargs)

def _format_num(self, value) -> typing.Any:
def _format_num(self, value) -> _NumType:
"""Return the number value for value, given this field's `num_type`."""
return self.num_type(value)

def _validated(self, value) -> _T | None:
def _validated(self, value: typing.Any) -> _NumType:
"""Format the value or raise a :exc:`ValidationError` if an error occurs."""
if value is None:
return None
# (value is True or value is False) is ~5x faster than isinstance(value, bool)
if value is True or value is False:
raise self.make_error("invalid", input=value)
Expand All @@ -947,21 +969,21 @@ def _validated(self, value) -> _T | None:
except OverflowError as error:
raise self.make_error("too_large", input=value) from error

def _to_string(self, value) -> str:
def _to_string(self, value: _NumType) -> str:
return str(value)

def _serialize(self, value, attr, obj, **kwargs) -> str | _T | None:
def _serialize(self, value, attr, obj, **kwargs) -> str | _NumType | None:
"""Return a string if `self.as_string=True`, otherwise return this field's `num_type`."""
if value is None:
return None
ret: _T = self._format_num(value)
ret: _NumType = self._format_num(value)
return self._to_string(ret) if self.as_string else ret

def _deserialize(self, value, attr, data, **kwargs) -> _T | None:
def _deserialize(self, value, attr, data, **kwargs) -> _NumType | None:
return self._validated(value)


class Integer(Number):
class Integer(Number[int]):
"""An integer field.

:param strict: If `True`, only integer types are valid.
Expand All @@ -979,13 +1001,13 @@ def __init__(self, *, strict: bool = False, **kwargs):
super().__init__(**kwargs)

# override Number
def _validated(self, value):
def _validated(self, value: typing.Any) -> int:
if self.strict and not isinstance(value, numbers.Integral):
raise self.make_error("invalid", input=value)
return super()._validated(value)


class Float(Number):
class Float(Number[float]):
"""A double as an IEEE-754 double precision string.

:param bool allow_nan: If `True`, `NaN`, `Infinity` and `-Infinity` are allowed,
Expand All @@ -1005,15 +1027,15 @@ def __init__(self, *, allow_nan: bool = False, as_string: bool = False, **kwargs
self.allow_nan = allow_nan
super().__init__(as_string=as_string, **kwargs)

def _validated(self, value):
def _validated(self, value: typing.Any) -> float:
num = super()._validated(value)
if self.allow_nan is False:
if math.isnan(num) or num == float("inf") or num == float("-inf"):
raise self.make_error("special")
return num


class Decimal(Number):
class Decimal(Number[decimal.Decimal]):
"""A field that (de)serializes to the Python ``decimal.Decimal`` type.
It's safe to use when dealing with money values, percentages, ratios
or other numbers where precision is critical.
Expand Down Expand Up @@ -1084,7 +1106,7 @@ def _format_num(self, value):
return num

# override Number
def _validated(self, value):
def _validated(self, value: typing.Any) -> decimal.Decimal:
try:
num = super()._validated(value)
except decimal.InvalidOperation as error:
Expand All @@ -1094,7 +1116,7 @@ def _validated(self, value):
return num

# override Number
def _to_string(self, value):
def _to_string(self, value: decimal.Decimal) -> str:
return format(value, "f")


Expand Down Expand Up @@ -1168,7 +1190,9 @@ def __init__(
if falsy is not None:
self.falsy = set(falsy)

def _serialize(self, value, attr, obj, **kwargs):
def _serialize(
self, value: typing.Any, attr: str | None, obj: typing.Any, **kwargs
):
if value is None:
return None

Expand Down
Loading