Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions lambench/metrics/downstream_tasks_metrics.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,11 @@ vacancy:
domain: Inorganic Materials
metrics: [MAE]
dummy: {"MAE": 4.381}
binding_energy:
domain: Molecules
metrics: [MAE]
dummy: {"MAE": 8.098}
rxn_barrier:
domain: Molecules
metrics: [MAE]
dummy: {"MAE": 20.975}
2 changes: 2 additions & 0 deletions lambench/metrics/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,8 @@ def process_domain_specific_for_one_model(model: BaseLargeAtomModel):
"wiggle150",
"elastic",
"vacancy",
"binding_energy",
"rxn_barrier",
]:
applicability_results[record.task_name] = record.metrics
return applicability_results
Expand Down
25 changes: 18 additions & 7 deletions lambench/models/ase_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def __init__(self, *args, **kwargs):
self._calc = None

@property
def calc(self) -> Calculator:
def calc(self, head=None) -> Calculator:
"""ASE Calculator with the model loaded."""
calculator_dispatch = {
"MACE": self._init_mace_calculator,
Expand All @@ -101,7 +101,6 @@ def calc(self) -> Calculator:
f"Model {self.model_name} is not supported by ASEModel, using EMT as default calculator."
)
self._calc = EMT()

else:
self._calc = calculator_dispatch[self.model_family]()
return self._calc
Expand All @@ -114,10 +113,12 @@ def calc(self, value: Calculator):
def _init_mace_calculator(self) -> Calculator:
from mace.calculators import mace_mp

if self.model_domain == "molecules":
head = "omol"
else:
head = "oc20_usemppbe"
return mace_mp(
model=self.model_name.split("_")[-1],
device="cuda",
default_dtype="float64",
model=self.model_path, device="cuda", default_dtype="float64", head=head
)

def _init_orb_calculator(self) -> Calculator:
Expand All @@ -134,7 +135,7 @@ def _init_sevennet_calculator(self) -> Calculator:

model_config = {"model": self.model_name, "device": "cuda"}
if self.model_name == "7net-mf-ompa":
model_config["modal"] = "mpa"
model_config["modal"] = "omat24"
return SevenNetCalculator(**model_config)

def _init_equiformer_calculator(self) -> Calculator:
Expand Down Expand Up @@ -171,7 +172,7 @@ def _init_dp_calculator(self) -> Calculator:
else:
return DP(
model=self.model_path,
head="MP_traj_v024_alldata_mixu",
head="Omat24",
)

def _init_grace_calculator(self) -> Calculator:
Expand Down Expand Up @@ -290,6 +291,16 @@ def evaluate(
elif task.task_name == "vacancy":
from lambench.tasks.calculator.vacancy.vacancy import run_inference

assert task.test_data is not None
return {"metrics": run_inference(self, task.test_data)}
elif task.task_name == "rxn_barrier":
from lambench.tasks.calculator.rxn_barrier.barrier import run_inference

assert task.test_data is not None
return {"metrics": run_inference(self, task.test_data)}
elif task.task_name == "binding_energy":
from lambench.tasks.calculator.binding.binding import run_inference

assert task.test_data is not None
return {"metrics": run_inference(self, task.test_data)}
else:
Expand Down
60 changes: 60 additions & 0 deletions lambench/tasks/calculator/binding/binding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
"""
The test data is retrieved from:
J. Chem. Inf. Model. 2020, 60, 3, 1453–1460

https://pubs.acs.org/doi/10.1021/acs.jcim.9b01171

Only the PLF547 dataset is used.

"""

from ase.io import read
import numpy as np
from tqdm import tqdm
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
from pathlib import Path
from lambench.models.ase_models import ASEModel
import logging


def run_inference(
model: ASEModel,
test_data: Path,
) -> dict[str, float]:
active_site_atoms = read(test_data / "active_site.traj", ":")
drug_atoms = read(test_data / "drug.traj", ":")
combined_atoms = read(test_data / "combined.traj", ":")
labels = np.load(test_data / "labels.npy")

EV_TO_KCAL = 23.06092234465

calc = model.calc
preds = []
success_labels = []

for site, drug, combo, label in tqdm(
zip(active_site_atoms, drug_atoms, combined_atoms, labels)
):
try:
for atoms in (site, drug, combo):
atoms.calc = calc
atoms.info.update(
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
)

site_energy = site.get_potential_energy()
drug_energy = drug.get_potential_energy()
combo_energy = combo.get_potential_energy()

binding_energy = combo_energy - site_energy - drug_energy
preds.append(binding_energy * EV_TO_KCAL)
success_labels.append(label)
except Exception as e:
logging.warning(f"Failed to calculate binding energy for one sample: {e}")
continue

return {
"MAE": mean_absolute_error(success_labels, preds), # kcal/mol
"RMSE": root_mean_squared_error(success_labels, preds), # kcal/mol
"success_rate": len(success_labels) / len(labels),
}
6 changes: 6 additions & 0 deletions lambench/tasks/calculator/calculator_tasks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,9 @@ elastic:
vacancy:
test_data: /bohr/lambench-vacancy-a2xo/v1
calculator_params: null
binding_energy:
test_data: /bohr/lambench-binding-dlc6/v1/PLF547
calculator_params: null
rxn_barrier:
test_data: /bohr/lambench-BH876-uplk/v1/BH876
calculator_params: null
78 changes: 78 additions & 0 deletions lambench/tasks/calculator/rxn_barrier/barrier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
"""
The test data is retrieved from:

@misc{liang2025gold,
title={Gold-Standard Chemical Database 137 (GSCDB137): A diverse set of accurate energy differences for assessing and developing density functionals},
author={Jiashu Liang and Martin Head-Gordon},
year={2025},
eprint={2508.13468},
archivePrefix={arXiv},
primaryClass={physics.chem-ph},
url={https://arxiv.org/abs/2508.13468},
}

https://github.com/JiashuLiang/GSCDB

Only the BH876 dataset is used.

"""

from ase.io import read
import pandas as pd
import numpy as np
from tqdm import tqdm
from sklearn.metrics import root_mean_squared_error, mean_absolute_error
from pathlib import Path
from lambench.models.ase_models import ASEModel
import logging


def run_inference(
model: ASEModel,
test_data: Path,
) -> dict[str, float]:
lookup_table = pd.read_csv(test_data / "lookup_table.csv")
lookup_table.reset_index(inplace=True)
stoichiometry = pd.read_csv(test_data / "stoichiometry.csv")
traj = read(test_data / "BH876.traj", ":")

EV_TO_KCAL = 23.06092234465
HARTREE_TO_KCAL = 627.50947406

preds = []
labels = []
success = len(stoichiometry)

calc = model.calc

for i, row in tqdm(stoichiometry.iterrows()):
try:
reactions = row["Stoichiometry"].split(",")
num_species = len(reactions) // 2
pred = 0
for i in range(num_species):
stoi = float(reactions[2 * i])
reactant = reactions[2 * i + 1]
structure_index = lookup_table[
lookup_table["ID"] == reactant
].index.values[0]
atoms = traj[structure_index]
atoms.info.update(
{"fparam": np.array([atoms.info["charge"], atoms.info["spin"]])}
)
atoms.calc = calc
energy = atoms.get_potential_energy()
pred += stoi * energy
preds.append(pred * EV_TO_KCAL)
labels.append(row["Reference"] * HARTREE_TO_KCAL)
except Exception as e:
logging.warning(
f"Failed to calculate reaction energy for reaction: {row['Stoichiometry']}. Error: {e}"
)
success -= 1

return {
"MAE": mean_absolute_error(labels, preds), # kcal/mol
"RMSE": root_mean_squared_error(labels, preds), # kcal/mol
"success_rate": success / len(stoichiometry),
}