Skip to content

Commit

Permalink
Implement alternate stability analysis (#58)
Browse files Browse the repository at this point in the history
* add sc_counts bootstrap method

* remove stability wf

* add novel stability wf

* fix name

* also passthrough layer argument

* turn r component into python component

* fix for list bug

* simplify wf

* fix wf
  • Loading branch information
rcannood authored Jun 1, 2024
1 parent e8451aa commit 408e292
Show file tree
Hide file tree
Showing 7 changed files with 259 additions and 144 deletions.
11 changes: 4 additions & 7 deletions scripts/run_stability_tw.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,17 @@ RUN_ID="stability_$(date +%Y-%m-%d_%H-%M-%S)"
publish_dir="s3://openproblems-data/resources/dge_perturbation_prediction/results/${RUN_ID}"

cat > /tmp/params.yaml << HERE
id: dge_perturbation_task
input_states: s3://openproblems-bio/public/neurips-2023-competition/workflow-resources/neurips-2023-data/state.yaml
output_state: "state.yaml"
id: neurips-2023-data
sc_counts: s3://openproblems-bio/public/neurips-2023-competition/sc_counts_reannotated_with_counts.h5ad
layer: clipped_sign_log10_pval
publish_dir: "$publish_dir"
rename_keys: "de_train_h5ad:de_train_h5ad,de_test_h5ad:de_test_h5ad,id_map:id_map"
settings: '{"stability": true, "stability_obs_fraction": 0.99, "stability_var_fraction": 0.99}'
HERE

tw launch https://github.com/openproblems-bio/task-dge-perturbation-prediction.git \
--revision main_build \
--pull-latest \
--main-script target/nextflow/workflows/run_benchmark/main.nf \
--main-script target/nextflow/workflows/run_stability_analysis/main.nf \
--workspace 53907369739130 \
--compute-env 6TeIFgV5OY4pJCk8I0bfOh \
--params-file /tmp/params.yaml \
--entry-name auto \
--config src/common/nextflow_helpers/labels_tw.config
46 changes: 46 additions & 0 deletions src/task/process_dataset/bootstrap_sc_counts/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
functionality:
name: bootstrap_sc_counts
namespace: process_dataset
info:
type: process_dataset
type_info:
label: Bootstrap
summary: A component to generate bootstraps of a dataset.
description: |
This component generates bootstraps of a dataset. It takes as input a parquet file and an h5ad file and generates bootstraps of the dataset. The bootstraps are saved as parquet and h5ad files.
argument_groups:
- name: Inputs
arguments:
- name: --input
type: file
required: true
direction: input
example: resources/neurips-2023-raw/sc_counts_reannotated_with_counts.h5ad
- name: Outputs
arguments:
- name: --output
type: file
required: true
direction: output
example: sc_counts_bootstrap.h5ad
- name: Arguments
arguments:
- name: --obs_fraction
type: double
required: true
default: 0.95
description: Fraction of the obs of the sc_counts to include in each bootstrap.
- name: --var_fraction
type: double
required: true
default: 0.95
description: Fraction of the var of the sc_counts to include in each bootstrap.
resources:
- type: python_script
path: script.py
platforms:
- type: docker
image: ghcr.io/openproblems-bio/base_python:1.0.4
- type: nextflow
directives:
label: [ midtime, highmem, midcpu ]
32 changes: 32 additions & 0 deletions src/task/process_dataset/bootstrap_sc_counts/script.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import anndata as ad
import numpy as np

# VIASH START
par = {
"input": "resources/neurips-2023-raw/sc_counts_reannotated_with_counts.h5ad",
"output": "output/sc_counts_bootstrapped_*.h5ad",
"obs_fraction": 0.95,
"var_fraction": 1
}
# VIASH END

# Load data
input_data = ad.read_h5ad(par["input"])

# Sample indices
obs_ix = np.random.choice(
input_data.obs_names,
int(input_data.n_obs * par["obs_fraction"]),
replace=False
)
var_ix = np.random.choice(
input_data.var_names,
int(input_data.n_vars * par["var_fraction"]),
replace=False
)

# Subset AnnData object
output_data = input_data[obs_ix, var_ix].copy()

# Write output
output_data.write_h5ad(par["output"], compression="gzip")
31 changes: 0 additions & 31 deletions src/task/workflows/run_benchmark/config.vsh.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,13 +29,6 @@ functionality:
direction: output
description: A yaml file containing the scores of each of the methods
default: score_uns.yaml
- name: "--stability_scores"
type: file
required: false
must_exist: false
direction: output
description: A yaml file containing the scores of each of the methods on bootstrapped datasets
default: stability_uns.yaml
- name: "--method_configs"
type: file
required: true
Expand Down Expand Up @@ -66,28 +59,6 @@ functionality:
type: string
multiple: true
description: A list of metric ids to run. If not specified, all metric will be run.
- name: Stability Analysis.
description: Run a stability analysis on the methods.
arguments:
- name: --stability
type: boolean
description: Whether to run a stability analysis on the methods.
default: false
- name: --stability_num_replicates
type: integer
required: true
default: 10
description: Number of bootstraps to generate.
- name: --stability_obs_fraction
type: double
required: true
default: 0.95
description: Fraction of the training dataset obs to include in each bootstrap.
- name: --stability_var_fraction
type: double
required: true
default: 0.95
description: Fraction of the training & test dataset var to include in each bootstrap.
resources:
- type: nextflow_script
path: main.nf
Expand Down Expand Up @@ -115,8 +86,6 @@ functionality:
- name: metrics/mean_rowwise_error_r
- name: metrics/mean_cosine_sim_r
- name: metrics/mean_correlation_r
- name: process_dataset/bootstrap
- name: process_dataset/generate_id_map
repositories:
- name: openproblemsv2
type: github
Expand Down
126 changes: 20 additions & 106 deletions src/task/workflows/run_benchmark/main.nf
Original file line number Diff line number Diff line change
Expand Up @@ -24,34 +24,6 @@ metrics = [
mean_correlation_r
]

// which arguments to pass to the methods
methodFromState = { id, state, comp ->
def new_args = [
de_train_h5ad: state.de_train_h5ad,
id_map: state.id_map,
layer: state.layer,
output: 'predictions/$id.$key.output.h5ad',
output_model: null
]
if (comp.config.functionality.info.type == "control_method") {
new_args.de_test_h5ad = state.de_test_h5ad
}
new_args
}

// where to store the method output
methodToState = ["prediction": "output"]

// which arguments to pass to the metrics
metricFromState = [
de_test_h5ad: "de_test_h5ad",
layer: "layer",
prediction: "prediction"
]

// where to store the metric output
metricToState = ["metric_output": "output"]

// helper workflow for starting a workflow based on lists of yaml files
workflow auto {
findStates(params, meta.config)
Expand All @@ -75,10 +47,26 @@ workflow run_wf {
| run_benchmark_fun(
methods: methods,
metrics: metrics,
methodFromState: methodFromState,
methodToState: methodToState,
metricFromState: metricFromState,
metricToState: metricToState,
methodFromState: { id, state, comp ->
def new_args = [
de_train_h5ad: state.de_train_h5ad,
id_map: state.id_map,
layer: state.layer,
output: 'predictions/$id.$key.output.h5ad',
output_model: null
]
if (comp.config.functionality.info.type == "control_method") {
new_args.de_test_h5ad = state.de_test_h5ad
}
new_args
},
methodToState: ["prediction": "output"],
metricFromState: [
de_test_h5ad: "de_test_h5ad",
layer: "layer",
prediction: "prediction"
],
metricToState: ["metric_output": "output"],
methodAuto: [publish: "state"]
)
| joinStates { ids, states ->
Expand All @@ -90,15 +78,6 @@ workflow run_wf {
["output", [scores: score_uns_file]]
}

/**************************
* RUN STABILITY ANALYSIS *
**************************/
stability_ch = input_ch
| filter{ id, state ->
state.stability
}
| stability_wf

/******************************
* GENERATE OUTPUT YAML FILES *
******************************/
Expand All @@ -114,7 +93,6 @@ workflow run_wf {
// merge all of the output data
output_ch = score_ch
| mix(metadata_ch)
| mix(stability_ch)
| joinStates{ ids, states ->
def mergedStates = states.inject([:]) { acc, m -> acc + m }
[ids[0], mergedStates]
Expand All @@ -125,70 +103,6 @@ workflow run_wf {
}


workflow stability_wf {
take: input_ch

main:
output_ch = input_ch

| bootstrap.run(
fromState: [
train_h5ad: "de_train_h5ad",
test_h5ad: "de_test_h5ad",
num_replicates: "stability_num_replicates",
obs_fraction: "stability_obs_fraction",
var_fraction: "stability_var_fraction"
],

toState: [
de_train_h5ad: "output_train_h5ad",
de_test_h5ad: "output_test_h5ad"
]
)

// flatten bootstraps
| flatMap { id, state ->
return [state.de_train_h5ad, state.de_test_h5ad]
.transpose()
.withIndex()
.collect{ el, idx ->
[
id + "-bootstrap" + idx,
state + [
replicate: idx,
de_train_h5ad: el[0],
de_test_h5ad: el[1]
]
]
}
}

| generate_id_map.run(
fromState: [de_test_h5ad: "de_test_h5ad"],
toState: [id_map: "id_map"]
)

| run_benchmark_fun(
keyPrefix: "stability_",
methods: methods,
metrics: metrics,
methodFromState: methodFromState,
methodToState: methodToState,
metricFromState: metricFromState,
metricToState: metricToState
)

| joinStates { ids, states ->
def stability_uns = states.collect{it.score_uns}
def stability_uns_yaml_blob = toYamlBlob(stability_uns)
def stability_uns_file = tempFile("stability_uns.yaml")
stability_uns_file.write(stability_uns_yaml_blob)

["output", [stability_scores: stability_uns_file]]
}

emit: output_ch
}



Expand Down
80 changes: 80 additions & 0 deletions src/task/workflows/run_stability_analysis/config.vsh.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
functionality:
name: "run_stability_analysis"
namespace: "workflows"
argument_groups:
- name: Inputs
arguments:
- name: --sc_counts
__merge__: ../../api/file_sc_counts.yaml
required: true
direction: input
- name: "--id"
type: string
description: Unique identifier of the dataset.
required: true
- name: --layer
type: string
direction: input
default: sign_log10_pval
description: Which layer to use for prediction and evaluation.
- name: Bootstrapping arguments
description: Define the sampling strategy for the stability analysis.
arguments:
- name: --bootstrap_num_replicates
type: integer
required: true
default: 10
description: Number of bootstraps to generate.
- name: --bootstrap_obs_fraction
type: double
required: true
default: 0.95
description: Fraction of the obs of the sc_counts to include in each bootstrap.
- name: --bootstrap_var_fraction
type: double
required: true
default: 1
description: Fraction of the var of the sc_counts to include in each bootstrap.
- name: Outputs
arguments:
- name: "--scores"
type: file
required: true
direction: output
description: A yaml file containing the scores of each of the methods
default: score_uns.yaml
- name: Arguments
arguments:
- name: "--method_ids"
type: string
multiple: true
description: A list of method ids to run. If not specified, all methods will be run.
- name: "--metric_ids"
type: string
multiple: true
description: A list of metric ids to run. If not specified, all metric will be run.
resources:
- type: nextflow_script
path: main.nf
entrypoint: run_wf
- type: file
path: "../../api/task_info.yaml"
dependencies:
- name: process_dataset/bootstrap_sc_counts
- name: workflows/process_dataset
- name: workflows/run_benchmark
repositories:
- name: openproblemsv2
type: github
repo: openproblems-bio/openproblems-v2
tag: main_build
platforms:
- type: nextflow
config:
script: |
process.errorStrategy = 'ignore'
trace {
enabled = true
overwrite = true
file = "${params.publish_dir}/trace.txt"
}
Loading

0 comments on commit 408e292

Please sign in to comment.