Skip to content

Commit

Permalink
feat(nbs): add support for model comparison to pancreas notebook
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <cameron.ray.smith@gmail.com>
  • Loading branch information
cameronraysmith committed Jul 31, 2024
1 parent b98ca30 commit 7cb5f77
Showing 1 changed file with 109 additions and 45 deletions.
154 changes: 109 additions & 45 deletions nbs/tutorials/pancreas/pancreas.qmd
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ from IPython.display import Image, display
Please see the [guide on interactive results review](/guides/interactive/interactive.qmd) for a
general review of how to download and review results. Here we use the same
approach without any of the explanatory text to retrieve model results to
generate plots for model selection and evaluation
generate plots for model selection and evaluation.

## Get results

Expand All @@ -40,88 +40,152 @@ to represent the results of applying the model to real data.
:::


### Setup remote connection
### Identify results of interest

```{python}
#| label: instantiate-remote-client
from flytekit.remote.remote import FlyteRemote
from flytekit.configuration import Config
# | label: get-workflow-io
from pyrovelocity.io.cluster import get_remote_task_results
(
model1_postprocessing_inputs,
model1_postprocessing_outputs,
) = get_remote_task_results(
execution_id="pyrovelocity-py311-defaul-fe20d09-dev-pzh-17b5092261c84e8695a",
task_id="f26c1pjy-0-dn1-0-dn3",
)
remote = FlyteRemote(
Config.for_endpoint("flyte.cluster.pyrovelocity.net"),
(
model2_postprocessing_inputs,
model2_postprocessing_outputs,
) = get_remote_task_results(
execution_id="pyrovelocity-py311-defaul-fe20d09-dev-pzh-17b5092261c84e8695a",
task_id="f26c1pjy-0-dn1-0-dn6",
)
```

### Identify results of interest
### Download results

```{python}
# | label: get-workflow-io
workflow_inputs = remote.get(
"flyte://v1/pyrovelocity/development/pyrovelocity-py311-defaul-fe20d09-dev-pzh-17b5092261c84e8695a/f26c1pjy-0-dn1-0-dn6/i"
# | label: download-outputs
from pyrovelocity.io.gcs import download_blob_from_uri
model1_pyrovelocity_data = download_blob_from_uri(
blob_uri=model1_postprocessing_outputs.o0.pyrovelocity_data.path,
download_filename_prefix=f"{model1_postprocessing_inputs.training_outputs.data_model}",
)
postprocessing_outputs = remote.get(
"flyte://v1/pyrovelocity/development/pyrovelocity-py311-defaul-fe20d09-dev-pzh-17b5092261c84e8695a/f26c1pjy-0-dn1-0-dn6/o"
model1_postprocessed_data = download_blob_from_uri(
blob_uri=model1_postprocessing_outputs.o0.postprocessed_data.path,
download_filename_prefix=f"{model1_postprocessing_inputs.training_outputs.data_model}",
)
model2_pyrovelocity_data = download_blob_from_uri(
blob_uri=model2_postprocessing_outputs.o0.pyrovelocity_data.path,
download_filename_prefix=f"{model2_postprocessing_inputs.training_outputs.data_model}",
)
model2_postprocessed_data = download_blob_from_uri(
blob_uri=model2_postprocessing_outputs.o0.postprocessed_data.path,
download_filename_prefix=f"{model2_postprocessing_inputs.training_outputs.data_model}",
)
```


## Analyze results

### Model 1

#### Load data

```{python}
# | label: model1-load-postprocessed-data
# | output: true
import scanpy as sc
from pyrovelocity.utils import print_anndata
adata = sc.read(model1_postprocessed_data)
print_anndata(adata)
```

```{python}
# | label: create-outputs-dict
from omegaconf import OmegaConf
from flytekit.interaction.string_literals import literal_map_string_repr
from pyrovelocity.utils import print_config_tree
inputs_dict = literal_map_string_repr(workflow_inputs.literals)
inputs_dictconfig = OmegaConf.create(inputs_dict)
print_config_tree(inputs_dict)
outputs_dict = literal_map_string_repr(postprocessing_outputs.literals)
outputs_dictconfig = OmegaConf.create(outputs_dict)
print_config_tree(outputs_dict)
# | label: model1-load-posterior-samples
# | output: true
from pyrovelocity.utils import pretty_print_dict
from pyrovelocity.io import CompressedPickle
posterior_samples = CompressedPickle.load(model1_pyrovelocity_data)
pretty_print_dict(posterior_samples)
```

### Download results
#### Extract results of interest

```{python}
#| label: download-outputs
from pyrovelocity.io.gcs import download_blob_from_uri
# | label: model1-extract-gene-selection
from pyrovelocity.analysis.analyze import pareto_frontier_genes
pyrovelocity_data = download_blob_from_uri(
outputs_dictconfig.o0.pyrovelocity_data.path
volcano_data = posterior_samples["gene_ranking"]
number_of_marker_genes = min(
max(int(len(volcano_data) * 0.1), 4), 6, len(volcano_data)
)
postprocessed_data = download_blob_from_uri(
outputs_dictconfig.o0.postprocessed_data.path
putative_marker_genes = pareto_frontier_genes(
volcano_data, number_of_marker_genes
)
```

#### Generate plots

## Analyze results
```{python}
# | label: model1-generate-gene-selection-summary-plot
# | output: false
from pyrovelocity.plots import plot_gene_selection_summary
vector_field_basis = model1_postprocessing_inputs.preprocess_data_args.vector_field_basis
cell_state = model1_postprocessing_inputs.preprocess_data_args.cell_state
plot_gene_selection_summary(
adata=adata,
posterior_samples=posterior_samples,
basis=vector_field_basis,
cell_state=cell_state,
plot_name="gene_selection_summary_plot.pdf",
selected_genes=putative_marker_genes,
show_marginal_histograms=False,
)
```

```{python}
# | label: model1-show-gene-selection-summary-plot
# | code-fold: true
# | output: true
display(Image(filename=f"gene_selection_summary_plot.pdf.png"))
```

### Model 2

### Load data
#### Load data

```{python}
# | label: load-postprocessed-data
# | label: model2-load-postprocessed-data
# | output: true
import scanpy as sc
from pyrovelocity.utils import print_anndata
adata = sc.read(postprocessed_data)
adata = sc.read(model2_postprocessed_data)
print_anndata(adata)
```

```{python}
# | label: load-posterior-samples
# | label: model2-load-posterior-samples
# | output: true
from pyrovelocity.utils import pretty_print_dict
from pyrovelocity.io import CompressedPickle
posterior_samples = CompressedPickle.load(pyrovelocity_data)
posterior_samples = CompressedPickle.load(model2_pyrovelocity_data)
pretty_print_dict(posterior_samples)
```

### Extract results of interest
#### Extract results of interest

```{python}
# | label: extract-gene-selection
# | label: model2-extract-gene-selection
from pyrovelocity.analysis.analyze import pareto_frontier_genes
volcano_data = posterior_samples["gene_ranking"]
Expand All @@ -133,15 +197,15 @@ putative_marker_genes = pareto_frontier_genes(
)
```

### Generate plots
#### Generate plots

```{python}
# | label: generate-gene-selection-summary-plot
# | label: model2-generate-gene-selection-summary-plot
# | output: false
from pyrovelocity.plots import plot_gene_selection_summary
vector_field_basis = inputs_dictconfig.preprocess_data_args.vector_field_basis
cell_state = inputs_dictconfig.preprocess_data_args.cell_state
vector_field_basis = model2_postprocessing_inputs.preprocess_data_args.vector_field_basis
cell_state = model2_postprocessing_inputs.preprocess_data_args.cell_state
plot_gene_selection_summary(
adata=adata,
Expand All @@ -155,7 +219,7 @@ plot_gene_selection_summary(
```

```{python}
# | label: show-gene-selection-summary-plot
# | label: model2-show-gene-selection-summary-plot
# | code-fold: true
# | output: true
display(Image(filename=f"gene_selection_summary_plot.pdf.png"))
Expand Down

0 comments on commit 7cb5f77

Please sign in to comment.