Skip to content

Commit a5caa39

Browse files
committed
fix tests
1 parent da8abff commit a5caa39

File tree

5 files changed

+23
-22
lines changed

5 files changed

+23
-22
lines changed

src/scanpy/_utils/__init__.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,8 @@
6060
_MemoryArray = NDArray | CSBase
6161
_SupportedArray = _MemoryArray | DaskArray
6262

63+
_SA = TypeVar("_SA", bound=_SupportedArray)
64+
6365
_ForT = TypeVar("_ForT", bound=Callable | type)
6466

6567

@@ -271,9 +273,7 @@ def get_igraph_from_adjacency(adjacency: CSBase, *, directed: bool = False) -> G
271273
import igraph as ig
272274

273275
sources, targets = adjacency.nonzero()
274-
weights = adjacency[sources, targets]
275-
if isinstance(weights, np.matrix):
276-
weights = weights.A1
276+
weights = dematrix(adjacency[sources, targets]).ravel()
277277
g = ig.Graph(directed=directed)
278278
g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices
279279
g.add_edges(list(zip(sources, targets, strict=True)))
@@ -750,6 +750,14 @@ def _check_nonnegative_integers_dask(X: DaskArray) -> DaskArray:
750750
return X.map_blocks(check_nonnegative_integers, dtype=bool, drop_axis=(0, 1))
751751

752752

753+
def dematrix(x: _SA | np.matrix) -> _SA:
754+
if isinstance(x, np.matrix):
755+
return x.A
756+
if isinstance(x, DaskArray) and isinstance(x._meta, np.matrix):
757+
return x.map_blocks(np.asarray, meta=np.array([], dtype=x.dtype))
758+
return x
759+
760+
753761
def select_groups(
754762
adata: AnnData,
755763
groups_order_subset: Iterable[str] | Literal["all"] = "all",

src/scanpy/preprocessing/_deprecated/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from scipy import sparse
55

66
from ..._compat import CSBase, old_positionals
7+
from ..._utils import dematrix
78

89

910
@old_positionals("max_fraction", "mult_with_mean")
@@ -39,15 +40,11 @@ def normalize_per_cell_weinreb16_deprecated(
3940
msg = "Choose max_fraction between 0 and 1."
4041
raise ValueError(msg)
4142

42-
counts_per_cell = x.sum(1)
43-
if isinstance(counts_per_cell, np.matrix):
44-
counts_per_cell = counts_per_cell.A1
45-
gene_subset = np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0)
46-
if isinstance(gene_subset, np.matrix):
47-
gene_subset = gene_subset.A1
48-
tc_include = x[:, gene_subset].sum(1)
49-
if isinstance(tc_include, np.matrix):
50-
tc_include = tc_include.A1
43+
counts_per_cell = dematrix(x.sum(1)).ravel()
44+
gene_subset = dematrix(
45+
np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0)
46+
).ravel()
47+
tc_include = dematrix(x[:, gene_subset].sum(1)).ravel()
5148

5249
x_norm = (
5350
x.multiply(sparse.csr_matrix(1 / tc_include[:, None])) # noqa: TID251

src/scanpy/preprocessing/_normalization.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .. import logging as logg
1111
from .._compat import CSBase, DaskArray, old_positionals
12-
from .._utils import axis_mul_or_truediv, view_to_actual
12+
from .._utils import axis_mul_or_truediv, dematrix, view_to_actual
1313
from ..get import _get_obs_rep, _set_obs_rep
1414

1515
try:
@@ -224,11 +224,7 @@ def normalize_total( # noqa: PLR0912, PLR0915
224224
counts_per_cell = stats.sum(x, axis=1)
225225
if exclude_highly_expressed:
226226
# at least one cell as more than max_fraction of counts per cell
227-
hi_exp = x > counts_per_cell[:, None] * max_fraction
228-
if isinstance(hi_exp, np.matrix):
229-
hi_exp = hi_exp.A
230-
elif isinstance(hi_exp, DaskArray) and isinstance(hi_exp._meta, np.matrix):
231-
hi_exp = hi_exp.map_blocks(np.asarray, meta=np.array([], dtype=x.dtype))
227+
hi_exp = dematrix(x > counts_per_cell[:, None] * max_fraction)
232228
gene_subset = stats.sum(hi_exp, axis=0) == 0
233229

234230
msg += (

src/scanpy/preprocessing/_scale.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from .._utils import (
1616
_check_array_function_arguments,
1717
axis_mul_or_truediv,
18+
dematrix,
1819
raise_not_implemented_error_if_backed_type,
1920
renamed_arg,
2021
view_to_actual,
@@ -202,6 +203,7 @@ def scale_array(
202203
msg = "zero-centering a sparse array/matrix densifies it."
203204
warnings.warn(msg, UserWarning, stacklevel=2)
204205
x -= mean
206+
x = dematrix(x)
205207

206208
x = axis_mul_or_truediv(
207209
x,

src/scanpy/tools/_louvain.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
from .. import _utils
1313
from .. import logging as logg
1414
from .._compat import old_positionals
15-
from .._utils import _choose_graph
15+
from .._utils import _choose_graph, dematrix
1616
from ._utils_clustering import rename_groups, restrict_adjacency
1717

1818
if TYPE_CHECKING:
@@ -203,9 +203,7 @@ def louvain( # noqa: PLR0912, PLR0913, PLR0915
203203
indices = cudf.Series(adjacency.indices)
204204
if use_weights:
205205
sources, targets = adjacency.nonzero()
206-
weights = adjacency[sources, targets]
207-
if isinstance(weights, np.matrix):
208-
weights = weights.A1
206+
weights = dematrix(adjacency[sources, targets]).ravel()
209207
weights = cudf.Series(weights)
210208
else:
211209
weights = None

0 commit comments

Comments
 (0)