Skip to content

Commit

Permalink
Use target cell_key when mapping (#122)
Browse files Browse the repository at this point in the history
* Use target cell_key when mapping

* Fix test
  • Loading branch information
oskbor authored Sep 30, 2024
1 parent dc08e06 commit 8a57bfb
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 8 deletions.
18 changes: 15 additions & 3 deletions scarf/datastore/mapping_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def run_mapping(
target_assay: Assay,
target_name: str,
target_feat_key: str,
target_cell_key: str = "I",
from_assay: Optional[str] = None,
cell_key: str = "I",
feat_key: Optional[str] = None,
Expand All @@ -57,6 +58,7 @@ def run_mapping(
target_name: Name of target data. This used to keep track of projections in the Zarr hierarchy
target_feat_key: This will be used to name wherein the normalized target data will be saved in its own
zarr hierarchy.
target_cell_key: Cell key for the target data. (Default value: 'I')
from_assay: Name of assay to be used. If no value is provided then the default assay will be used.
cell_key: Cell key. Should be same as the one that was used in the desired graph. (Default value: 'I')
feat_key: Feature key. Should be same as the one that was used in the desired graph. By default, the latest
Expand Down Expand Up @@ -119,6 +121,7 @@ def run_mapping(
cell_key,
feat_key,
target_feat_key,
target_cell_key,
filter_null,
exclude_missing,
self.nthreads,
Expand Down Expand Up @@ -151,15 +154,23 @@ def run_mapping(
logger.warning(f"`save_k` was decreased to {ann_obj.k}")
save_k = ann_obj.k
target_data = daskarr.from_zarr(
target_assay.z[f"normed__I__{target_feat_key}/data"], inline_array=True
target_assay.z[f"normed__{target_cell_key}__{target_feat_key}/data"],
inline_array=True,
)
if run_coral is True:
# Reversing coral here to correct target data
coral(
target_data, ann_obj.data, target_assay, target_feat_key, self.nthreads
target_data,
ann_obj.data,
target_assay,
target_feat_key,
target_cell_key,
self.nthreads,
)
target_data = daskarr.from_zarr(
target_assay.z[f"normed__I__{target_feat_key}/data_coral"],
target_assay.z[
f"normed__{target_cell_key}__{target_feat_key}/data_coral"
],
inline_array=True,
)
if ann_obj.method == "pca" and run_coral is False:
Expand Down Expand Up @@ -328,6 +339,7 @@ def get_target_classes(
store = self.zw[store_loc]
indices = store["indices"][:]
dists = store["distances"][:]

preds = []
weights = 1 - (dists / dists.max(axis=1).reshape(-1, 1))
for n in range(indices.shape[0]):
Expand Down
11 changes: 6 additions & 5 deletions scarf/mapping_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,15 @@ def _correlation_alignment(s: daskarr, t: daskarr, nthreads: int) -> daskarr:
return daskarr.dot(s, a_coral)


def coral(source_data, target_data, assay, feat_key: str, nthreads: int):
def coral(source_data, target_data, assay, feat_key: str, cell_key: str, nthreads: int):
"""Applies CORAL error correction to the input data.
Args:
source_data ():
target_data ():
assay ():
feat_key ():
cell_key ():
nthreads ():
"""
from .writers import dask_to_zarr
Expand Down Expand Up @@ -87,7 +88,7 @@ def coral(source_data, target_data, assay, feat_key: str, nthreads: int):
dask_to_zarr(
data,
assay.z["/"],
f"{assay.z.name}/normed__I__{feat_key}/data_coral",
f"{assay.z.name}/normed__{cell_key}__{feat_key}/data_coral",
1000,
nthreads,
msg="Writing out coral corrected data",
Expand Down Expand Up @@ -149,6 +150,7 @@ def align_features(
source_cell_key: str,
source_feat_key: str,
target_feat_key: str,
target_cell_key: str,
filter_null: bool,
exclude_missing: bool,
nthreads: int,
Expand Down Expand Up @@ -185,11 +187,10 @@ def align_features(
norm_params = source_assay.z[normed_loc].attrs["subset_params"]
sorted_t_idx = np.array(sorted(t_idx[t_idx != -1]))

# TODO: add target cell key
normed_data = target_assay.normed(
target_assay.cells.active_index("I"), sorted_t_idx, **norm_params
target_assay.cells.active_index(target_cell_key), sorted_t_idx, **norm_params
)
loc = f"{target_assay.z.name}/normed__I__{target_feat_key}/data"
loc = f"{target_assay.z.name}/normed__{target_cell_key}__{target_feat_key}/data"

og = create_zarr_dataset(
target_assay.z["/"], loc, (1000,), "float64", (normed_data.shape[0], len(t_idx))
Expand Down

0 comments on commit 8a57bfb

Please sign in to comment.