Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@

## BREAKING

* `differential_expression/create_pseudobulks`: Removed functionality to filter psuedobulk samples based on number of aggregated samples threshold, as this functionality is now covered in `filter/delimit_count` (PR #1044).
* `differential_expression/create_pseudobulks`: Removed functionality to filter pseudobulk samples based on number of aggregated samples threshold, as this functionality is now covered in `filter/delimit_count` (PR #1044).

* `annotate/celtypist`: This component now requires to pass a raw count layer, that will be lognormalized with a target sum of 10000, the required count format for CellTypist (PR #1083).

## NEW FUNCTIONALITY

Expand Down
6 changes: 3 additions & 3 deletions src/annotate/celltypist/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ argument_groups:
required: false
- name: "--input_layer"
type: string
description: The layer in the input data containing log normalized counts to be used for cell type annotation if .X is not to be used.
description: The layer in the input data containing raw counts to be used for cell type annotation if .X is not to be used.
- name: "--input_var_gene_names"
type: string
required: false
Expand All @@ -50,7 +50,7 @@ argument_groups:
required: false
- name: "--reference_layer"
type: string
description: The layer in the reference data to be used for cell type annotation if .X is not to be used. Data are expected to be processed in the same way as the --input query dataset.
description: The layer in the reference data containing raw counts to be used for cell type annotation if .X is not to be used.
required: false
- name: "--reference_obs_target"
type: string
Expand Down Expand Up @@ -152,7 +152,7 @@ engines:
packages:
- celltypist==1.6.3
- type: python
__merge__: [ /src/base/requirements/anndata_mudata.yaml, .]
__merge__: [ /src/base/requirements/anndata_mudata.yaml, /src/base/requirements/scanpy.yaml, .]
__merge__: [ /src/base/requirements/python_test_setup.yaml, .]
runners:
- type: executable
Expand Down
92 changes: 51 additions & 41 deletions src/annotate/celltypist/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@
import mudata as mu
import anndata as ad
import pandas as pd
import numpy as np
import scanpy as sc

## VIASH START
par = {
"input": "resources_test/pbmc_1k_protein_v3/pbmc_1k_protein_v3_mms.h5mu",
"output": "output.h5mu",
"modality": "rna",
# "reference": None,
"reference": "resources_test/annotation_test_data/TS_Blood_filtered.h5mu",
"model": None,
# "model": "resources_test/annotation_test_data/celltypist_model_Immune_All_Low.pkl",
"input_layer": "log_normalized",
"reference_layer": "log_normalized",
"input_reference_gene_overlap": 100,
"reference_obs_target": "cell_ontology_class",
"reference_var_input": None,
"check_expression": False,
"feature_selection": True,
"majority_voting": True,
"output_compression": "gzip",
Expand All @@ -44,10 +41,43 @@
logger = setup_logger()


def check_celltypist_format(indata):
if np.abs(np.expm1(indata[0]).sum() - 10000) > 1:
return False
return True
def setup_anndata(
adata: ad.AnnData,
layer: str | None = None,
gene_names: str | None = None,
var_input: str | None = None,
) -> ad.AnnData:
"""Creates an AnnData object in the expected format for CellTypist,
with lognormalized data (with a target sum of 10000) in the .X slot.

Parameters
----------
adata
AnnData object.
layer
Layer in AnnData object to lognormalize.
gene_names
.obs field with the gene names to be used
var_input
.var field with a boolean array of the genes to be used (e.g. highly variable genes)
Returns
-------
AnnData object in CellTypist format.
"""

adata = set_var_index(adata, gene_names)

if var_input:
adata = subset_vars(adata, var_input)

raw_counts = adata.layers[layer].copy() if layer else adata.X.copy()

input_modality = ad.AnnData(X=raw_counts, var=pd.DataFrame(index=adata.var.index))

sc.pp.normalize_total(input_modality, target_sum=10000)
sc.pp.log1p(input_modality)

return input_modality


def main(par):
Expand All @@ -63,17 +93,8 @@ def main(par):
input_modality = input_adata.copy()

# Provide correct format of query data for celltypist annotation
## Sanitize gene names and set as index
input_modality = set_var_index(input_modality, par["input_var_gene_names"])
## Fetch lognormalized counts
lognorm_counts = (
input_modality.layers[par["input_layer"]].copy()
if par["input_layer"]
else input_modality.X.copy()
)
## Create AnnData object
input_modality = ad.AnnData(
X=lognorm_counts, var=pd.DataFrame(index=input_modality.var.index)
input_modality = setup_anndata(
input_modality, par["input_layer"], par["input_var_gene_names"]
)

if par["model"]:
Expand All @@ -86,18 +107,15 @@ def main(par):
)

elif par["reference"]:
reference_modality = mu.read_h5mu(par["reference"]).mod[par["modality"]]

# subset to HVG if required
if par["reference_var_input"]:
reference_modality = subset_vars(
reference_modality, par["reference_var_input"]
)

# Set var names to the desired gene name format (gene symbol, ensembl id, etc.)
# CellTypist requires query gene names to be in index
reference_modality = set_var_index(
reference_modality, par["reference_var_gene_names"]
reference_adata = mu.read_h5mu(par["reference"]).mod[par["modality"]]
reference_modality = reference_adata.copy()

# Provide correct format of query data for celltypist annotation
reference_modality = setup_anndata(
reference_modality,
par["reference_layer"],
par["reference_var_gene_names"],
par["reference_var_input"],
)

# Ensure enough overlap between genes in query and reference
Expand All @@ -107,18 +125,10 @@ def main(par):
min_gene_overlap=par["input_reference_gene_overlap"],
)

reference_matrix = (
reference_modality.layers[par["reference_layer"]]
if par["reference_layer"]
else reference_modality.X
)

labels = reference_modality.obs[par["reference_obs_target"]]

logger.info("Training CellTypist model on reference")
model = celltypist.train(
reference_matrix,
labels=labels,
reference_modality.X,
labels=reference_adata.obs[par["reference_obs_target"]],
genes=reference_modality.var.index,
C=par["C"],
max_iter=par["max_iter"],
Expand Down
54 changes: 0 additions & 54 deletions src/annotate/celltypist/test.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import sys
import os
import pytest
import subprocess
import re
import mudata as mu
from openpipeline_testutils.asserters import assert_annotation_objects_equal

Expand All @@ -27,12 +25,8 @@ def test_simple_execution(run_component, random_h5mu_path):
[
"--input",
input_file,
"--input_layer",
"log_normalized",
"--reference",
reference_file,
"--reference_layer",
"log_normalized",
"--reference_obs_target",
"cell_ontology_class",
"--reference_var_gene_names",
Expand Down Expand Up @@ -75,12 +69,8 @@ def test_set_params(run_component, random_h5mu_path):
[
"--input",
input_file,
"--input_layer",
"log_normalized",
"--reference",
reference_file,
"--reference_layer",
"log_normalized",
"--reference_obs_target",
"cell_ontology_class",
"--reference_var_gene_names",
Expand Down Expand Up @@ -159,49 +149,5 @@ def test_with_model(run_component, random_h5mu_path):
)


def test_fail_invalid_input_expression(run_component, random_h5mu_path):
output_file = random_h5mu_path()

# fails because input data are not lognormalized
with pytest.raises(subprocess.CalledProcessError) as err:
run_component(
[
"--input",
input_file,
"--reference",
reference_file,
"--reference_var_gene_names",
"ensemblid",
"--output",
output_file,
]
)
assert re.search(
r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell",
err.value.stdout.decode("utf-8"),
)

# fails because reference data are not lognormalized
with pytest.raises(subprocess.CalledProcessError) as err:
run_component(
[
"--input",
input_file,
"--layer",
"log_normalized",
"--reference",
reference_file,
"--reference_var_gene_names",
"ensemblid",
"--output",
output_file,
]
)
assert re.search(
r"Invalid expression matrix, expect log1p normalized expression to 10000 counts per cell",
err.value.stdout.decode("utf-8"),
)


if __name__ == "__main__":
sys.exit(pytest.main([__file__]))
Loading