11import traceback
2- import urllib .request
3- from enum import StrEnum
42from pathlib import Path
53
64import pytest
75
8- from tests .conftest import DEVICE
9- from tests .models .conftest import make_model_calculator_consistency_test
6+ from tests .conftest import DEVICE , DTYPE
7+ from tests .models .conftest import (
8+ consistency_test_simstate_fixtures ,
9+ make_model_calculator_consistency_test ,
10+ make_validate_model_outputs_test ,
11+ )
1012
1113
1214try :
1315 from nequip .ase import NequIPCalculator
16+ from nequip .scripts .compile import main
1417
1518 from torch_sim .models .nequip_framework import (
1619 NequIPFrameworkModel ,
2225 )
2326
2427
25- class NequIPUrls (StrEnum ):
26- """Checkpoint download URLs for NequIP models."""
27-
28- Si = "https://github.com/abhijeetgangan/pt_model_checkpoints/raw/refs/heads/main/nequip/Si.nequip.pth"
29-
30-
3128@pytest .fixture (scope = "session" )
32- def model_path_nequip (tmp_path_factory : pytest .TempPathFactory ) -> Path :
33- tmp_path = tmp_path_factory .mktemp ("nequip_checkpoints" )
34- model_name = "Si.nequip.pth"
35- model_path = Path (tmp_path ) / model_name
36-
37- if not model_path .is_file ():
38- urllib .request .urlretrieve (NequIPUrls .Si , model_path ) # noqa: S310
29+ def compiled_nequip_model_path (tmp_path_factory : pytest .TempPathFactory ) -> Path :
30+ """Compile NequIP OAM-L model from nequip.net."""
31+ tmp_path = tmp_path_factory .mktemp ("nequip_compiled" )
32+ output_model_name = "mir-group__NequIP-OAM-L__0.1.nequip.pt2"
33+ output_path = Path (tmp_path ) / output_model_name
34+
35+ main (
36+ args = [
37+ "nequip.net:mir-group/NequIP-OAM-L:0.1" ,
38+ str (output_path ),
39+ "--mode" ,
40+ "aotinductor" ,
41+ "--device" ,
42+ DEVICE .type ,
43+ "--target" ,
44+ "ase" ,
45+ ]
46+ )
3947
40- return model_path
48+ return output_path
4149
4250
43- @pytest .fixture
44- def nequip_model (model_path_nequip : Path ) -> NequIPFrameworkModel :
51+ @pytest .fixture ( scope = "session" )
52+ def nequip_model (compiled_nequip_model_path : Path ) -> NequIPFrameworkModel :
4553 """Create an NequIPModel wrapper for the pretrained model."""
4654 compiled_model , (r_max , type_names ) = from_compiled_model (
47- model_path_nequip , device = DEVICE
55+ compiled_nequip_model_path , device = DEVICE
4856 )
4957 return NequIPFrameworkModel (
5058 model = compiled_model ,
@@ -54,16 +62,18 @@ def nequip_model(model_path_nequip: Path) -> NequIPFrameworkModel:
5462 )
5563
5664
57- @pytest .fixture
58- def nequip_calculator (model_path_nequip : Path ) -> NequIPCalculator :
65+ @pytest .fixture ( scope = "session" )
66+ def nequip_calculator (compiled_nequip_model_path : Path ) -> NequIPCalculator :
5967 """Create an NequIPCalculator for the pretrained model."""
60- return NequIPCalculator .from_compiled_model (str (model_path_nequip ), device = DEVICE )
68+ return NequIPCalculator .from_compiled_model (
69+ str (compiled_nequip_model_path ), device = DEVICE
70+ )
6171
6272
63- def test_nequip_initialization (model_path_nequip : Path ) -> None :
73+ def test_nequip_initialization (compiled_nequip_model_path : Path ) -> None :
6474 """Test that the NequIP model initializes correctly."""
6575 compiled_model , (r_max , type_names ) = from_compiled_model (
66- model_path_nequip , device = DEVICE
76+ compiled_nequip_model_path , device = DEVICE
6777 )
6878 model = NequIPFrameworkModel (
6979 model = compiled_model ,
@@ -78,7 +88,14 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
7888 test_name = "nequip" ,
7989 model_fixture_name = "nequip_model" ,
8090 calculator_fixture_name = "nequip_calculator" ,
81- sim_state_names = ("si_sim_state" , "rattled_si_sim_state" ),
91+ sim_state_names = consistency_test_simstate_fixtures ,
92+ energy_atol = 5e-5 ,
93+ dtype = DTYPE ,
94+ device = DEVICE ,
8295)
8396
84- # TODO (AG): Test multi element models
97+ test_nequip_model_outputs = make_validate_model_outputs_test (
98+ model_fixture_name = "nequip_model" ,
99+ dtype = DTYPE ,
100+ device = DEVICE ,
101+ )
0 commit comments