Skip to content

Commit cbb0874

Browse files
committed
Support single-seq methods in ensembling
1 parent 5145796 commit cbb0874

File tree

1 file changed

+24
-11
lines changed

1 file changed

+24
-11
lines changed

posebench/models/ensemble_generation.py

Lines changed: 24 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# -------------------------------------------------------------------------------------------------------------------------------------
44

55
import ast
6+
import copy
67
import glob
78
import logging
89
import multiprocessing
@@ -2243,6 +2244,12 @@ def main(cfg: DictConfig):
22432244
"""Generate predictions for a protein-ligand target pair using an ensemble of methods."""
22442245
os.makedirs(cfg.temp_protein_dir, exist_ok=True)
22452246

2247+
with open_dict(cfg):
2248+
# NOTE: besides their output directories, single-sequence baselines are treated like their multi-sequence counterparts
2249+
output_dir = copy.deepcopy(cfg.output_dir)
2250+
cfg.method = cfg.method.removesuffix("_ss")
2251+
cfg.output_dir = output_dir
2252+
22462253
if list(cfg.ensemble_methods) == ["neuralplexer"] and cfg.neuralplexer_no_ilcl:
22472254
with open_dict(cfg):
22482255
cfg.output_dir = cfg.output_dir.replace(
@@ -2360,7 +2367,7 @@ def main(cfg: DictConfig):
23602367
continue
23612368

23622369
# ensure an input protein structure is available
2363-
if type(row.protein_input) == str and os.path.exists(row.protein_input):
2370+
if isinstance(row.protein_input, str) and os.path.exists(row.protein_input):
23642371
temp_protein_filepath = row.protein_input
23652372
else:
23662373
if cfg.ensemble_benchmarking:
@@ -2371,7 +2378,7 @@ def main(cfg: DictConfig):
23712378
# NOTE: a placeholder protein sequence is used when making ligand-only predictions
23722379
row_protein_input = (
23732380
row.protein_input
2374-
if type(row.protein_input) == str and len(row.protein_input) > 0
2381+
if isinstance(row.protein_input, str) and len(row.protein_input) > 0
23752382
else LIGAND_ONLY_RECEPTOR_PLACEHOLDER_SEQUENCE
23762383
)
23772384
row_name = (
@@ -2443,15 +2450,21 @@ def main(cfg: DictConfig):
24432450
ranked_predictions,
24442451
temp_protein_filepath,
24452452
row.name,
2446-
None
2447-
if isinstance(row, np.ndarray) and np.isnan(row.ligand_numbers).any()
2448-
else row.ligand_numbers,
2449-
None
2450-
if isinstance(row, np.ndarray) and np.isnan(row.ligand_names).any()
2451-
else row.ligand_names,
2452-
None
2453-
if isinstance(row, np.ndarray) and np.isnan(row.ligand_tasks).any()
2454-
else row.ligand_tasks,
2453+
(
2454+
None
2455+
if isinstance(row, np.ndarray) and np.isnan(row.ligand_numbers).any()
2456+
else row.ligand_numbers
2457+
),
2458+
(
2459+
None
2460+
if isinstance(row, np.ndarray) and np.isnan(row.ligand_names).any()
2461+
else row.ligand_names
2462+
),
2463+
(
2464+
None
2465+
if isinstance(row, np.ndarray) and np.isnan(row.ligand_tasks).any()
2466+
else row.ligand_tasks
2467+
),
24552468
cfg,
24562469
)
24572470
logger.info(f"Ensemble generation for target {row.name} has been completed.")

0 commit comments

Comments
 (0)