Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions hypothesis-python/RELEASE.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
RELEASE_TYPE: patch

Improve the thread-safety of strategy validation.

Before this release, Hypothesis did not require that ``super().__init__()`` be called in ``SearchStrategy`` subclasses. Subclassing ``SearchStrategy`` is not supported or part of the public API, but if you are subclassing it anyway, you will need to make sure to call ``super().__init__()`` after this release.
1 change: 1 addition & 0 deletions hypothesis-python/src/hypothesis/extra/_array_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,6 +632,7 @@ def __init__(
allow_newaxis,
allow_fewer_indices_than_dims,
):
super().__init__()
self.shape = shape
self.min_dims = min_dims
self.max_dims = max_dims
Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/src/hypothesis/extra/array_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,7 @@ class ArrayStrategy(st.SearchStrategy):
def __init__(
self, *, xp, api_version, elements_strategy, dtype, shape, fill, unique
):
super().__init__()
self.xp = xp
self.elements_strategy = elements_strategy
self.dtype = dtype
Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/src/hypothesis/extra/lark.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def __init__(
explicit: dict[str, st.SearchStrategy[str]],
alphabet: st.SearchStrategy[str],
) -> None:
super().__init__()
assert isinstance(grammar, lark.lark.Lark)
start: list[str] = grammar.options.start if start is None else [start]

Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/src/hypothesis/extra/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def compat_kw(*args, **kw):

class ArrayStrategy(st.SearchStrategy):
def __init__(self, element_strategy, shape, dtype, fill, unique):
super().__init__()
self.shape = tuple(shape)
self.fill = fill
self.array_size = int(np.prod(shape))
Expand Down
2 changes: 2 additions & 0 deletions hypothesis-python/src/hypothesis/stateful.py
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,7 @@ def __hash__(self):

class BundleReferenceStrategy(SearchStrategy):
def __init__(self, name: str, *, consume: bool = False):
super().__init__()
self.name = name
self.consume = consume

Expand Down Expand Up @@ -582,6 +583,7 @@ class MyStateMachine(RuleBasedStateMachine):
def __init__(
self, name: str, *, consume: bool = False, draw_references: bool = True
) -> None:
super().__init__()
self.name = name
self.__reference_strategy = BundleReferenceStrategy(name, consume=consume)
self.draw_references = draw_references
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,7 @@ def __init__(
*,
optional: Optional[dict[Any, SearchStrategy[Any]]],
):
super().__init__()
dict_type = type(mapping)
self.mapping = mapping
keys = tuple(mapping.keys())
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1029,6 +1029,7 @@ def __init__(
args: tuple[SearchStrategy[Any], ...],
kwargs: dict[str, SearchStrategy[Any]],
):
super().__init__()
self.target = target
self.args = args
self.kwargs = kwargs
Expand Down Expand Up @@ -1070,7 +1071,7 @@ def do_draw(self, data: ConjectureData) -> Ex:
current_build_context().record_call(obj, self.target, args, kwargs)
return obj

def validate(self) -> None:
def do_validate(self) -> None:
tuples(*self.args).validate()
fixed_dictionaries(self.kwargs).validate()

Expand Down Expand Up @@ -1798,6 +1799,7 @@ def recursive(

class PermutationStrategy(SearchStrategy):
def __init__(self, values):
super().__init__()
self.values = values

def do_draw(self, data):
Expand Down Expand Up @@ -1828,6 +1830,7 @@ def permutations(values: Sequence[T]) -> SearchStrategy[list[T]]:

class CompositeStrategy(SearchStrategy):
def __init__(self, definition, args, kwargs):
super().__init__()
self.definition = definition
self.args = args
self.kwargs = kwargs
Expand Down Expand Up @@ -2173,6 +2176,7 @@ def uuids(

class RunnerStrategy(SearchStrategy):
def __init__(self, default):
super().__init__()
self.default = default

def do_draw(self, data):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ def draw_capped_multipart(

class DatetimeStrategy(SearchStrategy):
def __init__(self, min_value, max_value, timezones_strat, allow_imaginary):
super().__init__()
assert isinstance(min_value, dt.datetime)
assert isinstance(max_value, dt.datetime)
assert min_value.tzinfo is None
Expand Down Expand Up @@ -219,6 +220,7 @@ def datetimes(

class TimeStrategy(SearchStrategy):
def __init__(self, min_value, max_value, timezones_strat):
super().__init__()
self.min_value = min_value
self.max_value = max_value
self.tz_strat = timezones_strat
Expand Down Expand Up @@ -257,6 +259,7 @@ def times(

class DateStrategy(SearchStrategy):
def __init__(self, min_value, max_value):
super().__init__()
assert isinstance(min_value, dt.date)
assert isinstance(max_value, dt.date)
assert min_value < max_value
Expand Down Expand Up @@ -320,6 +323,7 @@ def dates(

class TimedeltaStrategy(SearchStrategy):
def __init__(self, min_value, max_value):
super().__init__()
assert isinstance(min_value, dt.timedelta)
assert isinstance(max_value, dt.timedelta)
assert min_value < max_value
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@

class IntegersStrategy(SearchStrategy[int]):
def __init__(self, start: Optional[int], end: Optional[int]) -> None:
super().__init__()
assert isinstance(start, int) or start is None
assert isinstance(end, int) or end is None
assert start is None or end is None or start <= end
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def setstate(self, state):

class RandomStrategy(SearchStrategy[HypothesisRandom]):
def __init__(self, *, note_method_calls: bool, use_true_random: bool) -> None:
super().__init__()
self.__note_method_calls = note_method_calls
self.__use_true_random = use_true_random

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def capped(self, max_templates):

class RecursiveStrategy(SearchStrategy):
def __init__(self, base, extend, max_leaves):
super().__init__()
self.max_leaves = max_leaves
self.base = base
self.limited_base = LimitedStrategy(base)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

class SharedStrategy(SearchStrategy[Ex]):
def __init__(self, base: SearchStrategy[Ex], key: Optional[Hashable] = None):
super().__init__()
self.key = key
self.base = base

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# obtain one at https://mozilla.org/MPL/2.0/.

import sys
import threading
import warnings
from collections import abc, defaultdict
from collections.abc import Sequence
Expand Down Expand Up @@ -228,9 +229,15 @@ class SearchStrategy(Generic[Ex]):
``builds(Foo, ...)``. Do not inherit from or directly instantiate this class.
"""

validate_called: bool = False
__label: Union[int, UniqueIdentifier, None] = None
__module__: str = "hypothesis.strategies"
LABELS: ClassVar[dict[type, int]] = {}
# triggers `assert isinstance(label, int)` under threading when setting this
# in init instead of a classvar. I'm not sure why, init should be safe. But
# this works so I'm not looking into it further atm.
__label: Union[int, UniqueIdentifier, None] = None

def __init__(self):
self.validate_called: dict[int, bool] = {}

def _available(self, data: ConjectureData) -> bool:
"""Returns whether this strategy can *currently* draw any
Expand Down Expand Up @@ -477,21 +484,39 @@ def __bool__(self) -> bool:
def validate(self) -> None:
"""Throw an exception if the strategy is not valid.

This can happen due to lazy construction
Strategies should implement ``do_validate``, which is called by this
method. They should not override ``validate``.

This can happen due to invalid arguments, or lazy construction.
"""
if self.validate_called:
thread_id = threading.get_ident()
if self.validate_called.get(thread_id, False):
return
# we need to set validate_called before calling do_validate, for
# recursive / deferred strategies. But if a thread switches after
# validate_called but before do_validate, we might have a strategy
# which does weird things like drawing when do_validate would error but
# its params are technically valid (e.g. a param was passed as 1.0
# instead of 1) and get into weird internal states.
#
# There are two ways to fix this.
# (1) The first is a per-strategy lock around do_validate. Even though we
# expect near-zero lock contention, this still adds the lock overhead.
# (2) The second is allowing concurrent .validate calls. Since validation
# is (assumed to be) deterministic, both threads will produce the same
# end state, so the validation order or race conditions does not matter.
#
# In order to avoid the lock overhead of (1), we use (2) here. See also
# discussion in https://github.com/HypothesisWorks/hypothesis/pull/4473.
try:
self.validate_called = True
self.validate_called[thread_id] = True
self.do_validate()
self.is_empty
self.has_reusable_values
except Exception:
self.validate_called = False
self.validate_called[thread_id] = False
raise

LABELS: ClassVar[dict[type, int]] = {}

@property
def class_label(self) -> int:
cls = self.__class__
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class OneCharStringStrategy(SearchStrategy[str]):
def __init__(
self, intervals: IntervalSet, force_repr: Optional[str] = None
) -> None:
super().__init__()
assert isinstance(intervals, IntervalSet)
self.intervals = intervals
self._force_repr = force_repr
Expand Down Expand Up @@ -349,6 +350,7 @@ def _identifier_characters():

class BytesStrategy(SearchStrategy):
def __init__(self, min_size: int, max_size: Optional[int]):
super().__init__()
self.min_size = min_size
self.max_size = (
max_size if max_size is not None else COLLECTION_DEFAULT_MAX_SIZE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -997,6 +997,7 @@ def resolve_Match(thing):

class GeneratorStrategy(st.SearchStrategy):
def __init__(self, yields, returns):
super().__init__()
assert isinstance(yields, st.SearchStrategy)
assert isinstance(returns, st.SearchStrategy)
self.yields = yields
Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/tests/common/strategies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def do_draw(self, data):

class HardToShrink(SearchStrategy):
def __init__(self):
super().__init__()
self.__last = None
self.accepted = set()

Expand Down
1 change: 1 addition & 0 deletions hypothesis-python/tests/nocover/test_duplication.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

class Blocks(SearchStrategy):
def __init__(self, n):
super().__init__()
self.n = n

def do_draw(self, data):
Expand Down
27 changes: 27 additions & 0 deletions hypothesis-python/tests/nocover/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,12 @@
import pytest

from hypothesis import given, settings, strategies as st
from hypothesis.errors import InvalidArgument
from hypothesis.internal.conjecture.junkdrawer import ensure_free_stackframes
from hypothesis.stateful import RuleBasedStateMachine, invariant, rule

from tests.common.debug import check_can_generate_examples

pytestmark = pytest.mark.skipif(
settings._current_profile == "crosshair", reason="crosshair is not thread safe"
)
Expand Down Expand Up @@ -125,3 +128,27 @@ def test(n):
thread.join()

assert sys.getrecursionlimit() == original_recursionlimit


@pytest.mark.parametrize(
"strategy",
[
st.recursive(st.none(), st.lists, max_leaves=-1),
st.recursive(st.none(), st.lists, max_leaves=0),
st.recursive(st.none(), st.lists, max_leaves=1.0),
],
)
def test_handles_invalid_args_cleanly(strategy):
# we previously had a race in SearchStrategy.validate, where one thread would
# set `validate_called = True` (which it has to do first for recursive
# strategies), then another thread would try to generate before the validation
# finished and errored, and would get into weird technically-valid states
# like interpreting 1.0 as 1. I saw FlakyStrategyDefinition here because the
# validating + errored thread drew zero choices, but the other thread drew
# 1 choice, for the same shared strategy.

def check():
with pytest.raises(InvalidArgument):
check_can_generate_examples(strategy)

run_concurrently(check, n=4)
3 changes: 3 additions & 0 deletions website/content/2016-12-10-how-hypothesis-works.md
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,7 @@ For example, suppose we tried to implement lists as follows:
```python
class ListStrategy(SearchStrategy):
def __init__(self, elements):
super().__init__()
self.elements = elements

def do_draw(self, data):
Expand Down Expand Up @@ -191,6 +192,7 @@ this as follows:
```python
class ListStrategy(SearchStrategy):
def __init__(self, elements):
super().__init__()
self.elements = elements

def do_draw(self, data):
Expand Down Expand Up @@ -388,6 +390,7 @@ class SearchStrategy:

class FlatmappedStrategy(SearchStrategy):
def __init__(self, base, bind):
super().__init__()
self.base = base
self.bind = bind

Expand Down