diff --git a/src/pyrovelocity/tasks/summarize.py b/src/pyrovelocity/tasks/summarize.py index c11662d08..350f86913 100644 --- a/src/pyrovelocity/tasks/summarize.py +++ b/src/pyrovelocity/tasks/summarize.py @@ -9,17 +9,18 @@ from pyrovelocity.analysis.analyze import pareto_frontier_genes from pyrovelocity.io import CompressedPickle from pyrovelocity.logging import configure_logging -from pyrovelocity.plots import cluster_violin_plots -from pyrovelocity.plots import plot_gene_ranking -from pyrovelocity.plots import plot_gene_selection_summary -from pyrovelocity.plots import plot_parameter_posterior_distributions -from pyrovelocity.plots import plot_shared_time_uncertainty -from pyrovelocity.plots import plot_vector_field_summary -from pyrovelocity.plots import posterior_curve -from pyrovelocity.plots import rainbowplot +from pyrovelocity.plots import ( + cluster_violin_plots, + plot_gene_ranking, + plot_gene_selection_summary, + plot_parameter_posterior_distributions, + plot_shared_time_uncertainty, + plot_vector_field_summary, + posterior_curve, + rainbowplot, +) from pyrovelocity.utils import save_anndata_counts_to_dataframe - __all__ = ["summarize_dataset"] logger = configure_logging(__name__) @@ -43,6 +44,8 @@ def summarize_dataset( Args: data_model (str): string containing the data set and model identifier, e.g. simulated_model1 + data_model_path (str | Path): + path to a model trained on a particualar data set, model_path (str | Path): path to the model, e.g. models/simulated_model1/model pyrovelocity_data_path (str | Path): path to the pyrovelocity data, @@ -54,23 +57,25 @@ def summarize_dataset( vector_field_basis (str): string containing the vector field basis identifier, e.g. umap reports_path (str | Path): path to the reports, e.g. reports + enable_experimental_plots (bool): flag to enable experimental plots Returns: Path: Top-level path to reports outputs for the data model combination, e.g. reports/simulated_model1 Examples: - >>> from pyrovelocity.tasks.summarize import summarize_dataset # xdoctest: +SKIP - >>> tmp = getfixture("tmp_path") # xdoctest: +SKIP + >>> # xdoctest: +SKIP + >>> from pyrovelocity.tasks.summarize import summarize_dataset + >>> tmp = getfixture("tmp_path") >>> summarize_dataset( - ... "simulated_model1", - ... "models/simulated_model1", - ... "models/simulated_model1/model", - ... "models/simulated_model1/pyrovelocity.pkl.zst", - ... "models/simulated_model1/postprocessed.h5ad", - ... "leiden", - ... "umap", - ... ) # xdoctest: +SKIP + ... data_model="simulated_model1", + ... data_model_path="models/simulated_model1", + ... model_path="models/simulated_model1/model", + ... pyrovelocity_data_path="models/simulated_model1/pyrovelocity.pkl.zst", + ... postprocessed_data_path="models/simulated_model1/postprocessed.h5ad", + ... cell_state="leiden", + ... vector_field_basis="umap", + ... ) """ logger.info(f"\n\nPlotting summary figure(s) in: {data_model}\n\n")