Skip to content

SpecifierSet.filter should allow pre-release when no final version matches specifiers, same as Specifier.filter #872

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
34 changes: 29 additions & 5 deletions src/packaging/specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -962,7 +962,7 @@ def filter(
>>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.3", Version("1.4")]))
['1.3', <Version('1.4')>]
>>> list(SpecifierSet(">=1.2.3").filter(["1.2", "1.5a1"]))
[]
["1.5a1"]
>>> list(SpecifierSet(">=1.2.3").filter(["1.3", "1.5a1"], prereleases=True))
['1.3', '1.5a1']
>>> list(SpecifierSet(">=1.2.3", prereleases=True).filter(["1.3", "1.5a1"]))
Expand All @@ -980,6 +980,11 @@ def filter(
>>> list(SpecifierSet("").filter(["1.3", "1.5a1"], prereleases=True))
['1.3', '1.5a1']
"""
# Allow a fallback to prereleases=True under the following conditions:
# - prereleases was not passed in this call
# - prereleases was not passed in the constructor
prereleases_fallback = prereleases is None and self._prereleases is None

# Determine if we're forcing a prerelease or not, if we're not forcing
# one for this particular filter call, then we'll use whatever the
# SpecifierSet thinks for whether or not we should support prereleases.
Expand All @@ -990,9 +995,28 @@ def filter(
# filter method for each one, this will act as a logical AND amongst
# each specifier.
if self._specs:
current_iter = iterable
for spec in self._specs:
iterable = spec.filter(iterable, prereleases=bool(prereleases))
return iter(iterable)
current_iter = spec.filter(current_iter, prereleases=bool(prereleases))

# If prereleases is True there is no need for fallback logic.
if not prereleases_fallback or prereleases is True:
yield from current_iter
else:
# If prereleases was not explicitly set, we need to do a similar
# check to Specifier.filter to see if any final releases are yielded.
yielded = False
for version in current_iter:
yield version
yielded = True

# Fall back to prereleases if no final releases were found.
if not yielded:
fallback_iter = iterable
for spec in self._specs:
fallback_iter = spec.filter(fallback_iter, prereleases=True)
yield from fallback_iter

# If we do not have any specifiers, then we need to have a rough filter
# which will filter out any pre-releases, unless there are no final
# releases.
Expand All @@ -1014,6 +1038,6 @@ def filter(
# If we've found no items except for pre-releases, then we'll go
# ahead and use the pre-releases
if not filtered and found_prereleases and prereleases is None:
return iter(found_prereleases)
yield from found_prereleases

return iter(filtered)
yield from filtered
69 changes: 57 additions & 12 deletions tests/test_specifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -559,20 +559,37 @@ def test_specifiers_prereleases(self, specifier, version, expected):
assert version in spec

@pytest.mark.parametrize(
("specifier", "prereleases", "input", "expected"),
("specifier", "specifier_prereleases", "prereleases", "input", "expected"),
[
(">=1.0", None, ["2.0a1"], ["2.0a1"]),
(">=1.0.dev1", None, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0.dev1", False, ["1.0", "2.0a1"], ["1.0"]),
("!=2.0a1", None, ["1.0a2", "1.0", "2.0a1"], ["1.0"]),
("==2.0a1", None, ["2.0a1"], ["2.0a1"]),
(">2.0a1", None, ["2.0a1", "3.0a2", "3.0"], ["3.0a2", "3.0"]),
("<2.0a1", None, ["1.0a2", "1.0", "2.0a1"], ["1.0a2", "1.0"]),
("~=2.0a1", None, ["1.0", "2.0a1", "3.0a2", "3.0"], ["2.0a1"]),
# General test of the filter method
(">=1.0.dev1", None, None, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.2.3", None, None, ["1.2", "1.5a1"], ["1.5a1"]),
(">=1.2.3", None, None, ["1.3", "1.5a1"], ["1.3"]),
(">=1.0", None, None, ["2.0a1"], ["2.0a1"]),
("!=2.0a1", None, None, ["1.0a2", "1.0", "2.0a1"], ["1.0"]),
("==2.0a1", None, None, ["2.0a1"], ["2.0a1"]),
(">2.0a1", None, None, ["2.0a1", "3.0a2", "3.0"], ["3.0a2", "3.0"]),
("<2.0a1", None, None, ["1.0a2", "1.0", "2.0a1"], ["1.0a2", "1.0"]),
("~=2.0a1", None, None, ["1.0", "2.0a1", "3.0a2", "3.0"], ["2.0a1"]),
# Test overriding with the prereleases parameter on filter
(">=1.0.dev1", None, False, ["1.0", "2.0a1"], ["1.0"]),
# Test overriding with the overall specifier
(">=1.0.dev1", True, None, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0.dev1", False, None, ["1.0", "2.0a1"], ["1.0"]),
# Test when both specifier and filter have prerelease value
(">=1.0", True, False, ["1.0", "2.0a1"], ["1.0"]),
(">=1.0", False, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", True, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", False, False, ["1.0", "2.0a1"], ["1.0"]),
],
)
def test_specifier_filter(self, specifier, prereleases, input, expected):
spec = Specifier(specifier)
def test_specifier_filter(
self, specifier, specifier_prereleases, prereleases, input, expected
):
if specifier_prereleases is None:
spec = Specifier(specifier)
else:
spec = Specifier(specifier, prereleases=specifier_prereleases)

kwargs = {"prereleases": prereleases} if prereleases is not None else {}

Expand Down Expand Up @@ -717,7 +734,15 @@ def test_specifier_contains_installed_prereleases(self):
("", None, None, ["1.0", "2.0a1"], ["1.0"]),
(">=1.0.dev1", None, None, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
("", None, None, ["1.0a1"], ["1.0a1"]),
(">=1.2.3", None, None, ["1.2", "1.5a1"], ["1.5a1"]),
(">=1.2.3", None, None, ["1.3", "1.5a1"], ["1.3"]),
("", None, None, ["1.0", Version("2.0")], ["1.0", Version("2.0")]),
(">=1.0", None, None, ["2.0a1"], ["2.0a1"]),
("!=2.0a1", None, None, ["1.0a2", "1.0", "2.0a1"], ["1.0"]),
("==2.0a1", None, None, ["2.0a1"], ["2.0a1"]),
(">2.0a1", None, None, ["2.0a1", "3.0a2", "3.0"], ["3.0a2", "3.0"]),
("<2.0a1", None, None, ["1.0a2", "1.0", "2.0a1"], ["1.0a2", "1.0"]),
("~=2.0a1", None, None, ["1.0", "2.0a1", "3.0a2", "3.0"], ["2.0a1"]),
# Test overriding with the prereleases parameter on filter
("", None, False, ["1.0a1"], []),
(">=1.0.dev1", None, False, ["1.0", "2.0a1"], ["1.0"]),
Expand All @@ -729,10 +754,30 @@ def test_specifier_contains_installed_prereleases(self):
(">=1.0.dev1", False, None, ["1.0", "2.0a1"], ["1.0"]),
("", True, None, ["1.0a1"], ["1.0a1"]),
("", False, None, ["1.0a1"], []),
# Test when both specifier and filter have prerelease value
(">=1.0", True, False, ["1.0", "2.0a1"], ["1.0"]),
(">=1.0", False, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", True, True, ["1.0", "2.0a1"], ["1.0", "2.0a1"]),
(">=1.0", False, False, ["1.0", "2.0a1"], ["1.0"]),
# Test when there are multiple specifiers
(">=1.0,<=2.0", None, None, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", None, None, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0", True, None, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0", False, None, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", False, None, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", True, None, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0", None, False, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0", None, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0dev", None, False, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", None, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0", True, False, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
(">=1.0,<=2.0dev", True, False, ["1.0", "1.5a1"], ["1.0"]),
(">=1.0,<=2.0dev", False, True, ["1.0", "1.5a1"], ["1.0", "1.5a1"]),
],
)
def test_specifier_filter(
self, specifier_prereleases, specifier, prereleases, input, expected
self, specifier, specifier_prereleases, prereleases, input, expected
):
if specifier_prereleases is None:
spec = SpecifierSet(specifier)
Expand Down
Loading