Skip to content

Commit 53cb989

Browse files
committed
fea: test multi-atom models
1 parent 9151970 commit 53cb989

File tree

2 files changed

+46
-29
lines changed

2 files changed

+46
-29
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ ignore = [
100100
"FIX002", # Line contains TODO, consider resolving the issue
101101
"N803", # Variable name should be lowercase
102102
"N806", # Uppercase letters in variable names
103-
"PLC0415", # import` should be at the top-level of a file
103+
"PLC0415", # import should be at the top-level of a file
104104
"PLR0912", # too many branches
105105
"PLR0913", # too many function arguments
106106
"PLR2004", # Magic value used in comparison, consider replacing {value} with a constant variable
Lines changed: 45 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
import traceback
2-
import urllib.request
3-
from enum import StrEnum
42
from pathlib import Path
53

64
import pytest
75

8-
from tests.conftest import DEVICE
9-
from tests.models.conftest import make_model_calculator_consistency_test
6+
from tests.conftest import DEVICE, DTYPE
7+
from tests.models.conftest import (
8+
consistency_test_simstate_fixtures,
9+
make_model_calculator_consistency_test,
10+
make_validate_model_outputs_test,
11+
)
1012

1113

1214
try:
1315
from nequip.ase import NequIPCalculator
16+
from nequip.scripts.compile import main
1417

1518
from torch_sim.models.nequip_framework import (
1619
NequIPFrameworkModel,
@@ -22,29 +25,34 @@
2225
)
2326

2427

25-
class NequIPUrls(StrEnum):
26-
"""Checkpoint download URLs for NequIP models."""
27-
28-
Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth"
29-
30-
3128
@pytest.fixture(scope="session")
32-
def model_path_nequip(tmp_path_factory: pytest.TempPathFactory) -> Path:
33-
tmp_path = tmp_path_factory.mktemp("nequip_checkpoints")
34-
model_name = "Si.nequip.pth"
35-
model_path = Path(tmp_path) / model_name
36-
37-
if not model_path.is_file():
38-
urllib.request.urlretrieve(NequIPUrls.Si, model_path) # noqa: S310
29+
def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path:
30+
"""Compile NequIP OAM-L model from nequip.net."""
31+
tmp_path = tmp_path_factory.mktemp("nequip_compiled")
32+
output_model_name = "mir-group__NequIP-OAM-L__0.1.nequip.pt2"
33+
output_path = Path(tmp_path) / output_model_name
34+
35+
main(
36+
args=[
37+
"nequip.net:mir-group/NequIP-OAM-L:0.1",
38+
str(output_path),
39+
"--mode",
40+
"aotinductor",
41+
"--device",
42+
DEVICE.type,
43+
"--target",
44+
"ase",
45+
]
46+
)
3947

40-
return model_path
48+
return output_path
4149

4250

43-
@pytest.fixture
44-
def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel:
51+
@pytest.fixture(scope="session")
52+
def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel:
4553
"""Create an NequIPModel wrapper for the pretrained model."""
4654
compiled_model, (r_max, type_names) = from_compiled_model(
47-
model_path_nequip, device=DEVICE
55+
compiled_nequip_model_path, device=DEVICE
4856
)
4957
return NequIPFrameworkModel(
5058
model=compiled_model,
@@ -54,16 +62,18 @@ def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel:
5462
)
5563

5664

57-
@pytest.fixture
58-
def nequip_calculator(model_path_nequip: Path) -> NequIPCalculator:
65+
@pytest.fixture(scope="session")
66+
def nequip_calculator(compiled_nequip_model_path: Path) -> NequIPCalculator:
5967
"""Create an NequIPCalculator for the pretrained model."""
60-
return NequIPCalculator.from_compiled_model(str(model_path_nequip), device=DEVICE)
68+
return NequIPCalculator.from_compiled_model(
69+
str(compiled_nequip_model_path), device=DEVICE
70+
)
6171

6272

63-
def test_nequip_initialization(model_path_nequip: Path) -> None:
73+
def test_nequip_initialization(compiled_nequip_model_path: Path) -> None:
6474
"""Test that the NequIP model initializes correctly."""
6575
compiled_model, (r_max, type_names) = from_compiled_model(
66-
model_path_nequip, device=DEVICE
76+
compiled_nequip_model_path, device=DEVICE
6777
)
6878
model = NequIPFrameworkModel(
6979
model=compiled_model,
@@ -78,7 +88,14 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
7888
test_name="nequip",
7989
model_fixture_name="nequip_model",
8090
calculator_fixture_name="nequip_calculator",
81-
sim_state_names=("si_sim_state", "rattled_si_sim_state"),
91+
sim_state_names=consistency_test_simstate_fixtures,
92+
energy_atol=5e-5,
93+
dtype=DTYPE,
94+
device=DEVICE,
8295
)
8396

84-
# TODO (AG): Test multi element models
97+
test_nequip_model_outputs = make_validate_model_outputs_test(
98+
model_fixture_name="nequip_model",
99+
dtype=DTYPE,
100+
device=DEVICE,
101+
)

0 commit comments

Comments
 (0)