Skip to content

Commit

Permalink
move cell type checking to ahead of training with each drug (openprob…
Browse files Browse the repository at this point in the history
…lems-bio#39)

* randomly select cell type if default does not exists

* move cell type checking to ahead of training with each drug

* Update src/task/methods/scape/script.py

Co-authored-by: Robrecht Cannoodt <rcannood@gmail.com>

* Update src/task/methods/scape/script.py

Co-authored-by: Robrecht Cannoodt <rcannood@gmail.com>

* Update src/task/methods/scape/script.py

Co-authored-by: Robrecht Cannoodt <rcannood@gmail.com>

* Update src/task/methods/scape/script.py

Co-authored-by: Robrecht Cannoodt <rcannood@gmail.com>

---------

Co-authored-by: Mengbo Wang <wang4887@gilbreth-h002.rcac.purdue.edu>
Co-authored-by: Robrecht Cannoodt <rcannood@gmail.com>
  • Loading branch information
3 people authored May 20, 2024
1 parent a6413da commit ca8854b
Showing 1 changed file with 24 additions and 9 deletions.
33 changes: 24 additions & 9 deletions src/task/methods/scape/script.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@
epochs_enhanced = 2,
n_genes = 10,
n_genes_enhanced = 10,
n_drugs = 5,
min_n_top_drugs = 0,
# n_drugs = 5,
n_drugs = None,
# min_n_top_drugs = 0,
min_n_top_drugs = 50,
)
meta = dict(
temp_dir = "/tmp"
Expand All @@ -47,12 +49,23 @@
# load log pvals
df_de = scape.io.load_slogpvals(par['de_train']).drop(columns=["id", "split"], axis=1, errors="ignore")

# if held-out cell type is not in the data, select a random cell type
if par["cell"] not in df_de.index.get_level_values("cell_type").unique():
print(f"Input cell type ({par['cell']}) not found in the data.")
par["cell"] = np.random.choice(df_de.index.get_level_values("cell_type").unique())
print(f"Randomly selecting a cell type from the data: {par['cell']}.")

def confirm_celltype(df_de, cell, sm_name=None):
cells = None
if sm_name is None:
cells = df_de.index.get_level_values("cell_type").unique()
else:
cells = df_de[df_de.index.get_level_values('sm_name')==sm_name].index.get_level_values("cell_type").unique()

if cell in cells:
return cell
else:
print(f"Input cell type ({cell}) not found in the" + f"drug {sm_name}" if sm_name is not None else "" + " data.")
cell_ = np.random.choice(cells)
print(f"Randomly selecting a cell type from the data: {cell_}.")
return cell_

par["cell"] = confirm_celltype(df_de, par["cell"])

# load logfc
adata = anndata.read_h5ad(par["de_train_h5ad"])
Expand Down Expand Up @@ -85,8 +98,9 @@
for i, d in enumerate(drugs):
print(i, d)
scm = scape.model.create_default_model(par["n_genes"], df_de, df_lfc)
cell = confirm_celltype(df_de, par["cell"], d)
result = scm.train(
val_cells=[par["cell"]],
val_cells=[cell],
val_drugs=[d],
input_columns=top_genes,
epochs=par["epochs"],
Expand Down Expand Up @@ -133,8 +147,9 @@
for i, d in enumerate(top_drugs):
print(i, d)
scm = scape.model.create_default_model(par["n_genes_enhanced"], df_de_c, df_lfc_c)
cell = confirm_celltype(df_de, par["cell"], d)
result = scm.train(
val_cells=[par["cell"]],
val_cells=[cell],
val_drugs=[d],
input_columns=top_genes,
epochs=par["epochs_enhanced"],
Expand Down

0 comments on commit ca8854b

Please sign in to comment.