Skip to content

Commit

Permalink
Merge pull request #57 from claassenlab/56-default-values-for-functio…
Browse files Browse the repository at this point in the history
…n-arguments

Add default values for arguments of all analysis steps. Closes #56.
  • Loading branch information
janschleicher authored Jan 23, 2024
2 parents 4be6538 + 6b820d6 commit 3bfad9d
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 37 deletions.
110 changes: 73 additions & 37 deletions MORESCA/analysis_steps.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import warnings
from pathlib import Path
from typing import Optional, Union
from typing import Optional, Union, List, Tuple

import doubletdetection
import gin
Expand Down Expand Up @@ -85,17 +85,19 @@ def load_data(data_path) -> AnnData:
def quality_control(
adata: AnnData,
apply: bool,
doublet_removal: bool,
outlier_removal: bool,
min_genes: int,
min_cells: int,
n_genes_by_counts: Optional[Union[float, str, bool]],
mt_threshold: Optional[Union[int, float, str, bool]],
rb_threshold: Optional[Union[int, float, str, bool]],
hb_threshold: Optional[Union[int, float, str, bool]],
figures: Optional[Union[Path, str]],
pre_qc_plots: Optional[bool],
post_qc_plots: Optional[bool],
doublet_removal: bool = False,
outlier_removal: bool = False,
min_genes: Optional[Union[float, int, bool]] = None,
min_counts: Optional[Union[float, int, bool]] = None,
max_counts: Optional[Union[float, int, bool]] = None,
min_cells: Optional[Union[float, int, bool]] = None,
n_genes_by_counts: Optional[Union[float, int, str, bool]] = None,
mt_threshold: Optional[Union[int, float, str, bool]] = None,
rb_threshold: Optional[Union[int, float, str, bool]] = None,
hb_threshold: Optional[Union[int, float, str, bool]] = None,
figures: Optional[Union[Path, str]] = None,
pre_qc_plots: Optional[bool] = None,
post_qc_plots: Optional[bool] = None,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand All @@ -104,9 +106,9 @@ def quality_control(
Args:
adata: An AnnData object to perform quality control on.
apply: Whether to apply the quality control steps or not.
doublet_removal: Whether to perform doublet removal or not.
outlier_removal: Whether to remove outliers or not.
min_genes: The minimum number of genes required for a cell to pass quality control.
min_counts: The minimum total counts required for a cell to pass quality control.
max_counts: The maximum total counts allowed for a cell to pass quality control.
min_cells: The minimum number of cells required for a gene to pass quality control.
n_genes_by_counts: The threshold for the number of genes detected per cell.
mt_threshold: The threshold for the percentage of counts in mitochondrial genes.
Expand All @@ -115,6 +117,8 @@ def quality_control(
figures: The path to the output directory for the quality control plots.
pre_qc_plots: Whether to generate plots of QC covariates before quality control or not.
post_qc_plots: Whether to generate plots of QC covariates after quality control or not.
doublet_removal: Whether to perform doublet removal or not.
outlier_removal: Whether to remove outliers or not.
inplace: Whether to perform the quality control steps in-place or return a modified copy of the AnnData object.
Returns:
Expand Down Expand Up @@ -205,8 +209,37 @@ def quality_control(
remove_cells_by_pct_counts(adata=adata, genes="rb", threshold=rb_threshold)
remove_cells_by_pct_counts(adata=adata, genes="hb", threshold=hb_threshold)

sc.pp.filter_cells(adata, min_genes=min_genes)
sc.pp.filter_genes(adata, min_cells=min_cells)
match min_genes:
case min_genes if isinstance(min_genes, float | int):
sc.pp.filter_cells(adata, min_genes=min_genes)
case False | None:
print("No removal based on min_genes.")
case _:
raise ValueError("Invalid value for min_genes.")

match min_counts:
case min_counts if isinstance(min_counts, float | int):
sc.pp.filter_cells(adata, min_counts=min_counts)
case False | None:
print("No removal based on min_counts.")
case _:
raise ValueError("Invalid value for min_counts.")

match max_counts:
case max_counts if isinstance(max_counts, float | int):
sc.pp.filter_cells(adata, max_counts=max_counts)
case False | None:
print("No removal based on max_counts.")
case _:
raise ValueError("Invalid value for max_counts.")

match min_cells:
case min_cells if isinstance(min_cells, float | int):
sc.pp.filter_genes(adata, min_cells=min_cells)
case False | None:
print("No removal based on min_cells.")
case _:
raise ValueError("Invalid value for min_cells.")

if post_qc_plots:
# Make default directory if figures is None or empty string
Expand All @@ -225,10 +258,10 @@ def quality_control(
def normalization(
adata: AnnData,
apply: bool,
method: str,
remove_mt: Optional[bool],
remove_rb: Optional[bool],
remove_hb: Optional[bool],
method: Optional[str] = "log1pPF",
remove_mt: Optional[bool] = False,
remove_rb: Optional[bool] = False,
remove_hb: Optional[bool] = False,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand All @@ -239,7 +272,7 @@ def normalization(
apply: Whether to apply the normalization steps or not.
method: The normalization method to use. Available options are:
- "log1pCP10k": Normalize total counts to 10,000 and apply log1p transformation.
- "log1PF": Normalize counts per cell to median of total counts and apply log1p transformation.
- "log1pPF": Normalize counts per cell to median of total counts and apply log1p transformation.
- "PFlog1pPF": Normalize counts per cell to median of total counts, apply log1p transformation, and normalize again using the median of total counts.
- "analytical_pearson": Normalize using analytical Pearson residuals.
remove_mt: Whether to remove mitochondrial genes or not.
Expand Down Expand Up @@ -273,7 +306,7 @@ def normalization(
case "log1pCP10k":
sc.pp.normalize_total(adata, target_sum=10e4)
sc.pp.log1p(adata)
case "log1PF":
case "log1pPF":
sc.pp.normalize_total(adata, target_sum=None)
sc.pp.log1p(adata)
case "PFlog1pPF":
Expand Down Expand Up @@ -312,8 +345,8 @@ def normalization(
def feature_selection(
adata: AnnData,
apply: bool,
method: str,
number_features=None,
method: Optional[str] = "seurat",
number_features: Optional[int] = None,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand Down Expand Up @@ -408,7 +441,7 @@ def feature_selection(
def scaling(
adata: AnnData,
apply: bool,
max_value: Optional[Union[int, float]],
max_value: Optional[Union[int, float]] = None,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand Down Expand Up @@ -450,7 +483,7 @@ def pca(
adata: AnnData,
apply: bool,
n_comps: int = 50,
use_highly_variable: int = True,
use_highly_variable: bool = True,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand Down Expand Up @@ -491,8 +524,8 @@ def pca(
def batch_effect_correction(
adata: AnnData,
apply: bool,
method: str,
batch_key: str,
method: Optional[str] = "harmony",
batch_key: str = "batch",
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand Down Expand Up @@ -536,6 +569,8 @@ def batch_effect_correction(

match method:
case "harmony":
if "X_pca" not in adata.obsm_keys():
raise KeyError("X_pca not in adata.obsm. Run PCA first.")
sce.pp.harmony_integrate(
adata=adata,
key=batch_key,
Expand All @@ -557,8 +592,8 @@ def batch_effect_correction(
def neighborhood_graph(
adata: AnnData,
apply: bool,
n_neighbors: int,
n_pcs: int,
n_neighbors: int = 15,
n_pcs: Optional[int] = None,
metric: str = "cosine",
inplace: bool = True,
) -> Optional[AnnData]:
Expand Down Expand Up @@ -599,7 +634,8 @@ def neighborhood_graph(
n_pcs=n_pcs,
use_rep="X_pca_corrected"
if "X_pca_corrected" in adata.obsm_keys()
else "X_pca",
else "X_pca" if "X_pca" in adata.obsm_keys()
else None,
metric=metric,
random_state=0,
)
Expand All @@ -612,8 +648,8 @@ def neighborhood_graph(
def clustering(
adata: AnnData,
apply: bool,
method: str,
resolution=None,
method: str = "leiden",
resolution: Union[float, int, List[Union[float, int]], Tuple[Union[float, int]]] = 1.0,
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand Down Expand Up @@ -681,8 +717,8 @@ def clustering(
def diff_gene_exp(
adata: AnnData,
apply: bool,
method: str,
groupby: str,
method: str = "wilcoxon",
groupby: str = "leiden_r1.0",
use_raw: bool = False,
layer: str = "counts",
corr_method: str = "benjamini-hochberg",
Expand Down Expand Up @@ -807,8 +843,8 @@ def umap(
def plotting(
adata: AnnData,
apply: bool,
umap: bool,
path: Path,
umap: bool = True,
path: Path = Path("figures"),
inplace: bool = True,
) -> Optional[AnnData]:
# TODO: Check before merging if we changed adata
Expand Down
2 changes: 2 additions & 0 deletions MORESCA/tests/test-config.gin
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ quality_control:
doublet_removal = False
outlier_removal = True
min_genes = 200
min_counts = 10
max_counts = 6000
min_cells = 10
n_genes_by_counts = 8000
mt_threshold = "auto"
Expand Down

0 comments on commit 3bfad9d

Please sign in to comment.