Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

load accepts Sequence rather than Iterable (rejects generators) #2795

Open
wants to merge 4 commits into
base: dev
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use TypeGuards and fix type issues
  • Loading branch information
sloria committed Jan 20, 2025
commit 592ac20407df356a94a8d9d9b6de9f08967534b2
11 changes: 5 additions & 6 deletions src/marshmallow/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
ValidationError,
_FieldInstanceResolutionError,
)
from marshmallow.utils import is_aware, is_collection
from marshmallow.validate import And, Length

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -501,9 +500,9 @@ def __init__(
**kwargs: Unpack[_BaseFieldKwargs],
):
# Raise error if only or exclude is passed as string, not list of strings
if only is not None and not is_collection(only):
if only is not None and not utils.is_sequence_but_not_string(only):
raise StringNotCollectionError('"only" should be a collection of strings.')
if not is_collection(exclude):
if not utils.is_sequence_but_not_string(exclude):
raise StringNotCollectionError(
'"exclude" should be a collection of strings.'
)
Expand Down Expand Up @@ -818,7 +817,7 @@ def _deserialize(
data: typing.Mapping[str, typing.Any] | None,
**kwargs,
) -> tuple:
if not utils.is_collection(value):
if not utils.is_sequence_but_not_string(value):
raise self.make_error("invalid")

self.validate_length(value)
Expand Down Expand Up @@ -1322,7 +1321,7 @@ def __init__(

def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
ret = super()._deserialize(value, attr, data, **kwargs)
if is_aware(ret):
if utils.is_aware(ret):
if self.timezone is None:
raise self.make_error(
"invalid_awareness",
Expand Down Expand Up @@ -1359,7 +1358,7 @@ def __init__(

def _deserialize(self, value, attr, data, **kwargs) -> dt.datetime:
ret = super()._deserialize(value, attr, data, **kwargs)
if not is_aware(ret):
if not utils.is_aware(ret):
if self.default_timezone is None:
raise self.make_error(
"invalid_awareness",
Expand Down
8 changes: 3 additions & 5 deletions src/marshmallow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@

from marshmallow.constants import missing

# TODO: Use TypeGuard for predicate functions


def is_generator(obj) -> bool:
def is_generator(obj) -> TypeGuard[typing.Generator]:
"""Return True if ``obj`` is a generator"""
return inspect.isgeneratorfunction(obj) or inspect.isgenerator(obj)


def is_iterable_but_not_string(obj) -> bool:
def is_iterable_but_not_string(obj) -> TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is an iterable object that isn't a string."""
return (hasattr(obj, "__iter__") and not hasattr(obj, "strip")) or is_generator(obj)

Expand All @@ -34,7 +32,7 @@ def is_sequence_but_not_string(obj) -> TypeGuard[Sequence]:
return isinstance(obj, Sequence) and not isinstance(obj, (str, bytes))


def is_collection(obj) -> bool:
def is_collection(obj) -> TypeGuard[typing.Iterable]:
"""Return True if ``obj`` is a collection type, e.g list, tuple, queryset."""
return is_iterable_but_not_string(obj) and not isinstance(obj, Mapping)

Expand Down
Loading