Skip to content

Commit

Permalink
26 implement plotting function (#43)
Browse files Browse the repository at this point in the history
* Add skeleton for plotting function

* Add plotting of QC variables; closes #26.

* Remove debug print statement.

* Add plotting of QC variables; closes #26.

* Run black

* Make quality control work inplace; closes #49.

* Move plotting of QC covariates to quality_control function.

* Create parent folders of figure directory if not present.

---------

Co-authored-by: Sebastian Bischoff <sebastian@salzreute.de>
  • Loading branch information
janschleicher and Baschdl authored Dec 5, 2023
1 parent 55b5c98 commit 315c46b
Show file tree
Hide file tree
Showing 6 changed files with 192 additions and 50 deletions.
121 changes: 105 additions & 16 deletions MORESCA/analysis_steps.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import doubletdetection
import gin
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import scanpy as sc
Expand All @@ -12,6 +13,7 @@
from anndata import AnnData

from MORESCA.utils import remove_cells_by_pct_counts, remove_genes, store_config_params
from MORESCA.plotting import plot_qc_vars

try:
from anticor_features.anticor_features import get_anti_cor_genes
Expand Down Expand Up @@ -91,6 +93,9 @@ def quality_control(
mt_threshold: Optional[Union[int, float, str, bool]],
rb_threshold: Optional[Union[int, float, str, bool]],
hb_threshold: Optional[Union[int, float, str, bool]],
figures: Optional[Union[Path, str]],
pre_qc_plots: Optional[bool],
post_qc_plots: Optional[bool],
inplace: bool = True,
) -> Optional[AnnData]:
"""
Expand All @@ -107,6 +112,9 @@ def quality_control(
mt_threshold: The threshold for the percentage of counts in mitochondrial genes.
rb_threshold: The threshold for the percentage of counts in ribosomal genes.
hb_threshold: The threshold for the percentage of counts in hemoglobin genes.
figures: The path to the output directory for the quality control plots.
pre_qc_plots: Whether to generate plots of QC covariates before quality control or not.
post_qc_plots: Whether to generate plots of QC covariates after quality control or not.
inplace: Whether to perform the quality control steps in-place or return a modified copy of the AnnData object.
Returns:
Expand Down Expand Up @@ -135,6 +143,28 @@ def quality_control(
if not apply:
return None

# Quality control - calculate QC covariates
adata.obs["n_counts"] = adata.X.sum(1)
adata.obs["log_counts"] = np.log(adata.obs["n_counts"])
adata.obs["n_genes"] = (adata.X > 0).sum(1)

adata.var["mt"] = adata.var_names.str.contains("(?i)^MT-")
adata.var["rb"] = adata.var_names.str.contains("(?i)^RP[SL]")
adata.var["hb"] = adata.var_names.str.contains("(?i)^HB(?!EGF|S1L|P1).+")

sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt", "rb", "hb"], percent_top=[20], log1p=True, inplace=True
)

if pre_qc_plots:
# Make default directory if figures is None or empty string
if not figures:
figures = "figures/"
if isinstance(figures, str):
figures = Path(figures)
figures.mkdir(parents=True, exist_ok=True)
plot_qc_vars(adata, pre_qc=True, out_dir=figures)

if doublet_removal:
clf = doubletdetection.BoostClassifier(
n_iters=10,
Expand All @@ -150,20 +180,7 @@ def quality_control(
adata.obs["doublet"] = adata.obs["doublet"].astype(bool)
adata.obs["doublet_score"] = clf.doublet_score()

adata = adata[(~adata.obs.doublet)]

# Quality control - calculate QC covariates
adata.obs["n_counts"] = adata.X.sum(1)
adata.obs["log_counts"] = np.log(adata.obs["n_counts"])
adata.obs["n_genes"] = (adata.X > 0).sum(1)

adata.var["mt"] = adata.var_names.str.contains("(?i)^MT-")
adata.var["rb"] = adata.var_names.str.contains("(?i)^RP[SL]")
adata.var["hb"] = adata.var_names.str.contains("(?i)^HB(?!EGF|S1L|P1).+")

sc.pp.calculate_qc_metrics(
adata, qc_vars=["mt", "rb", "hb"], percent_top=[20], log1p=True, inplace=True
)
adata._inplace_subset_obs(~adata.obs.doublet)

if outlier_removal:
adata.obs["outlier"] = (
Expand All @@ -172,11 +189,11 @@ def quality_control(
| is_outlier(adata, "pct_counts_in_top_20_genes", 5)
)

adata = adata[(~adata.obs.outlier)]
adata._inplace_subset_obs(~adata.obs.outlier)

match n_genes_by_counts:
case n_genes_by_counts if isinstance(n_genes_by_counts, float | int):
adata = adata[adata.obs.n_genes_by_counts < n_genes_by_counts, :]
adata._inplace_subset_obs(adata.obs.n_genes_by_counts < n_genes_by_counts)
case "auto":
pass
case False | None:
Expand All @@ -191,6 +208,15 @@ def quality_control(
sc.pp.filter_cells(adata, min_genes=min_genes)
sc.pp.filter_genes(adata, min_cells=min_cells)

if post_qc_plots:
# Make default directory if figures is None or empty string
if not figures:
figures = "figures/"
if isinstance(figures, str):
figures = Path(figures)
figures.mkdir(parents=True, exist_ok=True)
plot_qc_vars(adata, pre_qc=False, out_dir=figures)

if not inplace:
return adata

Expand Down Expand Up @@ -748,3 +774,66 @@ def diff_gene_exp(

if not inplace:
return adata


@gin.configurable
def umap(
adata: AnnData,
apply: bool,
inplace: bool = True,
) -> Optional[AnnData]:
if not inplace:
adata = adata.copy()

store_config_params(
adata=adata,
analysis_step=plotting.__name__,
apply=apply,
params={
key: val for key, val in locals().items() if key not in ["adata", "inplace"]
},
)

if not apply:
return None

sc.tl.umap(adata=adata)

if not inplace:
return adata


@gin.configurable
def plotting(
adata: AnnData,
apply: bool,
umap: bool,
path: Path,
inplace: bool = True,
) -> Optional[AnnData]:
# TODO: Check before merging if we changed adata
if not inplace:
adata = adata.copy()

store_config_params(
adata=adata,
analysis_step=plotting.__name__,
apply=apply,
params={
key: val for key, val in locals().items() if key not in ["adata", "inplace"]
},
)

if not apply:
return None

path = Path(path)
path.mkdir(parents=True, exist_ok=True)

if umap:
sc.pl.umap(adata=adata, show=False)
plt.savefig(Path(path, "umap.png"))
plt.close()

if not inplace:
return adata
13 changes: 12 additions & 1 deletion MORESCA/config.gin
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ quality_control:
min_genes = 200
min_cells = 10
n_genes_by_counts = None
mt_threshold = 50
mt_threshold = 10
rb_threshold = 10
hb_threshold = 2
figures = "figures/"
pre_qc_plots = True
post_qc_plots = True

normalization:
apply = True
Expand Down Expand Up @@ -52,3 +55,11 @@ diff_gene_exp:
groupby = "leiden_r1.0"
use_raw = True
tables = False

umap:
apply = True

plotting:
apply = True
umap = True
path = "figures/"
36 changes: 36 additions & 0 deletions MORESCA/plotting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
from pathlib import Path
import scanpy as sc
from anndata import AnnData
from matplotlib import pyplot as plt
import seaborn as sns


def plot_qc_vars(adata: AnnData, pre_qc: bool, out_dir: Path) -> None:
# Plot cell level QC metrics
qc_vars_cells = [
"n_genes_by_counts",
"total_counts",
"pct_counts_mt",
"pct_counts_rb",
"pct_counts_hb",
]
fig, axs = plt.subplots(nrows=2, ncols=3, figsize=(9, 6))
sc.pl.scatter(
adata, x="total_counts", y="n_genes_by_counts", ax=axs.flat[0], show=False
)
for qc_var, ax in zip(qc_vars_cells, axs.flat[1:]):
sns.violinplot(adata.obs[qc_var], ax=ax, cut=0)
sns.stripplot(adata.obs[qc_var], jitter=0.4, s=1, color="black", ax=ax)
fig.tight_layout()
fig.savefig(Path(out_dir, f"qc_vars_cells_{'pre' if pre_qc else 'post'}_qc.png"))
plt.close()

# Plot gene level QC metrics
qc_vars_genes = ["n_cells_by_counts", "pct_dropout_by_counts"]
fig, axs = plt.subplots(nrows=1, ncols=2, figsize=(6, 3))
for qc_var, ax in zip(qc_vars_genes, axs.flat):
sns.violinplot(adata.var[qc_var], ax=ax, cut=0)
sns.stripplot(adata.var[qc_var], jitter=0.4, s=1, color="black", ax=ax)
fig.tight_layout()
fig.savefig(Path(out_dir, f"qc_vars_genes_{'pre' if pre_qc else 'post'}_qc.png"))
plt.close()
11 changes: 4 additions & 7 deletions MORESCA/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
pca,
quality_control,
scaling,
umap,
plotting,
)


Expand All @@ -23,13 +25,6 @@ def run_analysis(
verbose: bool,
result_path: Path = Path("results"),
) -> None:
FIGURE_PATH = Path("figures")
FIGURE_PATH_PRE = Path(FIGURE_PATH, "preQC")
FIGURE_PATH_POST = Path(FIGURE_PATH, "postQC")

FIGURE_PATH.mkdir(exist_ok=True)
FIGURE_PATH_PRE.mkdir(exist_ok=True)
FIGURE_PATH_POST.mkdir(exist_ok=True)
result_path.mkdir(exist_ok=True)

gin.parse_config_file(config_path)
Expand All @@ -47,6 +42,8 @@ def run_analysis(
neighborhood_graph(adata=adata)
clustering(adata=adata)
diff_gene_exp(adata=adata)
umap(adata=adata)
plotting(adata=adata)

adata.write(Path(result_path, "data_processed.h5ad"))

Expand Down
15 changes: 13 additions & 2 deletions MORESCA/tests/test-config.gin
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ quality_control:
min_genes = 200
min_cells = 10
n_genes_by_counts = 8000
mt_threshold = 15
mt_threshold = "auto"
rb_threshold = 10
hb_threshold = 2
figures = "figures/"
pre_qc_plots = True
post_qc_plots = True

normalization:
apply = True
Expand Down Expand Up @@ -52,4 +55,12 @@ diff_gene_exp:
use_raw = False
layer = "counts"
corr_method = "benjamini-hochberg"
tables = False
tables = False

umap:
apply = True

plotting:
apply = True
umap = False
path = "figures/"
Loading

0 comments on commit 315c46b

Please sign in to comment.