Skip to content

Commit

Permalink
Merge branch 'refs/heads/dev_experiments' into dev
Browse files Browse the repository at this point in the history
# Conflicts:
#	experiments/DTI/visualize.py
  • Loading branch information
Old-Shatterhand committed Sep 3, 2024
2 parents 76c633d + 0258181 commit 6d66bde
Show file tree
Hide file tree
Showing 11 changed files with 493 additions and 236 deletions.
3 changes: 3 additions & 0 deletions experiments/DTI/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,11 @@ def main(full_path: Path):
full_path: Path to the folder holding the runs for all tools
"""
for tool in TECHNIQUES:
if tool == "datasail":
continue
train_tool(full_path, tool)


if __name__ == '__main__':
main(Path(sys.argv[1]))

243 changes: 130 additions & 113 deletions experiments/DTI/visualize.py

Large diffs are not rendered by default.

12 changes: 7 additions & 5 deletions experiments/MPP/split.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def split_w_datasail(base_path: Path, name: str, techniques: List[str], solver:
# print("DataSAIL skipping", name)
# return

with open(base_path / "time.txt", "w") as time:
print("Start", file=time)
# with open(base_path / "time.txt", "w") as time:
# print("Start", file=time)

df = prep_moleculenet(name)
start = T.time()
Expand All @@ -56,8 +56,8 @@ def split_w_datasail(base_path: Path, name: str, techniques: List[str], solver:
max_sec=1000,
epsilon=0.1,
)
with open(base_path / "time.txt", "a") as time:
print("I1+C1", T.time() - start, file=time)
# with open(base_path / "time.txt", "a") as time:
# print("I1+C1", T.time() - start, file=time)

save_datasail_splits(base_path, df, "ID", [(t, t) for t in techniques], e_splits=e_splits)

Expand Down Expand Up @@ -163,7 +163,7 @@ def split(full_path, name, solver="GUROBI"):
"""
Split the MoleculeNet datasets using different techniques.
"""
split_w_datasail(full_path / "datasail" / name, name, techniques=["I1e", "C1e"], solver=solver)
split_w_datasail(full_path / "datasail" / name, name, techniques=["I1e"], solver=solver)
# split_w_deepchem(full_path / "deepchem" / name, name, techniques=SPLITTERS.keys())
# split_w_lohi(full_path / "lohi" / name, name)

Expand All @@ -177,6 +177,8 @@ def specific():


if __name__ == '__main__':
split_w_datasail(Path("/") / "scratch" / "SCRATCH_SAS" / "roman" / "DataSAIL" / "v10" / "MPP" / "datasail" / "hiv", "hiv", ["I1e"])
exit(0)
if len(sys.argv) == 1:
specific()
elif len(sys.argv) == 2:
Expand Down
18 changes: 13 additions & 5 deletions experiments/MPP/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,11 @@ def train_run(run_path: Path, data_path: Path, name: str, model: str) -> float:
m.fit(x_train, y_train)

test_predictions = m.predict(x_test)
test_perf = metric[DATASETS[name][2]](y_test, test_predictions)
scoring = metric[DATASETS[name][2]]
if name == "muv":
test_perf = np.mean([scoring(y_test[:, i], test_predictions[:, i]) for i in range(y_test.shape[1])])
else:
test_perf = scoring(y_test, test_predictions)

return test_perf

Expand Down Expand Up @@ -174,9 +178,8 @@ def train_model(base_path: Path, data_path: Path, model: str, tool: str, name: s
pd.DataFrame: Dataframe of the performance of the models
"""
perf = {}
# for tech in set(TECHNIQUES[tool]).intersection(set(DRUG_TECHNIQUES)):
tech = "C1e"
perf.update(train_tech(base_path / tech, data_path, model, tech, name))
for tech in ["I1e", "C1e"]: # set(TECHNIQUES[tool]).intersection(set(DRUG_TECHNIQUES)):
perf.update(train_tech(base_path / tech, data_path, model, tech, name))
# message(tool, name, model[:-2], tech)
df = pd.DataFrame(list(perf.items()), columns=["name", "perf"])
df["model"] = model
Expand Down Expand Up @@ -224,14 +227,19 @@ def train(full_path: Path, name: Optional[str] = None) -> None:
"""
if name is None:
for name in DATASETS:
# train_dataset(full_path, name)
if name in ["qm7", "qm8", "qm9", "lipophilicity", "esol", "freesolv", "pcba", "tox21", "clintox", "muv"]:
continue
train_tool(full_path, "datasail", name)
# train_dataset(full_path, name)
else:
train_dataset(full_path, name)


if __name__ == '__main__':
train_tool(Path(sys.argv[1]), "datasail", "hiv")
exit(0)
if len(sys.argv) == 2:
train(Path(sys.argv[1]))
# train_tool(Path(sys.argv[1]), "datasail", "muv")
elif len(sys.argv) == 3:
train_dataset(Path(sys.argv[1]), sys.argv[2])
Loading

0 comments on commit 6d66bde

Please sign in to comment.