-
Notifications
You must be signed in to change notification settings - Fork 642
Add neighbors_from_distance for computing neighborhood graphs from precomputed distance matrices #3627
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add neighbors_from_distance for computing neighborhood graphs from precomputed distance matrices #3627
Changes from all commits
f76dc7b
f092469
7ffa1ec
c0d0c52
68652a7
948319a
6a64330
793351f
92d8e26
198c4fb
e7fb67a
14cb441
0ce8c15
914b87d
d285203
50705b3
4730667
040b8b7
c03b863
473a437
ec586df
43dcfc0
8a3588c
293f568
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
Added `neighbors_from_distance`, function for computing graphs from a precoputing distance matrix using UMAP or Gaussian methods. {smaller}`A. Karesh` |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,15 +21,17 @@ | |
from .._utils import NeighborsView, _doc_params, get_literal_vals | ||
from . import _connectivity | ||
from ._common import ( | ||
_get_indices_distances_from_dense_matrix, | ||
_get_indices_distances_from_sparse_matrix, | ||
_get_sparse_matrix_from_indices_distances, | ||
) | ||
from ._connectivity import umap | ||
from ._doc import doc_n_pcs, doc_use_rep | ||
from ._types import _KnownTransformer, _Method | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Callable, MutableMapping | ||
from typing import Any, Literal, NotRequired | ||
from typing import Any, Literal, NotRequired, Unpack | ||
|
||
from anndata import AnnData | ||
from igraph import Graph | ||
|
@@ -58,11 +60,18 @@ class KwdsForTransformer(TypedDict): | |
random_state: _LegacyRandom | ||
|
||
|
||
class NeighborsDict(TypedDict): # noqa: D101 | ||
connectivities_key: str | ||
distances_key: str | ||
params: NeighborsParams | ||
rp_forest: NotRequired[RPForestDict] | ||
|
||
|
||
class NeighborsParams(TypedDict): # noqa: D101 | ||
n_neighbors: int | ||
method: _Method | ||
random_state: _LegacyRandom | ||
metric: _Metric | _MetricFn | ||
metric: _Metric | _MetricFn | None | ||
metric_kwds: NotRequired[Mapping[str, Any]] | ||
use_rep: NotRequired[str] | ||
n_pcs: NotRequired[int] | ||
|
@@ -74,11 +83,12 @@ def neighbors( # noqa: PLR0913 | |
n_neighbors: int = 15, | ||
n_pcs: int | None = None, | ||
*, | ||
distances: np.ndarray | SpBase | None = None, | ||
use_rep: str | None = None, | ||
knn: bool = True, | ||
method: _Method = "umap", | ||
transformer: KnnTransformerLike | _KnownTransformer | None = None, | ||
metric: _Metric | _MetricFn = "euclidean", | ||
metric: _Metric | _MetricFn | None = None, | ||
metric_kwds: Mapping[str, Any] = MappingProxyType({}), | ||
random_state: _LegacyRandom = 0, | ||
key_added: str | None = None, | ||
|
@@ -135,6 +145,8 @@ def neighbors( # noqa: PLR0913 | |
Use :func:`rapids_singlecell.pp.neighbors` instead. | ||
metric | ||
A known metric’s name or a callable that returns a distance. | ||
If `distances` is given, this parameter is simply stored in `.uns` (see below), | ||
otherwise defaults to `'euclidean'`. | ||
|
||
*ignored if ``transformer`` is an instance.* | ||
metric_kwds | ||
|
@@ -186,6 +198,20 @@ def neighbors( # noqa: PLR0913 | |
:doc:`/how-to/knn-transformers` | ||
|
||
""" | ||
if distances is not None: | ||
if callable(metric): | ||
msg = "`metric` must be a string if `distances` is given." | ||
raise TypeError(msg) | ||
# if a precomputed distance matrix is provided, skip the PCA and distance computation | ||
return neighbors_from_distance( | ||
adata, | ||
distances, | ||
n_neighbors=n_neighbors, | ||
metric=metric, | ||
method=method, | ||
) | ||
if metric is None: | ||
metric = "euclidean" | ||
start = logg.info("computing neighbors") | ||
adata = adata.copy() if copy else adata | ||
if adata.is_view: # we shouldn't need this here... | ||
|
@@ -203,51 +229,124 @@ def neighbors( # noqa: PLR0913 | |
random_state=random_state, | ||
) | ||
|
||
if key_added is None: | ||
key_added = "neighbors" | ||
conns_key = "connectivities" | ||
dists_key = "distances" | ||
else: | ||
conns_key = key_added + "_connectivities" | ||
dists_key = key_added + "_distances" | ||
|
||
adata.uns[key_added] = {} | ||
|
||
neighbors_dict = adata.uns[key_added] | ||
|
||
neighbors_dict["connectivities_key"] = conns_key | ||
neighbors_dict["distances_key"] = dists_key | ||
|
||
neighbors_dict["params"] = NeighborsParams( | ||
key_added, neighbors_dict = _get_metadata( | ||
key_added, | ||
n_neighbors=neighbors.n_neighbors, | ||
method=method, | ||
random_state=random_state, | ||
metric=metric, | ||
**({} if not metric_kwds else dict(metric_kwds=metric_kwds)), | ||
**({} if use_rep is None else dict(use_rep=use_rep)), | ||
**({} if n_pcs is None else dict(n_pcs=n_pcs)), | ||
) | ||
if metric_kwds: | ||
neighbors_dict["params"]["metric_kwds"] = metric_kwds | ||
if use_rep is not None: | ||
neighbors_dict["params"]["use_rep"] = use_rep | ||
if n_pcs is not None: | ||
neighbors_dict["params"]["n_pcs"] = n_pcs | ||
|
||
adata.obsp[dists_key] = neighbors.distances | ||
adata.obsp[conns_key] = neighbors.connectivities | ||
|
||
if neighbors.rp_forest is not None: | ||
neighbors_dict["rp_forest"] = neighbors.rp_forest | ||
|
||
adata.uns[key_added] = neighbors_dict | ||
adata.obsp[neighbors_dict["distances_key"]] = neighbors.distances | ||
adata.obsp[neighbors_dict["connectivities_key"]] = neighbors.connectivities | ||
|
||
logg.info( | ||
" finished", | ||
time=start, | ||
deep=( | ||
f"added to `.uns[{key_added!r}]`\n" | ||
f" `.obsp[{dists_key!r}]`, distances for each pair of neighbors\n" | ||
f" `.obsp[{conns_key!r}]`, weighted adjacency matrix" | ||
f" `.obsp[{neighbors_dict['distances_key']!r}]`, distances for each pair of neighbors\n" | ||
f" `.obsp[{neighbors_dict['connectivities_key']!r}]`, weighted adjacency matrix" | ||
), | ||
) | ||
return adata if copy else None | ||
|
||
|
||
def neighbors_from_distance( | ||
adata: AnnData, | ||
distances: np.ndarray | SpBase, | ||
*, | ||
n_neighbors: int = 15, | ||
metric: _Metric | None = None, | ||
method: _Method = "umap", # default to umap | ||
key_added: str | None = None, | ||
Comment on lines
+266
to
+269
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please remove the defaults here, and fix the bug that gets uncovered by this action. |
||
) -> AnnData: | ||
"""Compute neighbors from a precomputer distance matrix. | ||
|
||
Parameters | ||
---------- | ||
adata | ||
Annotated data matrix. | ||
distances | ||
Precomputed dense or sparse distance matrix. | ||
n_neighbors | ||
Number of nearest neighbors to use in the graph. | ||
metric | ||
Name of metric used to compute `distances`. | ||
method | ||
Method to use for computing the graph. Currently only `'umap'` is supported. | ||
key_added | ||
Optional key under which to store the results. Default is 'neighbors'. | ||
|
||
Returns | ||
------- | ||
adata | ||
Annotated data with computed distances and connectivities. | ||
""" | ||
if isinstance(distances, SpBase): | ||
distances = sparse.csr_matrix(distances) # noqa: TID251 | ||
distances.setdiag(0) | ||
distances.eliminate_zeros() | ||
else: | ||
distances = np.asarray(distances) | ||
np.fill_diagonal(distances, 0) | ||
|
||
if method == "umap": | ||
if isinstance(distances, CSRBase): | ||
knn_indices, knn_distances = _get_indices_distances_from_sparse_matrix( | ||
distances, n_neighbors | ||
) | ||
else: | ||
knn_indices, knn_distances = _get_indices_distances_from_dense_matrix( | ||
distances, n_neighbors | ||
) | ||
connectivities = umap( | ||
knn_indices, knn_distances, n_obs=adata.n_obs, n_neighbors=n_neighbors | ||
) | ||
elif method == "gauss": | ||
distances = sparse.csr_matrix(distances) # noqa: TID251 | ||
connectivities = _connectivity.gauss(distances, n_neighbors, knn=True) | ||
else: | ||
msg = f"Method {method} not implemented." | ||
raise NotImplementedError(msg) | ||
|
||
key_added, neighbors_dict = _get_metadata( | ||
key_added, | ||
n_neighbors=n_neighbors, | ||
method=method, | ||
random_state=0, | ||
metric=metric, | ||
) | ||
adata.uns[key_added] = neighbors_dict | ||
adata.obsp[neighbors_dict["distances_key"]] = distances | ||
adata.obsp[neighbors_dict["connectivities_key"]] = connectivities | ||
return adata | ||
|
||
|
||
def _get_metadata( | ||
key_added: str | None, | ||
**params: Unpack[NeighborsParams], | ||
) -> tuple[str, NeighborsDict]: | ||
if key_added is None: | ||
return "neighbors", NeighborsDict( | ||
connectivities_key="connectivities", | ||
distances_key="distances", | ||
params=params, | ||
) | ||
return key_added, NeighborsDict( | ||
connectivities_key=f"{key_added}_connectivities", | ||
distances_key=f"{key_added}_distances", | ||
params=params, | ||
) | ||
|
||
|
||
class FlatTree(NamedTuple): # noqa: D101 | ||
hyperplanes: None | ||
offsets: None | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
from scanpy import Neighbors | ||
from scanpy._compat import CSBase | ||
from testing.scanpy._helpers import anndata_v0_8_constructor_compat | ||
from testing.scanpy._helpers.data import pbmc68k_reduced | ||
|
||
if TYPE_CHECKING: | ||
from typing import Literal | ||
|
@@ -241,3 +242,26 @@ def test_restore_n_neighbors(neigh, conv): | |
ad.uns["neighbors"] = dict(connectivities=conv(neigh.connectivities)) | ||
neigh_restored = Neighbors(ad) | ||
assert neigh_restored.n_neighbors == 1 | ||
|
||
|
||
def test_neighbors_distance_equivalence(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. please parametrize this test over |
||
adata = pbmc68k_reduced() | ||
adata_d = adata.copy() | ||
|
||
sc.pp.neighbors(adata) | ||
# reusing the same distances | ||
sc.pp.neighbors(adata_d, distances=adata.obsp["distances"]) | ||
np.testing.assert_allclose( | ||
adata.obsp["connectivities"].toarray(), | ||
adata_d.obsp["connectivities"].toarray(), | ||
rtol=1e-5, | ||
) | ||
np.testing.assert_allclose( | ||
adata.obsp["distances"].toarray(), | ||
adata_d.obsp["distances"].toarray(), | ||
rtol=1e-5, | ||
) | ||
p, p_d = (ad.uns["neighbors"]["params"].copy() for ad in (adata, adata_d)) | ||
assert p.pop("metric") == "euclidean" | ||
assert p_d.pop("metric") is None | ||
assert p == p_d |
Uh oh!
There was an error while loading. Please reload this page.