Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 20 additions & 10 deletions src/packaging/specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@
from typing import Callable, Final, Iterable, Iterator, TypeVar, Union

from .utils import canonicalize_version
from .version import Version
from .version import InvalidVersion, Version

UnparsedVersion = Union[Version, str]
UnparsedVersionVar = TypeVar("UnparsedVersionVar", bound=UnparsedVersion)
CallableOperator = Callable[[Version, str], bool]


def _coerce_version(version: UnparsedVersion) -> Version:
def _coerce_version(version: UnparsedVersion) -> Version | None:
if not isinstance(version, Version):
version = Version(version)
try:
version = Version(version)
except InvalidVersion:
return None
return version


Expand Down Expand Up @@ -581,6 +584,8 @@ def filter(
# Filter versions
for version in iterable:
parsed_version = _coerce_version(version)
if parsed_version is None:
continue

if operator_callable(parsed_version, self.version):
# If it's not a prerelease or prereleases are allowed, yield it directly
Expand Down Expand Up @@ -894,14 +899,14 @@ def contains(
>>> SpecifierSet(">=1.0.0,!=1.0.1").contains("1.3.0a1", prereleases=True)
True
"""
# Ensure that our item is a Version instance.
if not isinstance(item, Version):
item = Version(item)
version = _coerce_version(item)
if version is None:
return False

if installed and item.is_prerelease:
if installed and version.is_prerelease:
prereleases = True

return bool(list(self.filter([item], prereleases=prereleases)))
return bool(list(self.filter([version], prereleases=prereleases)))

def filter(
self, iterable: Iterable[UnparsedVersionVar], prereleases: bool | None = None
Expand Down Expand Up @@ -959,7 +964,7 @@ def filter(

if prereleases is not None:
# If we have a forced prereleases value,
# we can immediately return he iterator.
# we can immediately return the iterator.
return iter(iterable)
else:
# Handle empty SpecifierSet cases where prereleases is not None.
Expand All @@ -968,7 +973,10 @@ def filter(

if prereleases is False:
return (
item for item in iterable if not _coerce_version(item).is_prerelease
item
for item in iterable
if (version := _coerce_version(item)) is not None
and not version.is_prerelease
)

# Finally if prereleases is None, apply PEP 440 logic:
Expand All @@ -978,6 +986,8 @@ def filter(

for item in iterable:
parsed_version = _coerce_version(item)
if parsed_version is None:
continue
if parsed_version.is_prerelease:
found_prereleases.append(item)
else:
Expand Down
31 changes: 31 additions & 0 deletions tests/test_specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,6 +494,17 @@ def test_specifiers(self, version, spec, expected):
assert Version(version) not in spec
assert not spec.contains(Version(version))

@pytest.mark.parametrize(
("spec", "version"),
[
("==1.0", "not a valid version"),
("===invalid", "invalid"),
],
)
def test_invalid_spec(self, spec, version):
spec = Specifier(spec, prereleases=True)
assert not spec.contains(version)

@pytest.mark.parametrize(
(
"specifier",
Expand Down Expand Up @@ -645,6 +656,9 @@ def test_specifiers_prereleases(
(">=1.0", False, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", True, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", False, False, ["1.0", "2.0a1"], ["1.0"]),
# Test that invalid versions are discarded
(">=1.0", None, None, ["not a valid version"], []),
(">=1.0", None, None, ["1.0", "not a valid version"], ["1.0"]),
],
)
def test_specifier_filter(
Expand Down Expand Up @@ -960,6 +974,13 @@ def test_specifier_contains_installed_prereleases(
(">=1.0,<=2.0", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0dev", True, False, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
# Test that invalid versions are discarded
("", None, None, ["invalid version"], []),
("", None, False, ["invalid version"], []),
("", False, None, ["invalid version"], []),
("", None, None, ["1.0", "invalid version"], ["1.0"]),
("", None, False, ["1.0", "invalid version"], ["1.0"]),
("", False, None, ["1.0", "invalid version"], ["1.0"]),
],
)
def test_specifier_filter(
Expand Down Expand Up @@ -1332,6 +1353,16 @@ def test_contains_exclusionary_bridges(
kwargs = {"prereleases": prereleases} if prereleases is not None else {}
assert spec.contains(version, **kwargs) == expected

@pytest.mark.parametrize(
("specifier", "input"),
[
(">=1.0", "not a valid version"),
],
)
def test_contains_rejects_invalid_specifier(self, specifier, input):
spec = SpecifierSet(specifier, prereleases=True)
assert not spec.contains(input)

@pytest.mark.parametrize(
("specifier", "expected"),
[
Expand Down
Loading