Skip to content

Commit

Permalink
reinterpret argument settings (openproblems-bio#62)
Browse files Browse the repository at this point in the history
* reinterpret argument settings

* fix stability script

* validate input arguments
  • Loading branch information
rcannood authored Jun 2, 2024
1 parent 753f37c commit a161cfd
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 17 deletions.
1 change: 1 addition & 0 deletions scripts/run_stability_test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ id: neurips-2023-data
sc_counts: resources/neurips-2023-raw/sc_counts_reannotated_with_counts.h5ad
method_ids: ['ground_truth', 'sample', 'mean_across_celltypes', 'mean_across_compounds']
layer: t # test a different layer
bootstrap_num_replicates: 2
publish_dir: "output/test_stability_analysis"
output_state: "state.yaml"
HERE
Expand Down
45 changes: 28 additions & 17 deletions src/task/methods/transformer_ensemble/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,34 +42,38 @@

# train and predict models
argsets = [
# Note by author - weight_df1: 0.5 (utilizing std, mean, and clustering sampling, yielding 0.551)
{
"name": "trained_models_kmeans_mean_std",
"name": "weight_df1",
"mean_std": "mean_std",
"uncommon": False,
"sampling_strategy": "k-means",
"weight": 0.4,
"sampling_strategy": "random",
"weight": 0.5,
},
# Note by author - weight_df2: 0.25 (excluding uncommon elements, resulting in 0.559)
{
"name": "trained_models_kmeans_mean_std_trueuncommon",
"name": "weight_df2",
"mean_std": "mean_std",
"uncommon": True,
"sampling_strategy": "k-means",
"weight": 0.1,
"sampling_strategy": "random",
"weight": 0.25,
},
# Note by author - weight_df3: 0.25 (leveraging clustering sampling, achieving 0.575)
{
"name": "trained_models_kmeans_mean",
"mean_std": "mean",
"uncommon": False,
"name": "weight_df3",
"mean_std": "mean_std",
"uncommon": False, # should this be set to False or True?
"sampling_strategy": "k-means",
"weight": 0.2,
"weight": 0.25,
},
# Note by author - weight_df4: 0.3 (incorporating mean, random sampling, and excluding std, attaining 0.554)
{
"name": "trained_models_nonkmeans_mean",
"name": "weight_df4",
"mean_std": "mean",
"uncommon": False,
"uncommon": False, # should this be set to False or True?
"sampling_strategy": "random",
"weight": 0.3,
},
}
]


Expand All @@ -90,6 +94,8 @@
one_hot_encode_features, targets, one_hot_test = (
prepare_augmented_data_mean_only(de_train=de_train, id_map=id_map)
)
else:
raise ValueError("Invalid mean_std argument")

print(f"> Train model", flush=True)
if argset["sampling_strategy"] == "k-means":
Expand All @@ -104,7 +110,7 @@
device=device,
mean_std=argset["mean_std"],
)
else:
elif argset["sampling_strategy"] == "random":
label_reducer, scaler, transformer_model = train_non_k_means_strategy(
n_components=n_components,
d_model=d_model,
Expand All @@ -116,6 +122,8 @@
device=device,
mean_std=argset["mean_std"],
)
else:
raise ValueError("Invalid sampling_strategy argument")

print(f"> Predict model", flush=True)
unseen_data = torch.tensor(one_hot_test, dtype=torch.float32).to(device)
Expand Down Expand Up @@ -145,9 +153,12 @@
predictions.append(pred)

print(f"Combine predictions", flush=True)
weighted_pred = sum(
[argset["weight"] * pred for argset, pred in zip(argsets, predictions)]
) / sum([argset["weight"] for argset in argsets])
# compute weighted sum
sum_weights = sum([argset["weight"] for argset in argsets])
weighted_pred = sum([
pred * argset["weight"] / sum_weights
for argset, pred in zip(argsets, predictions)
])


print('Write output to file', flush=True)
Expand Down

0 comments on commit a161cfd

Please sign in to comment.