Skip to content

Commit 20d6b5b

Browse files
authored
skip tests when mps not functional (#3249)
1 parent 37d9a67 commit 20d6b5b

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

tests/ignite/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,3 +3,14 @@
33

44
def cpu_and_maybe_cuda():
55
return ("cpu",) + (("cuda",) if torch.cuda.is_available() else ())
6+
7+
8+
def is_mps_available_and_functional():
9+
if not torch.backends.mps.is_available():
10+
return False
11+
try:
12+
# Try to allocate a small tensor on the MPS device
13+
torch.tensor([1.0], device="mps")
14+
return True
15+
except RuntimeError:
16+
return False

tests/ignite/distributed/test_auto.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import ignite.distributed as idist
1414
from ignite.distributed.auto import auto_dataloader, auto_model, auto_optim, DistributedProxySampler
15+
from tests.ignite import is_mps_available_and_functional
1516

1617

1718
class DummyDS(Dataset):
@@ -179,6 +180,9 @@ def _test_auto_model_optimizer(ws, device):
179180
assert optimizer.backward_passes_per_step == backward_passes_per_step
180181

181182

183+
@pytest.mark.skipif(
184+
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
185+
)
182186
def test_auto_methods_no_dist():
183187
_test_auto_dataloader(1, 1, batch_size=1)
184188
_test_auto_dataloader(1, 1, batch_size=10, num_workers=2)

tests/ignite/distributed/test_launcher.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99

1010
import ignite.distributed as idist
1111
from ignite.distributed.utils import has_hvd_support, has_native_dist_support, has_xla_support
12+
from tests.ignite import is_mps_available_and_functional
1213

1314

1415
def test_parallel_wrong_inputs():
@@ -54,6 +55,9 @@ def execute(cmd, env=None):
5455
return str(process.stdout.read()) + str(process.stderr.read())
5556

5657

58+
@pytest.mark.skipif(
59+
torch.backends.mps.is_available() and not is_mps_available_and_functional(), reason="Skip if MPS not functional"
60+
)
5761
def test_check_idist_parallel_no_dist(exec_filepath):
5862
cmd = [sys.executable, "-u", exec_filepath]
5963
out = execute(cmd)

tests/ignite/engine/test_create_supervised.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525
)
2626
from ignite.metrics import MeanSquaredError
2727

28+
from tests.ignite import is_mps_available_and_functional
29+
2830

2931
class DummyModel(torch.nn.Module):
3032
def __init__(self, output_as_list=False):
@@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda():
485487
_test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
486488

487489

488-
@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
490+
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
489491
def test_create_supervised_trainer_on_mps():
490492
model_device = trainer_device = "mps"
491493
_test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
@@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
666668
_test_mocked_supervised_evaluator(evaluator_device="cuda")
667669

668670

669-
@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
671+
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
670672
def test_create_supervised_evaluator_on_mps():
671673
model_device = evaluator_device = "mps"
672674
_test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
673675
_test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
674676

675677

676-
@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS")
678+
@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS")
677679
def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
678680
_test_create_supervised_evaluator(evaluator_device="mps")
679681
_test_mocked_supervised_evaluator(evaluator_device="mps")

0 commit comments

Comments
 (0)