Skip to content

Commit 68c26a8

Browse files
committed
move to tests/ignite/__init__.py
1 parent 5aee3be commit 68c26a8

File tree

4 files changed

+13
-14
lines changed

4 files changed

+13
-14
lines changed

tests/ignite/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,14 @@
33

44
def cpu_and_maybe_cuda():
55
return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ())
6+
7+
8+
def is_mps_available_and_functional():
9+
if not torch.backends.mps.is_available():
10+
return False
11+
try:
12+
# Try to allocate a small tensor on the MPS device
13+
torch.tensor([1.0], device="mps")
14+
return True
15+
except RuntimeError:
16+
return False

tests/ignite/distributed/utils/test_serial.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import ignite.distributed as idist
44
from ignite.distributed.comp_models.base import _torch_version_le_112
5+
from tests.ignite import is_mps_available_and_functional
56
from tests.ignite.distributed.utils import (
67
_sanity_check,
78
_test_distrib__get_max_length,
@@ -12,7 +13,6 @@
1213
_test_distrib_new_group,
1314
_test_sync,
1415
)
15-
from ....utils_for_tests import is_mps_available_and_functional
1616

1717

1818
def test_no_distrib(capsys):

tests/ignite/engine/test_create_supervised.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
)
2626
from ignite.metrics import MeanSquaredError
2727

28-
from ...utils_for_tests import is_mps_available_and_functional # type: ignore
28+
from tests.ignite import is_mps_available_and_functional
2929

3030

3131
class DummyModel(torch.nn.Module):

tests/utils_for_tests.py

Lines changed: 0 additions & 12 deletions
This file was deleted.

0 commit comments

Comments
 (0)