Skip to content

Simplify scale #3351

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 33 commits into from
May 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
20b5f14
Avoid parallel numba within dask
flying-sheep Oct 24, 2024
4ba3b21
restore zappy compat
flying-sheep Oct 24, 2024
7fdeda1
only do it in tests
flying-sheep Oct 25, 2024
726b28f
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 7, 2024
f5eaa12
relnote
flying-sheep Nov 7, 2024
ebada2f
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 8, 2024
ae8bd60
Merge branch 'main' into fix-clip-dask-sparse
flying-sheep Nov 11, 2024
bbbf3f4
Simplify scale implementation
flying-sheep Nov 11, 2024
4008a28
Merge branch 'main' into simplify-scale
flying-sheep Nov 11, 2024
f57a9b0
Fix merge
flying-sheep Nov 11, 2024
9e9f63f
Merge branch 'main' into simplify-scale
flying-sheep Feb 17, 2025
cdb1c87
Rename 3317.bugfix.md to 3351.bugfix.md
flying-sheep Feb 17, 2025
3f59f8b
reintroduce numba helper
flying-sheep Feb 17, 2025
07e8af2
oops
flying-sheep Feb 17, 2025
18a58d7
use typevar
flying-sheep Feb 17, 2025
208bec6
Merge branch 'main' into simplify-scale
flying-sheep Mar 14, 2025
e253f8a
Merge branch 'main' into simplify-scale
flying-sheep Mar 17, 2025
1d71b1b
add mask for csr
Intron7 Apr 8, 2025
d7c21a3
make this into a suggestion
Intron7 Apr 8, 2025
bd122d1
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
04815d2
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
08112b9
add default mask for CSR
flying-sheep Apr 10, 2025
a8d13d7
do tests properly
flying-sheep Apr 10, 2025
b0716e2
can only use shortcut when not zero-centering
flying-sheep Apr 10, 2025
66ca664
msg
flying-sheep Apr 10, 2025
efd1130
fix test
flying-sheep Apr 10, 2025
04ff943
Merge branch 'main' into simplify-scale
flying-sheep Apr 10, 2025
a428274
Merge branch 'main' into simplify-scale
flying-sheep Apr 14, 2025
9ccabe8
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Apr 14, 2025
233d8a3
oops
flying-sheep Apr 14, 2025
bea7c01
Merge branch 'main' into simplify-scale
flying-sheep May 13, 2025
da8abff
Merge branch 'main' into simplify-scale
flying-sheep May 15, 2025
a5caa39
fix tests
flying-sheep May 15, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/release-notes/3351.bugfix.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fix zappy compatibility for clip_array {smaller}`P Angerer`
14 changes: 11 additions & 3 deletions src/scanpy/_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@
_MemoryArray = NDArray | CSBase
_SupportedArray = _MemoryArray | DaskArray

_SA = TypeVar("_SA", bound=_SupportedArray)

_ForT = TypeVar("_ForT", bound=Callable | type)


Expand Down Expand Up @@ -271,9 +273,7 @@ def get_igraph_from_adjacency(adjacency: CSBase, *, directed: bool = False) -> G
import igraph as ig

sources, targets = adjacency.nonzero()
weights = adjacency[sources, targets]
if isinstance(weights, np.matrix):
weights = weights.A1
weights = dematrix(adjacency[sources, targets]).ravel()
g = ig.Graph(directed=directed)
g.add_vertices(adjacency.shape[0]) # this adds adjacency.shape[0] vertices
g.add_edges(list(zip(sources, targets, strict=True)))
Expand Down Expand Up @@ -750,6 +750,14 @@ def _check_nonnegative_integers_dask(X: DaskArray) -> DaskArray:
return X.map_blocks(check_nonnegative_integers, dtype=bool, drop_axis=(0, 1))


def dematrix(x: _SA | np.matrix) -> _SA:
if isinstance(x, np.matrix):
return x.A
if isinstance(x, DaskArray) and isinstance(x._meta, np.matrix):
return x.map_blocks(np.asarray, meta=np.array([], dtype=x.dtype))
return x


def select_groups(
adata: AnnData,
groups_order_subset: Iterable[str] | Literal["all"] = "all",
Expand Down
15 changes: 6 additions & 9 deletions src/scanpy/preprocessing/_deprecated/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from scipy import sparse

from ..._compat import CSBase, old_positionals
from ..._utils import dematrix


@old_positionals("max_fraction", "mult_with_mean")
Expand Down Expand Up @@ -39,15 +40,11 @@ def normalize_per_cell_weinreb16_deprecated(
msg = "Choose max_fraction between 0 and 1."
raise ValueError(msg)

counts_per_cell = x.sum(1)
if isinstance(counts_per_cell, np.matrix):
counts_per_cell = counts_per_cell.A1
gene_subset = np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0)
if isinstance(gene_subset, np.matrix):
gene_subset = gene_subset.A1
tc_include = x[:, gene_subset].sum(1)
if isinstance(tc_include, np.matrix):
tc_include = tc_include.A1
counts_per_cell = dematrix(x.sum(1)).ravel()
gene_subset = dematrix(
np.all(x <= counts_per_cell[:, None] * max_fraction, axis=0)
).ravel()
tc_include = dematrix(x[:, gene_subset].sum(1)).ravel()

x_norm = (
x.multiply(sparse.csr_matrix(1 / tc_include[:, None])) # noqa: TID251
Expand Down
8 changes: 2 additions & 6 deletions src/scanpy/preprocessing/_normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

from .. import logging as logg
from .._compat import CSBase, DaskArray, old_positionals
from .._utils import axis_mul_or_truediv, view_to_actual
from .._utils import axis_mul_or_truediv, dematrix, view_to_actual
from ..get import _get_obs_rep, _set_obs_rep

try:
Expand Down Expand Up @@ -224,11 +224,7 @@ def normalize_total( # noqa: PLR0912, PLR0915
counts_per_cell = stats.sum(x, axis=1)
if exclude_highly_expressed:
# at least one cell as more than max_fraction of counts per cell
hi_exp = x > counts_per_cell[:, None] * max_fraction
if isinstance(hi_exp, np.matrix):
hi_exp = hi_exp.A
elif isinstance(hi_exp, DaskArray) and isinstance(hi_exp._meta, np.matrix):
hi_exp = hi_exp.map_blocks(np.asarray, meta=np.array([], dtype=x.dtype))
hi_exp = dematrix(x > counts_per_cell[:, None] * max_fraction)
gene_subset = stats.sum(hi_exp, axis=0) == 0

msg += (
Expand Down
Loading
Loading