Skip to content

Change the default number of nearest neighbors search in Ingest #1111

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
May 18, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions scanpy/tests/test_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,19 @@ def test_neighbors(adatas):
assert percent_correct > 0.99


@pytest.mark.parametrize('n', [3, 4])
def test_neighbors_defaults(adatas, n):
adata_ref = adatas[0].copy()
adata_new = adatas[1].copy()

sc.pp.neighbors(adata_ref, n_neighbors=n)

ing = sc.tl.Ingest(adata_ref)
ing.fit(adata_new)
ing.neighbors()
assert ing._indices.shape[1] == n


@pytest.mark.skipif(
pkg_version("anndata") < sc.tl._ingest.ANNDATA_MIN_VERSION,
reason="`AnnData.concatenate` does not concatenate `.obsm` in old anndata versions",
Expand Down
11 changes: 8 additions & 3 deletions scanpy/tools/_ingest.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,9 +186,11 @@ class Ingest:
"""

def _init_umap(self, adata):
from umap import UMAP
import umap as u

self._umap = UMAP(
u.umap_._HAVE_PYNNDESCENT = False

self._umap = u.UMAP(
metric=self._metric,
random_state=adata.uns['umap']['params'].get('random_state', 0),
)
Expand Down Expand Up @@ -396,7 +398,7 @@ def fit(self, adata_new):
self._adata_new = adata_new
self._obsm['rep'] = self._same_rep()

def neighbors(self, k=10, queue_size=5, random_state=0):
def neighbors(self, k=None, queue_size=5, random_state=0):
"""\
Calculate neighbors of `adata_new` observations in `adata`.

Expand All @@ -412,6 +414,9 @@ def neighbors(self, k=10, queue_size=5, random_state=0):
train = self._rep
test = self._obsm['rep']

if k is None:
k = self._n_neighbors

init = self._initialise_search(
self._rp_forest, train, test, int(k * queue_size), rng_state=rng_state,
)
Expand Down