|
25 | 25 | )
|
26 | 26 | from ignite.metrics import MeanSquaredError
|
27 | 27 |
|
| 28 | +from tests.ignite import is_mps_available_and_functional |
| 29 | + |
28 | 30 |
|
29 | 31 | class DummyModel(torch.nn.Module):
|
30 | 32 | def __init__(self, output_as_list=False):
|
@@ -485,7 +487,7 @@ def test_create_supervised_trainer_on_cuda():
|
485 | 487 | _test_create_mocked_supervised_trainer(model_device=model_device, trainer_device=trainer_device)
|
486 | 488 |
|
487 | 489 |
|
488 |
| -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") |
| 490 | +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") |
489 | 491 | def test_create_supervised_trainer_on_mps():
|
490 | 492 | model_device = trainer_device = "mps"
|
491 | 493 | _test_create_supervised_trainer_wrong_accumulation(model_device=model_device, trainer_device=trainer_device)
|
@@ -666,14 +668,14 @@ def test_create_supervised_evaluator_on_cuda_with_model_on_cpu():
|
666 | 668 | _test_mocked_supervised_evaluator(evaluator_device="cuda")
|
667 | 669 |
|
668 | 670 |
|
669 |
| -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") |
| 671 | +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") |
670 | 672 | def test_create_supervised_evaluator_on_mps():
|
671 | 673 | model_device = evaluator_device = "mps"
|
672 | 674 | _test_create_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
|
673 | 675 | _test_mocked_supervised_evaluator(model_device=model_device, evaluator_device=evaluator_device)
|
674 | 676 |
|
675 | 677 |
|
676 |
| -@pytest.mark.skipif(not (_torch_version_le_112 and torch.backends.mps.is_available()), reason="Skip if no MPS") |
| 678 | +@pytest.mark.skipif(not (_torch_version_le_112 and is_mps_available_and_functional()), reason="Skip if no MPS") |
677 | 679 | def test_create_supervised_evaluator_on_mps_with_model_on_cpu():
|
678 | 680 | _test_create_supervised_evaluator(evaluator_device="mps")
|
679 | 681 | _test_mocked_supervised_evaluator(evaluator_device="mps")
|
|
0 commit comments