Skip to content

Commit 85586d6

Browse files
Parametrize on available devices
1 parent 1d3a844 commit 85586d6

File tree

2 files changed

+23
-19
lines changed

2 files changed

+23
-19
lines changed

tests/ignite/conftest.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,16 @@
1313
import ignite.distributed as idist
1414

1515

16+
@pytest.fixture(
17+
params=[
18+
"cpu",
19+
pytest.param("cuda", marks=pytest.mark.skipif(not torch.cuda.is_available(), reason="Skip if no CUDA support")),
20+
]
21+
)
22+
def available_device(request):
23+
return request.param
24+
25+
1626
@pytest.fixture()
1727
def dirname():
1828
path = Path(tempfile.mkdtemp())

tests/ignite/metrics/test_psnr.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,6 @@
99
from ignite.metrics import PSNR
1010
from ignite.utils import manual_seed
1111

12-
from tests.ignite import cpu_and_maybe_cuda
13-
1412

1513
def test_zero_div():
1614
psnr = PSNR(1.0)
@@ -30,35 +28,31 @@ def test_invalid_psnr():
3028
psnr.update((y_pred, y.squeeze(dim=0)))
3129

3230

33-
@pytest.fixture
34-
def test_data(request):
31+
@pytest.fixture(params=["float", "YCbCr", "uint8", "NHW shape"])
32+
def test_data(request, available_device):
3533
manual_seed(42)
36-
device = request.param.device
37-
if request.param.sample_type == "float":
38-
y_pred = torch.rand(8, 3, 28, 28, device=device)
34+
if request.param == "float":
35+
y_pred = torch.rand(8, 3, 28, 28, device=available_device)
3936
y = y_pred * 0.8
40-
elif request.param.sample_type == "YCbCr":
41-
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
42-
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=device)
43-
elif request.param.sample_type == "uint8":
44-
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=device)
37+
elif request.param == "YCbCr":
38+
y_pred = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
39+
y = torch.randint(16, 236, (4, 1, 12, 12), dtype=torch.uint8, device=available_device)
40+
elif request.param == "uint8":
41+
y_pred = torch.randint(0, 256, (4, 3, 16, 16), dtype=torch.uint8, device=available_device)
4542
y = (y_pred * 0.8).to(torch.uint8)
46-
elif request.param.sample_type == "NHW shape":
47-
y_pred = torch.rand(8, 28, 28, device=device)
43+
elif request.param == "NHW shape":
44+
y_pred = torch.rand(8, 28, 28, device=available_device)
4845
y = y_pred * 0.8
4946
else:
5047
raise ValueError(f"Wrong fixture parameter, given {request.param}")
5148
return (y_pred, y)
5249

5350

54-
@pytest.mark.parametrize("device", cpu_and_maybe_cuda(), indirect=True)
55-
@pytest.mark.parametrize("sample_type", ["float", "YCbCr", "uint8", "NHW shape"], indirect=True)
56-
def test_psnr(test_data):
51+
def test_psnr(test_data, available_device):
5752
y_pred, y = test_data
58-
device = idist.device()
5953
data_range = (y.max() - y.min()).cpu().item()
6054

61-
psnr = PSNR(data_range=data_range, device=device)
55+
psnr = PSNR(data_range=data_range, device=available_device)
6256
psnr.update(test_data)
6357
psnr_compute = psnr.compute()
6458

0 commit comments

Comments
 (0)