From 376cda39755ed936d697e841992b5a23837366a4 Mon Sep 17 00:00:00 2001 From: Parashar Date: Thu, 21 Apr 2022 00:11:15 +0200 Subject: [PATCH] reimplemented find_markers_by_rank for prenormed case --- VERSION | 2 +- requirements.txt | 3 ++- scarf/datastore.py | 7 ++++++- scarf/markers.py | 51 ++++++++++++++++++++++++++++++---------------- 4 files changed, 42 insertions(+), 21 deletions(-) diff --git a/VERSION b/VERSION index b72b05e..c0b8d59 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -0.19.3 +0.19.4 diff --git a/requirements.txt b/requirements.txt index 31d67f3..c5be856 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,7 @@ h5py numcodecs umap-learn scikit-learn -scikit-network +scikit-network==0.24.0 scipy statsmodels seaborn @@ -29,3 +29,4 @@ requests jinja2 kneed ipywidgets +joblib diff --git a/scarf/datastore.py b/scarf/datastore.py index 59340c9..77ba131 100644 --- a/scarf/datastore.py +++ b/scarf/datastore.py @@ -3664,8 +3664,9 @@ def run_marker_search( cell_key: str = None, threshold: float = 0.25, gene_batch_size: int = 50, - use_prenormed: bool = True, + use_prenormed: bool = False, prenormed_store: Optional[str] = None, + n_threads: int = None, **norm_params, ) -> None: """ @@ -3689,6 +3690,7 @@ def run_marker_search( This can speed up the results. (Default value: True) prenormed_store: If prenormalized values were computed in a custom manner then, the Zarr group's location can be provided here. (Default value: None) + n_threads: Number of threads to use to run the marker search. Only used if use_prenormed is True. Returns: None @@ -3702,6 +3704,8 @@ def run_marker_search( ) if cell_key is None: cell_key = "I" + if n_threads is None: + n_threads = self.nthreads assay = self._get_assay(from_assay) markers = find_markers_by_rank( assay, @@ -3711,6 +3715,7 @@ def run_marker_search( gene_batch_size, use_prenormed, prenormed_store, + n_threads, **norm_params, ) z = self.z[assay.name] diff --git a/scarf/markers.py b/scarf/markers.py index a20d600..23fd8cb 100644 --- a/scarf/markers.py +++ b/scarf/markers.py @@ -8,7 +8,8 @@ import pandas as pd from scipy.stats import linregress from typing import Optional - +from joblib import Parallel, delayed +from scipy.stats import rankdata __all__ = [ "find_markers_by_rank", @@ -36,6 +37,7 @@ def find_markers_by_rank( batch_size: int, use_prenormed: bool, prenormed_store: Optional[str], + n_threads: int, **norm_params, ) -> dict: """ @@ -49,6 +51,7 @@ def find_markers_by_rank( batch_size: use_prenormed: prenormed_store: + n_threads: Returns: @@ -70,6 +73,14 @@ def mean_rank_wrapper(v): """ return calc_mean_rank(v.values) + def prenormed_mean_rank_wrapper(gene_idx): + mr = calc_mean_rank(rankdata(prenormed_store[gene_idx][:][cell_idx], method='dense')) + idx = mr > threshold + if np.any(idx): + return np.array([ii[idx], np.repeat(gene_idx, idx.sum()), mr[idx]]) + else: + return None + groups = assay.cells.fetch(group_key, cell_key) group_set = sorted(set(groups)) n_groups = len(group_set) @@ -87,12 +98,17 @@ def mean_rank_wrapper(v): use_prenormed = False if use_prenormed: - batch_iterator = read_prenormed_batches( - prenormed_store, - assay.cells.active_index(cell_key), - batch_size, - desc="Finding markers" - ) + ii = np.array(list(rev_idx_map.values())) + cell_idx = assay.cells.active_index(cell_key) + batch_iterator = tqdmbar(prenormed_store.keys(), desc="Finding markers") + res = Parallel(n_jobs=n_threads)(delayed(prenormed_mean_rank_wrapper)(i) for i in batch_iterator) + res = pd.DataFrame(np.hstack([x for x in res if x is not None])).T + res[1] = res[1].astype(int) + res[2] = res[2].astype(float) + results = {} + for i in group_set: + results[i] = res[res[0] == str(i)].sort_values(by=2, ascending=False)[[1, 2]].set_index(1)[2] + return results else: batch_iterator = assay.iter_normed_feature_wise( cell_key, @@ -101,17 +117,16 @@ def mean_rank_wrapper(v): "Finding markers", **norm_params ) - - for val in batch_iterator: - res = val.rank(method="dense").astype(int).apply(mean_rank_wrapper) - # Removing genes that were below the threshold in all the groups - res = res.T[(res < threshold).sum() != n_groups] - for j in res: - results[rev_idx_map[j]].append(res[j][res[j] > threshold]) - - for i in results: - results[i] = pd.concat(results[i]).sort_values(ascending=False) - return results + for val in batch_iterator: + res = val.rank(method="dense").astype(int).apply(mean_rank_wrapper) + # Removing genes that were below the threshold in all the groups + res = res.T[(res < threshold).sum() != n_groups] + for j in res: + results[rev_idx_map[j]].append(res[j][res[j] > threshold]) + + for i in results: + results[i] = pd.concat(results[i]).sort_values(ascending=False) + return results def find_markers_by_regression(