Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion lambench/models/ase_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,16 @@ def evaluate(
import torch

torch.set_default_dtype(torch.float32)
return self.run_ase_dptest(self, task.test_data, task.dispersion_correction)
# Use corresponding DFT label for models supporting OMol25 on Molecules tasks
if isinstance(task.test_data, dict):
if self.supports_omol and self.model_domain == "molecules":
data_path = task.test_data["wB97"]
else:
data_path = task.test_data["PBE"]
else:
data_path = task.test_data

return self.run_ase_dptest(self, data_path, task.dispersion_correction)
elif isinstance(task, CalculatorTask):
if task.task_name == "nve_md":
from lambench.tasks.calculator.nve_md.nve_md import (
Expand Down
2 changes: 1 addition & 1 deletion lambench/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class BaseTask(BaseModel):
"""

task_name: str
test_data: Path
test_data: Path | dict[str, Path]
task_config: ClassVar[Path]
model_config = ConfigDict(extra="allow")
workdir: Path = Path(tempfile.gettempdir()) / "lambench"
Expand Down
12 changes: 8 additions & 4 deletions lambench/tasks/direct/direct_tasks.yml
Original file line number Diff line number Diff line change
@@ -1,13 +1,9 @@
ANI:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
HEA25_S:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25S"
HEA25_bulk:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HEA25"
MoS2:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MoS2"
MD22:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
REANN_CO2_Ni100:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/REANN_CO2_Ni100"
NequIP_NC_2022:
Expand All @@ -24,6 +20,10 @@ HPt_NC_2022:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/HPt_NC2022"
Ca_batteries_CM2021:
test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/Ca_batteries"
AQM:
test_data:
PBE: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000"
wB97: "/bohr/temp-lambench-ood-5zz5/v3/AQM-sol-PBE__downsampled_1000_OMol-wb97mv-def2tzvpd-ORCA600"
## DEPRECATED
# Collision:
# test_data: "/bohr/lambench-ood-zwtr/v2/LAMBench-TestData-v2/Collision"
Expand All @@ -39,3 +39,7 @@ Ca_batteries_CM2021:
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/HEMC_HEMB"
# Torsionnet500:
# test_data: "/bohr/lambench-ood-zwtr/v1/OOD_test_data_v2/raw_torsionnet500"
# ANI:
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/ANI"
# MD22:
# test_data: "/bohr/lambench-ood-zwtr/v4/LAMBench-TestData-v3/MD22"
12 changes: 11 additions & 1 deletion lambench/workflow/dflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,16 @@ def submit_tasks_dflow(
name = f"{task.task_name}--{model.model_name}"
# dflow task name should be alphanumeric
name = "".join([c if c.isalnum() else "-" for c in name])
if task.test_data is not None:
# handle dict type test_data, NOTE: if the datasets are in the same parent folder, only need to upload the artifact once.
task_data = (
list(task.test_data.values())[0]
if isinstance(task.test_data, dict)
else task.test_data
)
else:
task_data = []
logging.warning(f"Submitting task {name} with test data paths: {task_data}")

dflow_task = Task(
name=name,
Expand All @@ -69,7 +79,7 @@ def submit_tasks_dflow(
"task": task,
"model": model,
},
artifacts={"dataset": get_dataset([model.model_path, task.test_data])},
artifacts={"dataset": get_dataset([model.model_path, task_data])},
executor=DispatcherExecutor(
machine_dict={
"batch_type": "Bohrium",
Expand Down
2 changes: 1 addition & 1 deletion lambench/workflow/entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def gather_task_type(
continue # Regular ASEModel does not support PropertyFinetuneTask
for task_name, task_params in task_configs.items():
if (task_names and task_name not in task_names) or task_class.__name__ in (
model_param["skip_tasks"]
model_param.get("skip_tasks", [])
):
continue
task = task_class(task_name=task_name, **task_params)
Expand Down