From 315c46b03f981ee5f965abb931f78c62f9f07c20 Mon Sep 17 00:00:00 2001 From: Jan Schleicher <72506135+janschleicher@users.noreply.github.com> Date: Tue, 5 Dec 2023 08:34:55 +0100 Subject: [PATCH] 26 implement plotting function (#43) * Add skeleton for plotting function * Add plotting of QC variables; closes #26. * Remove debug print statement. * Add plotting of QC variables; closes #26. * Run black * Make quality control work inplace; closes #49. * Move plotting of QC covariates to quality_control function. * Create parent folders of figure directory if not present. --------- Co-authored-by: Sebastian Bischoff --- MORESCA/analysis_steps.py | 121 +++++++++++++++++++++++++++++----- MORESCA/config.gin | 13 +++- MORESCA/plotting.py | 36 ++++++++++ MORESCA/template.py | 11 ++-- MORESCA/tests/test-config.gin | 15 ++++- MORESCA/utils.py | 46 +++++++------ 6 files changed, 192 insertions(+), 50 deletions(-) create mode 100644 MORESCA/plotting.py diff --git a/MORESCA/analysis_steps.py b/MORESCA/analysis_steps.py index e198518..a9bf46d 100644 --- a/MORESCA/analysis_steps.py +++ b/MORESCA/analysis_steps.py @@ -4,6 +4,7 @@ import doubletdetection import gin +import matplotlib.pyplot as plt import numpy as np import pandas as pd import scanpy as sc @@ -12,6 +13,7 @@ from anndata import AnnData from MORESCA.utils import remove_cells_by_pct_counts, remove_genes, store_config_params +from MORESCA.plotting import plot_qc_vars try: from anticor_features.anticor_features import get_anti_cor_genes @@ -91,6 +93,9 @@ def quality_control( mt_threshold: Optional[Union[int, float, str, bool]], rb_threshold: Optional[Union[int, float, str, bool]], hb_threshold: Optional[Union[int, float, str, bool]], + figures: Optional[Union[Path, str]], + pre_qc_plots: Optional[bool], + post_qc_plots: Optional[bool], inplace: bool = True, ) -> Optional[AnnData]: """ @@ -107,6 +112,9 @@ def quality_control( mt_threshold: The threshold for the percentage of counts in mitochondrial genes. rb_threshold: The threshold for the percentage of counts in ribosomal genes. hb_threshold: The threshold for the percentage of counts in hemoglobin genes. + figures: The path to the output directory for the quality control plots. + pre_qc_plots: Whether to generate plots of QC covariates before quality control or not. + post_qc_plots: Whether to generate plots of QC covariates after quality control or not. inplace: Whether to perform the quality control steps in-place or return a modified copy of the AnnData object. Returns: @@ -135,6 +143,28 @@ def quality_control( if not apply: return None + # Quality control - calculate QC covariates + adata.obs["n_counts"] = adata.X.sum(1) + adata.obs["log_counts"] = np.log(adata.obs["n_counts"]) + adata.obs["n_genes"] = (adata.X > 0).sum(1) + + adata.var["mt"] = adata.var_names.str.contains("(?i)^MT-") + adata.var["rb"] = adata.var_names.str.contains("(?i)^RP[SL]") + adata.var["hb"] = adata.var_names.str.contains("(?i)^HB(?!EGF|S1L|P1).+") + + sc.pp.calculate_qc_metrics( + adata, qc_vars=["mt", "rb", "hb"], percent_top=[20], log1p=True, inplace=True + ) + + if pre_qc_plots: + # Make default directory if figures is None or empty string + if not figures: + figures = "figures/" + if isinstance(figures, str): + figures = Path(figures) + figures.mkdir(parents=True, exist_ok=True) + plot_qc_vars(adata, pre_qc=True, out_dir=figures) + if doublet_removal: clf = doubletdetection.BoostClassifier( n_iters=10, @@ -150,20 +180,7 @@ def quality_control( adata.obs["doublet"] = adata.obs["doublet"].astype(bool) adata.obs["doublet_score"] = clf.doublet_score() - adata = adata[(~adata.obs.doublet)] - - # Quality control - calculate QC covariates - adata.obs["n_counts"] = adata.X.sum(1) - adata.obs["log_counts"] = np.log(adata.obs["n_counts"]) - adata.obs["n_genes"] = (adata.X > 0).sum(1) - - adata.var["mt"] = adata.var_names.str.contains("(?i)^MT-") - adata.var["rb"] = adata.var_names.str.contains("(?i)^RP[SL]") - adata.var["hb"] = adata.var_names.str.contains("(?i)^HB(?!EGF|S1L|P1).+") - - sc.pp.calculate_qc_metrics( - adata, qc_vars=["mt", "rb", "hb"], percent_top=[20], log1p=True, inplace=True - ) + adata._inplace_subset_obs(~adata.obs.doublet) if outlier_removal: adata.obs["outlier"] = ( @@ -172,11 +189,11 @@ def quality_control( | is_outlier(adata, "pct_counts_in_top_20_genes", 5) ) - adata = adata[(~adata.obs.outlier)] + adata._inplace_subset_obs(~adata.obs.outlier) match n_genes_by_counts: case n_genes_by_counts if isinstance(n_genes_by_counts, float | int): - adata = adata[adata.obs.n_genes_by_counts < n_genes_by_counts, :] + adata._inplace_subset_obs(adata.obs.n_genes_by_counts < n_genes_by_counts) case "auto": pass case False | None: @@ -191,6 +208,15 @@ def quality_control( sc.pp.filter_cells(adata, min_genes=min_genes) sc.pp.filter_genes(adata, min_cells=min_cells) + if post_qc_plots: + # Make default directory if figures is None or empty string + if not figures: + figures = "figures/" + if isinstance(figures, str): + figures = Path(figures) + figures.mkdir(parents=True, exist_ok=True) + plot_qc_vars(adata, pre_qc=False, out_dir=figures) + if not inplace: return adata @@ -748,3 +774,66 @@ def diff_gene_exp( if not inplace: return adata + + +@gin.configurable +def umap( + adata: AnnData, + apply: bool, + inplace: bool = True, +) -> Optional[AnnData]: + if not inplace: + adata = adata.copy() + + store_config_params( + adata=adata, + analysis_step=plotting.__name__, + apply=apply, + params={ + key: val for key, val in locals().items() if key not in ["adata", "inplace"] + }, + ) + + if not apply: + return None + + sc.tl.umap(adata=adata) + + if not inplace: + return adata + + +@gin.configurable +def plotting( + adata: AnnData, + apply: bool, + umap: bool, + path: Path, + inplace: bool = True, +) -> Optional[AnnData]: + # TODO: Check before merging if we changed adata + if not inplace: + adata = adata.copy() + + store_config_params( + adata=adata, + analysis_step=plotting.__name__, + apply=apply, + params={ + key: val for key, val in locals().items() if key not in ["adata", "inplace"] + }, + ) + + if not apply: + return None + + path = Path(path) + path.mkdir(parents=True, exist_ok=True) + + if umap: + sc.pl.umap(adata=adata, show=False) + plt.savefig(Path(path, "umap.png")) + plt.close() + + if not inplace: + return adata diff --git a/MORESCA/config.gin b/MORESCA/config.gin index 5bd8014..7130138 100644 --- a/MORESCA/config.gin +++ b/MORESCA/config.gin @@ -5,9 +5,12 @@ quality_control: min_genes = 200 min_cells = 10 n_genes_by_counts = None - mt_threshold = 50 + mt_threshold = 10 rb_threshold = 10 hb_threshold = 2 + figures = "figures/" + pre_qc_plots = True + post_qc_plots = True normalization: apply = True @@ -52,3 +55,11 @@ diff_gene_exp: groupby = "leiden_r1.0" use_raw = True tables = False + +umap: + apply = True + +plotting: + apply = True + umap = True + path = "figures/" \ No newline at end of file diff --git a/MORESCA/plotting.py b/MORESCA/plotting.py new file mode 100644 index 0000000..9ac2a5c --- /dev/null +++ b/MORESCA/plotting.py @@ -0,0 +1,36 @@ +from pathlib import Path +import scanpy as sc +from anndata import AnnData +from matplotlib import pyplot as plt +import seaborn as sns + + +def plot_qc_vars(adata: AnnData, pre_qc: bool, out_dir: Path) -> None: + # Plot cell level QC metrics + qc_vars_cells = [ + "n_genes_by_counts", + "total_counts", + "pct_counts_mt", + "pct_counts_rb", + "pct_counts_hb", + ] + fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9, 6)) + sc.pl.scatter( + adata, x="total_counts", y="n_genes_by_counts", ax=axs.flat[0], show=False + ) + for qc_var, ax in zip(qc_vars_cells, axs.flat[1:]): + sns.violinplot(adata.obs[qc_var], ax=ax, cut=0) + sns.stripplot(adata.obs[qc_var], jitter=0.4, s=1, color="black", ax=ax) + fig.tight_layout() + fig.savefig(Path(out_dir, f"qc_vars_cells_{'pre' if pre_qc else 'post'}_qc.png")) + plt.close() + + # Plot gene level QC metrics + qc_vars_genes = ["n_cells_by_counts", "pct_dropout_by_counts"] + fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 3)) + for qc_var, ax in zip(qc_vars_genes, axs.flat): + sns.violinplot(adata.var[qc_var], ax=ax, cut=0) + sns.stripplot(adata.var[qc_var], jitter=0.4, s=1, color="black", ax=ax) + fig.tight_layout() + fig.savefig(Path(out_dir, f"qc_vars_genes_{'pre' if pre_qc else 'post'}_qc.png")) + plt.close() diff --git a/MORESCA/template.py b/MORESCA/template.py index 7265450..90d412d 100644 --- a/MORESCA/template.py +++ b/MORESCA/template.py @@ -13,6 +13,8 @@ pca, quality_control, scaling, + umap, + plotting, ) @@ -23,13 +25,6 @@ def run_analysis( verbose: bool, result_path: Path = Path("results"), ) -> None: - FIGURE_PATH = Path("figures") - FIGURE_PATH_PRE = Path(FIGURE_PATH, "preQC") - FIGURE_PATH_POST = Path(FIGURE_PATH, "postQC") - - FIGURE_PATH.mkdir(exist_ok=True) - FIGURE_PATH_PRE.mkdir(exist_ok=True) - FIGURE_PATH_POST.mkdir(exist_ok=True) result_path.mkdir(exist_ok=True) gin.parse_config_file(config_path) @@ -47,6 +42,8 @@ def run_analysis( neighborhood_graph(adata=adata) clustering(adata=adata) diff_gene_exp(adata=adata) + umap(adata=adata) + plotting(adata=adata) adata.write(Path(result_path, "data_processed.h5ad")) diff --git a/MORESCA/tests/test-config.gin b/MORESCA/tests/test-config.gin index 4367a44..e5184ce 100644 --- a/MORESCA/tests/test-config.gin +++ b/MORESCA/tests/test-config.gin @@ -5,9 +5,12 @@ quality_control: min_genes = 200 min_cells = 10 n_genes_by_counts = 8000 - mt_threshold = 15 + mt_threshold = "auto" rb_threshold = 10 hb_threshold = 2 + figures = "figures/" + pre_qc_plots = True + post_qc_plots = True normalization: apply = True @@ -52,4 +55,12 @@ diff_gene_exp: use_raw = False layer = "counts" corr_method = "benjamini-hochberg" - tables = False \ No newline at end of file + tables = False + +umap: + apply = True + +plotting: + apply = True + umap = False + path = "figures/" diff --git a/MORESCA/utils.py b/MORESCA/utils.py index c8cd52a..7d2395f 100644 --- a/MORESCA/utils.py +++ b/MORESCA/utils.py @@ -49,7 +49,6 @@ def remove_cells_by_pct_counts( genes: str, threshold: Optional[Union[int, float, str, bool]], inplace: bool = True, - save: bool = False, ) -> Optional[AnnData]: """ Remove cells from an AnnData object based on the percentage of counts in specific gene categories. @@ -84,12 +83,12 @@ def remove_cells_by_pct_counts( threshold, bool ): if genes == "rb": - adata = adata[adata.obs[f"pct_counts_{genes}"] > threshold, :] + adata._inplace_subset_obs(adata.obs[f"pct_counts_{genes}"] > threshold) else: - adata = adata[adata.obs[f"pct_counts_{genes}"] < threshold, :] + adata._inplace_subset_obs(adata.obs[f"pct_counts_{genes}"] < threshold) case "auto": if genes == "mt": - adata = ddqc(adata, inplace=False) + ddqc(adata) else: raise ValueError(f"Auto selection for {genes}_threshold not implemented.") case False | None: @@ -97,8 +96,6 @@ def remove_cells_by_pct_counts( case _: raise ValueError("Error.") - if save and isinstance(save, Path | str): - adata.write(save) if not inplace: return adata @@ -151,32 +148,32 @@ def ddqc(adata: AnnData, inplace: bool = True) -> Optional[AnnData]: if not inplace: adata = adata.copy() - adata_raw = adata.copy() + adata_copy = adata.copy() sc.pp.calculate_qc_metrics( - adata, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True + adata_copy, qc_vars=["mt"], percent_top=None, log1p=False, inplace=True ) - adata = adata[adata.obs.pct_counts_mt <= 80, :] + adata_copy._inplace_subset_obs(adata_copy.obs.pct_counts_mt <= 80) # Todo: can this be removed? - adata.layers["counts"] = adata.X.copy() - sc.pp.normalize_total(adata, target_sum=1e4) - sc.pp.log1p(adata) + adata_copy.layers["counts"] = adata_copy.X.copy() + sc.pp.normalize_total(adata_copy, target_sum=1e4) + sc.pp.log1p(adata_copy) sc.pp.highly_variable_genes( - adata, flavor="seurat_v3", n_top_genes=2000, layer="counts" + adata_copy, flavor="seurat_v3", n_top_genes=2000, layer="counts" ) - sc.pp.scale(adata) - sc.tl.pca(adata) - sc.pp.neighbors(adata, n_neighbors=20, n_pcs=50, metric="euclidean") - sc.tl.leiden(adata, resolution=1.4) + sc.pp.scale(adata_copy) + sc.tl.pca(adata_copy) + sc.pp.neighbors(adata_copy, n_neighbors=20, n_pcs=50, metric="euclidean") + sc.tl.leiden(adata_copy, resolution=1.4) # Directly apply the quality control checks and create the 'passed' mask - passed = np.ones(adata.n_obs, dtype=bool) - for cluster in adata.obs["leiden"].unique(): - indices = adata.obs["leiden"] == cluster - pct_counts_mt_cluster = adata.obs.loc[indices, "pct_counts_mt"].values - total_counts_cluster = adata.obs.loc[indices, "total_counts"].values - n_genes_cluster = adata.obs.loc[indices, "n_genes"].values + passed = np.ones(adata_copy.n_obs, dtype=bool) + for cluster in adata_copy.obs["leiden"].unique(): + indices = adata_copy.obs["leiden"] == cluster + pct_counts_mt_cluster = adata_copy.obs.loc[indices, "pct_counts_mt"].values + total_counts_cluster = adata_copy.obs.loc[indices, "total_counts"].values + n_genes_cluster = adata_copy.obs.loc[indices, "n_genes"].values passing_mask_mt = is_passing_upper(pct_counts_mt_cluster, nmads=3) passing_mask_counts = is_passing_lower( @@ -186,7 +183,8 @@ def ddqc(adata: AnnData, inplace: bool = True) -> Optional[AnnData]: passed[indices] = passing_mask_mt & passing_mask_counts & passing_mask_genes - adata = adata_raw[passed].copy() + passed = adata_copy[passed].obs_names + adata._inplace_subset_obs(passed.values) if not inplace: return adata