Skip to content

Commit ffdf391

Browse files
authored
Merge pull request #4263 from tybug/next
Optimizations
2 parents 7cb0989 + d598f6a commit ffdf391

File tree

9 files changed

+98
-110
lines changed

9 files changed

+98
-110
lines changed

hypothesis-python/RELEASE.rst

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
RELEASE_TYPE: patch
2+
3+
Optimize performance (improves speed by ~5%) and clarify the wording in an error message.

hypothesis-python/src/hypothesis/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import warnings
2424
import zlib
2525
from collections import defaultdict
26-
from collections.abc import Coroutine, Generator, Hashable, Sequence
26+
from collections.abc import Coroutine, Generator, Hashable, Iterable, Sequence
2727
from functools import partial
2828
from random import Random
2929
from typing import (
@@ -321,7 +321,7 @@ def accept(test):
321321
return accept
322322

323323

324-
def encode_failure(choices):
324+
def encode_failure(choices: Iterable[ChoiceT]) -> bytes:
325325
blob = choices_to_bytes(choices)
326326
compressed = zlib.compress(blob)
327327
if len(compressed) < len(blob):
@@ -687,7 +687,7 @@ def skip_exceptions_to_reraise():
687687
return tuple(sorted(exceptions, key=str))
688688

689689

690-
def failure_exceptions_to_catch():
690+
def failure_exceptions_to_catch() -> tuple[type[BaseException], ...]:
691691
"""Return a tuple of exceptions meaning 'this test has failed', to catch.
692692
693693
This is intended to cover most common test runners; if you would

hypothesis-python/src/hypothesis/internal/conjecture/data.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -492,8 +492,7 @@ def __len__(self) -> int:
492492
return self.__length
493493

494494
def __getitem__(self, i: int) -> Example:
495-
assert isinstance(i, int)
496-
n = len(self)
495+
n = self.__length
497496
if i < -n or i >= n:
498497
raise IndexError(f"Index {i} out of range [-{n}, {n})")
499498
if i < 0:

hypothesis-python/src/hypothesis/internal/conjecture/junkdrawer.py

Lines changed: 10 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -17,18 +17,9 @@
1717
import sys
1818
import time
1919
import warnings
20+
from array import ArrayType
2021
from collections.abc import Iterable, Iterator, Sequence
21-
from typing import (
22-
Any,
23-
Callable,
24-
Generic,
25-
List,
26-
Literal,
27-
Optional,
28-
TypeVar,
29-
Union,
30-
overload,
31-
)
22+
from typing import Any, Callable, Generic, Literal, Optional, TypeVar, Union, overload
3223

3324
from sortedcontainers import SortedList
3425

@@ -41,7 +32,7 @@
4132

4233
def array_or_list(
4334
code: str, contents: Iterable[int]
44-
) -> "Union[List[int], array.ArrayType[int]]":
35+
) -> Union[list[int], "ArrayType[int]"]:
4536
if code == "O":
4637
return list(contents)
4738
return array.array(code, contents)
@@ -82,7 +73,7 @@ class IntList(Sequence[int]):
8273

8374
__slots__ = ("__underlying",)
8475

85-
__underlying: "Union[List[int], array.ArrayType[int]]"
76+
__underlying: Union[list[int], "ArrayType[int]"]
8677

8778
def __init__(self, values: Sequence[int] = ()):
8879
for code in ARRAY_CODES:
@@ -116,11 +107,13 @@ def __len__(self) -> int:
116107
def __getitem__(self, i: int) -> int: ... # pragma: no cover
117108

118109
@overload
119-
def __getitem__(self, i: slice) -> "IntList": ... # pragma: no cover
110+
def __getitem__(
111+
self, i: slice
112+
) -> Union[list[int], "ArrayType[int]"]: ... # pragma: no cover
120113

121-
def __getitem__(self, i: Union[int, slice]) -> "Union[int, IntList]":
122-
if isinstance(i, slice):
123-
return IntList(self.__underlying[i])
114+
def __getitem__(
115+
self, i: Union[int, slice]
116+
) -> Union[int, list[int], "ArrayType[int]"]:
124117
return self.__underlying[i]
125118

126119
def __delitem__(self, i: Union[int, slice]) -> None:

hypothesis-python/src/hypothesis/internal/conjecture/utils.py

Lines changed: 70 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from collections import OrderedDict, abc
1717
from collections.abc import Sequence
1818
from functools import lru_cache
19-
from typing import TYPE_CHECKING, List, Optional, TypeVar, Union
19+
from typing import TYPE_CHECKING, Optional, TypeVar, Union
2020

2121
from hypothesis.errors import InvalidArgument
2222
from hypothesis.internal.compat import int_from_bytes
@@ -72,7 +72,7 @@ def check_sample(
7272
)
7373
elif not isinstance(values, (OrderedDict, abc.Sequence, enum.EnumMeta)):
7474
raise InvalidArgument(
75-
f"Cannot sample from {values!r}, not an ordered collection. "
75+
f"Cannot sample from {values!r} because it is not an ordered collection. "
7676
f"Hypothesis goes to some length to ensure that the {strategy_name} "
7777
"strategy has stable results between runs. To replay a saved "
7878
"example, the sampled values must have the same iteration order "
@@ -87,6 +87,73 @@ def check_sample(
8787
return tuple(values)
8888

8989

90+
@lru_cache(64)
91+
def compute_sampler_table(weights: tuple[float, ...]) -> list[tuple[int, int, float]]:
92+
n = len(weights)
93+
table: list[list[int | float | None]] = [[i, None, None] for i in range(n)]
94+
total = sum(weights)
95+
num_type = type(total)
96+
97+
zero = num_type(0) # type: ignore
98+
one = num_type(1) # type: ignore
99+
100+
small: list[int] = []
101+
large: list[int] = []
102+
103+
probabilities = [w / total for w in weights]
104+
scaled_probabilities: list[float] = []
105+
106+
for i, alternate_chance in enumerate(probabilities):
107+
scaled = alternate_chance * n
108+
scaled_probabilities.append(scaled)
109+
if scaled == 1:
110+
table[i][2] = zero
111+
elif scaled < 1:
112+
small.append(i)
113+
else:
114+
large.append(i)
115+
heapq.heapify(small)
116+
heapq.heapify(large)
117+
118+
while small and large:
119+
lo = heapq.heappop(small)
120+
hi = heapq.heappop(large)
121+
122+
assert lo != hi
123+
assert scaled_probabilities[hi] > one
124+
assert table[lo][1] is None
125+
table[lo][1] = hi
126+
table[lo][2] = one - scaled_probabilities[lo]
127+
scaled_probabilities[hi] = (
128+
scaled_probabilities[hi] + scaled_probabilities[lo]
129+
) - one
130+
131+
if scaled_probabilities[hi] < 1:
132+
heapq.heappush(small, hi)
133+
elif scaled_probabilities[hi] == 1:
134+
table[hi][2] = zero
135+
else:
136+
heapq.heappush(large, hi)
137+
while large:
138+
table[large.pop()][2] = zero
139+
while small:
140+
table[small.pop()][2] = zero
141+
142+
new_table: list[tuple[int, int, float]] = []
143+
for base, alternate, alternate_chance in table:
144+
assert isinstance(base, int)
145+
assert isinstance(alternate, int) or alternate is None
146+
assert alternate_chance is not None
147+
if alternate is None:
148+
new_table.append((base, base, alternate_chance))
149+
elif alternate < base:
150+
new_table.append((alternate, base, one - alternate_chance))
151+
else:
152+
new_table.append((base, alternate, alternate_chance))
153+
new_table.sort()
154+
return new_table
155+
156+
90157
class Sampler:
91158
"""Sampler based on Vose's algorithm for the alias method. See
92159
http://www.keithschwarz.com/darts-dice-coins/ for a good explanation.
@@ -109,69 +176,7 @@ class Sampler:
109176

110177
def __init__(self, weights: Sequence[float], *, observe: bool = True):
111178
self.observe = observe
112-
113-
n = len(weights)
114-
table: "list[list[int | float | None]]" = [[i, None, None] for i in range(n)]
115-
total = sum(weights)
116-
num_type = type(total)
117-
118-
zero = num_type(0) # type: ignore
119-
one = num_type(1) # type: ignore
120-
121-
small: "List[int]" = []
122-
large: "List[int]" = []
123-
124-
probabilities = [w / total for w in weights]
125-
scaled_probabilities: "List[float]" = []
126-
127-
for i, alternate_chance in enumerate(probabilities):
128-
scaled = alternate_chance * n
129-
scaled_probabilities.append(scaled)
130-
if scaled == 1:
131-
table[i][2] = zero
132-
elif scaled < 1:
133-
small.append(i)
134-
else:
135-
large.append(i)
136-
heapq.heapify(small)
137-
heapq.heapify(large)
138-
139-
while small and large:
140-
lo = heapq.heappop(small)
141-
hi = heapq.heappop(large)
142-
143-
assert lo != hi
144-
assert scaled_probabilities[hi] > one
145-
assert table[lo][1] is None
146-
table[lo][1] = hi
147-
table[lo][2] = one - scaled_probabilities[lo]
148-
scaled_probabilities[hi] = (
149-
scaled_probabilities[hi] + scaled_probabilities[lo]
150-
) - one
151-
152-
if scaled_probabilities[hi] < 1:
153-
heapq.heappush(small, hi)
154-
elif scaled_probabilities[hi] == 1:
155-
table[hi][2] = zero
156-
else:
157-
heapq.heappush(large, hi)
158-
while large:
159-
table[large.pop()][2] = zero
160-
while small:
161-
table[small.pop()][2] = zero
162-
163-
self.table: "list[tuple[int, int, float]]" = []
164-
for base, alternate, alternate_chance in table:
165-
assert isinstance(base, int)
166-
assert isinstance(alternate, int) or alternate is None
167-
assert alternate_chance is not None
168-
if alternate is None:
169-
self.table.append((base, base, alternate_chance))
170-
elif alternate < base:
171-
self.table.append((alternate, base, one - alternate_chance))
172-
else:
173-
self.table.append((base, alternate, alternate_chance))
174-
self.table.sort()
179+
self.table = compute_sampler_table(tuple(weights))
175180

176181
def sample(
177182
self,

hypothesis-python/src/hypothesis/internal/escalation.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import partial
1717
from inspect import getframeinfo
1818
from pathlib import Path
19-
from types import ModuleType
19+
from types import ModuleType, TracebackType
2020
from typing import Callable, NamedTuple, Optional
2121

2222
import hypothesis
@@ -57,7 +57,9 @@ def accept(filepath: str) -> bool:
5757
is_hypothesis_file = belongs_to(hypothesis)
5858

5959

60-
def get_trimmed_traceback(exception=None):
60+
def get_trimmed_traceback(
61+
exception: Optional[BaseException] = None,
62+
) -> Optional[TracebackType]:
6163
"""Return the current traceback, minus any frames added by Hypothesis."""
6264
if exception is None:
6365
_, exception, tb = sys.exc_info()
@@ -67,9 +69,10 @@ def get_trimmed_traceback(exception=None):
6769
# was raised inside Hypothesis. Additionally, the environment variable
6870
# HYPOTHESIS_NO_TRACEBACK_TRIM is respected if nonempty, because verbose
6971
# mode is prohibitively slow when debugging strategy recursion errors.
72+
assert hypothesis.settings.default is not None
7073
if (
7174
tb is None
72-
or os.environ.get("HYPOTHESIS_NO_TRACEBACK_TRIM", None)
75+
or os.environ.get("HYPOTHESIS_NO_TRACEBACK_TRIM")
7376
or hypothesis.settings.default.verbosity >= hypothesis.Verbosity.debug
7477
or (
7578
is_hypothesis_file(traceback.extract_tb(tb)[-1][0])

hypothesis-python/src/hypothesis/internal/reflection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def _clean_source(src: str) -> bytes:
8282
return "\n".join(x.rstrip() for x in src.splitlines() if x.rstrip()).encode()
8383

8484

85-
def function_digest(function):
85+
def function_digest(function: Any) -> bytes:
8686
"""Returns a string that is stable across multiple invocations across
8787
multiple processes and is prone to changing significantly in response to
8888
minor changes to the function.

hypothesis-python/tests/conjecture/test_junkdrawer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,9 @@ def test_int_list_extend():
169169

170170
def test_int_list_slice():
171171
x = IntList([1, 2])
172-
assert x[:1] == IntList([1])
173-
assert x[0:2] == IntList([1, 2])
174-
assert x[1:] == IntList([2])
172+
assert list(x[:1]) == [1]
173+
assert list(x[0:2]) == [1, 2]
174+
assert list(x[1:]) == [2]
175175

176176

177177
def test_int_list_del():

hypothesis-python/tests/nocover/test_conjecture_int_list.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,6 @@ def valid_index(draw):
2323
return draw(st.integers(0, len(machine.model) - 1))
2424

2525

26-
@st.composite
27-
def valid_slice(draw):
28-
machine = draw(st.runner())
29-
result = [
30-
draw(st.integers(0, max(3, len(machine.model) * 2 - 1))) for _ in range(2)
31-
]
32-
result.sort()
33-
return slice(*result)
34-
35-
3626
class IntListRules(RuleBasedStateMachine):
3727
@initialize(ls=st.lists(INTEGERS))
3828
def starting_lists(self, ls):
@@ -52,16 +42,11 @@ def append(self, n):
5242
self.model.append(n)
5343
self.target.append(n)
5444

55-
@rule(i=valid_index() | valid_slice())
45+
@rule(i=valid_index())
5646
def delete(self, i):
5747
del self.model[i]
5848
del self.target[i]
5949

60-
@rule(sl=valid_slice())
61-
def slice(self, sl):
62-
self.model = self.model[sl]
63-
self.target = self.target[sl]
64-
6550
@rule(i=valid_index())
6651
def agree_on_values(self, i):
6752
assert self.model[i] == self.target[i]

0 commit comments

Comments
 (0)