@@ -39,7 +39,7 @@ def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path
3939 "--mode" ,
4040 "aotinductor" ,
4141 "--device" ,
42- "cuda" ,
42+ DEVICE . type ,
4343 "--target" ,
4444 "ase" ,
4545 ]
@@ -48,7 +48,7 @@ def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path
4848 return output_path
4949
5050
51- @pytest .fixture
51+ @pytest .fixture ( scope = "session" )
5252def nequip_model (compiled_nequip_model_path : Path ) -> NequIPFrameworkModel :
5353 """Create an NequIPModel wrapper for the pretrained model."""
5454 compiled_model , (r_max , type_names ) = from_compiled_model (
@@ -62,16 +62,18 @@ def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel:
6262 )
6363
6464
65- @pytest .fixture
66- def nequip_calculator (model_path_nequip : Path ) -> NequIPCalculator :
65+ @pytest .fixture ( scope = "session" )
66+ def nequip_calculator (compiled_nequip_model_path : Path ) -> NequIPCalculator :
6767 """Create an NequIPCalculator for the pretrained model."""
68- return NequIPCalculator .from_compiled_model (str (model_path_nequip ), device = DEVICE )
68+ return NequIPCalculator .from_compiled_model (
69+ str (compiled_nequip_model_path ), device = DEVICE
70+ )
6971
7072
71- def test_nequip_initialization (model_path_nequip : Path ) -> None :
73+ def test_nequip_initialization (compiled_nequip_model_path : Path ) -> None :
7274 """Test that the NequIP model initializes correctly."""
7375 compiled_model , (r_max , type_names ) = from_compiled_model (
74- model_path_nequip , device = DEVICE
76+ compiled_nequip_model_path , device = DEVICE
7577 )
7678 model = NequIPFrameworkModel (
7779 model = compiled_model ,
@@ -82,7 +84,7 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
8284 assert model ._device == DEVICE # noqa: SLF001
8385
8486
85- test_metatomic_consistency = make_model_calculator_consistency_test (
87+ test_nequip_consistency = make_model_calculator_consistency_test (
8688 test_name = "nequip" ,
8789 model_fixture_name = "nequip_model" ,
8890 calculator_fixture_name = "nequip_calculator" ,
@@ -92,7 +94,7 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
9294 device = DEVICE ,
9395)
9496
95- test_metatomic_model_outputs = make_validate_model_outputs_test (
97+ test_nequip_model_outputs = make_validate_model_outputs_test (
9698 model_fixture_name = "nequip_model" ,
9799 dtype = DTYPE ,
98100 device = DEVICE ,
0 commit comments