Skip to content

Commit 8aa22d0

Browse files
committed
fix: use DEVICE.type in compile args
1 parent 8805ee0 commit 8aa22d0

File tree

1 file changed

+11
-9
lines changed

1 file changed

+11
-9
lines changed

tests/models/test_nequip_framework.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path
3939
"--mode",
4040
"aotinductor",
4141
"--device",
42-
"cuda",
42+
DEVICE.type,
4343
"--target",
4444
"ase",
4545
]
@@ -48,7 +48,7 @@ def compiled_nequip_model_path(tmp_path_factory: pytest.TempPathFactory) -> Path
4848
return output_path
4949

5050

51-
@pytest.fixture
51+
@pytest.fixture(scope="session")
5252
def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel:
5353
"""Create an NequIPModel wrapper for the pretrained model."""
5454
compiled_model, (r_max, type_names) = from_compiled_model(
@@ -62,16 +62,18 @@ def nequip_model(compiled_nequip_model_path: Path) -> NequIPFrameworkModel:
6262
)
6363

6464

65-
@pytest.fixture
66-
def nequip_calculator(model_path_nequip: Path) -> NequIPCalculator:
65+
@pytest.fixture(scope="session")
66+
def nequip_calculator(compiled_nequip_model_path: Path) -> NequIPCalculator:
6767
"""Create an NequIPCalculator for the pretrained model."""
68-
return NequIPCalculator.from_compiled_model(str(model_path_nequip), device=DEVICE)
68+
return NequIPCalculator.from_compiled_model(
69+
str(compiled_nequip_model_path), device=DEVICE
70+
)
6971

7072

71-
def test_nequip_initialization(model_path_nequip: Path) -> None:
73+
def test_nequip_initialization(compiled_nequip_model_path: Path) -> None:
7274
"""Test that the NequIP model initializes correctly."""
7375
compiled_model, (r_max, type_names) = from_compiled_model(
74-
model_path_nequip, device=DEVICE
76+
compiled_nequip_model_path, device=DEVICE
7577
)
7678
model = NequIPFrameworkModel(
7779
model=compiled_model,
@@ -82,7 +84,7 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
8284
assert model._device == DEVICE # noqa: SLF001
8385

8486

85-
test_metatomic_consistency = make_model_calculator_consistency_test(
87+
test_nequip_consistency = make_model_calculator_consistency_test(
8688
test_name="nequip",
8789
model_fixture_name="nequip_model",
8890
calculator_fixture_name="nequip_calculator",
@@ -92,7 +94,7 @@ def test_nequip_initialization(model_path_nequip: Path) -> None:
9294
device=DEVICE,
9395
)
9496

95-
test_metatomic_model_outputs = make_validate_model_outputs_test(
97+
test_nequip_model_outputs = make_validate_model_outputs_test(
9698
model_fixture_name="nequip_model",
9799
dtype=DTYPE,
98100
device=DEVICE,

0 commit comments

Comments
 (0)