Skip to content

Commit c7325e3

Browse files
committed
Fix for old scipy
1 parent eea1248 commit c7325e3

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

src/fast_array_utils/stats/_generic_ops.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,8 @@ def _generic_op_cs(
9090
# convert to array so dimensions collapse as expected
9191
x = (sp.csr_array if x.format == "csr" else sp.csc_array)(x, **_dtype_kw(dtype, op)) # type: ignore[arg-type]
9292
rv = cast("NDArray[Any] | types.coo_array | np.number[Any]", getattr(x, op)(axis=axis))
93-
return rv.toarray() if isinstance(rv, types.coo_array) else rv
93+
# old scipy versions’ sparray.{max,min}() return a 1×n/n×1 sparray here, so we squeeze
94+
return rv.toarray().squeeze() if isinstance(rv, types.coo_array) else rv
9495

9596

9697
@generic_op.register(types.DaskArray)

src/fast_array_utils/stats/_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ def _get_shape(a: NDArray[Any] | np.number[Any] | types.CupyArray, *, axis: Lite
110110
assert axis is not None
111111
return (1, a.size) if axis == 0 else (a.size, 1)
112112
case _: # pragma: no cover
113-
msg = f"{keepdims=}, {type(a)}"
113+
msg = f"{keepdims=}, {a.ndim=}, {type(a)=}"
114114
raise AssertionError(msg)
115115

116116

0 commit comments

Comments
 (0)