Skip to content

Commit edbc1a0

Browse files
committed
Make super fast benchmarks bigger
1 parent 104bf1c commit edbc1a0

File tree

1 file changed

+12
-1
lines changed

1 file changed

+12
-1
lines changed

tests/test_stats.py

+12-1
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,19 @@ def test_stats_benchmark(
207207
axis: Literal[0, 1, None],
208208
dtype: type[np.float32 | np.float64],
209209
) -> None:
210+
density = 0.1 if func is stats.is_constant and is_along_major_axis(array_type, axis) else 0.01
211+
210212
shape = (10_000, 10_000) if "sparse" in array_type.mod else (1000, 1000)
211-
arr = array_type.random(shape, dtype=dtype)
213+
arr = array_type.random(shape, dtype=dtype, density=density)
212214

213215
func(arr, axis=axis) # warmup: numba compile
214216
benchmark(func, arr, axis=axis)
217+
218+
219+
def is_along_major_axis(array_type: ArrayType[Any], axis: Literal[0, 1, None]) -> bool:
220+
if axis is None or not (array_type.flags & Flags.Sparse):
221+
return False
222+
cls = array_type.inner.cls if array_type.inner else array_type.cls
223+
if not issubclass(cls, types.CSBase):
224+
return False
225+
return (axis == 1) is (cls.format == "csr")

0 commit comments

Comments
 (0)