Skip to content

Commit ec6aa33

Browse files
Merge similar test components with parameterized (#7663)
### Description I noticed some test cases contain same duplicated asserts. Having multiple asserts in one test cases can cause potential issues like when the first assert fails, the test case stops and won't check the second assert. By using @parameterized.expand, this issue can be resolved and the caching also saves execution time. Added sign-offs from #7648 ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [x] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Han Wang <freddie.wanah@gmail.com> Co-authored-by: YunLiu <55491388+KumoLiu@users.noreply.github.com>
1 parent a59676f commit ec6aa33

15 files changed

+243
-473
lines changed

tests/test_affine_transform.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -133,28 +133,17 @@ def test_to_norm_affine_ill(self, affine, src_size, dst_size, align_corners):
133133

134134
class TestAffineTransform(unittest.TestCase):
135135

136-
def test_affine_shift(self):
137-
affine = torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]])
138-
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
139-
out = AffineTransform(align_corners=False)(image, affine)
140-
out = out.detach().cpu().numpy()
141-
expected = [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]
142-
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
143-
144-
def test_affine_shift_1(self):
145-
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]])
146-
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
147-
out = AffineTransform(align_corners=False)(image, affine)
148-
out = out.detach().cpu().numpy()
149-
expected = [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]
150-
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
151-
152-
def test_affine_shift_2(self):
153-
affine = torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]])
136+
@parameterized.expand(
137+
[
138+
(torch.as_tensor([[1.0, 0.0, 0.0], [0.0, 1.0, -1.0]]), [[[[0, 4, 1, 3], [0, 7, 6, 8], [0, 3, 5, 3]]]]),
139+
(torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, -1.0]]), [[[[0, 0, 0, 0], [0, 4, 1, 3], [0, 7, 6, 8]]]]),
140+
(torch.as_tensor([[1.0, 0.0, -1.0], [0.0, 1.0, 0.0]]), [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]),
141+
]
142+
)
143+
def test_affine_transforms(self, affine, expected):
154144
image = torch.as_tensor([[[[4.0, 1.0, 3.0, 2.0], [7.0, 6.0, 8.0, 5.0], [3.0, 5.0, 3.0, 6.0]]]])
155145
out = AffineTransform(align_corners=False)(image, affine)
156146
out = out.detach().cpu().numpy()
157-
expected = [[[[0, 0, 0, 0], [4, 1, 3, 2], [7, 6, 8, 5]]]]
158147
np.testing.assert_allclose(out, expected, atol=1e-5, rtol=_rtol)
159148

160149
def test_zoom(self):

tests/test_compute_f_beta.py

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import torch
18+
from parameterized import parameterized
1819

1920
from monai.metrics import FBetaScore
2021
from tests.utils import assert_allclose
@@ -33,26 +34,21 @@ def test_expecting_success_and_device(self):
3334
assert_allclose(result, torch.Tensor([0.714286]), atol=1e-6, rtol=1e-6)
3435
np.testing.assert_equal(result.device, y_pred.device)
3536

36-
def test_expecting_success2(self):
37-
metric = FBetaScore(beta=0.5)
38-
metric(
39-
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
40-
)
41-
assert_allclose(metric.aggregate()[0], torch.Tensor([0.609756]), atol=1e-6, rtol=1e-6)
42-
43-
def test_expecting_success3(self):
44-
metric = FBetaScore(beta=2)
45-
metric(
46-
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]])
47-
)
48-
assert_allclose(metric.aggregate()[0], torch.Tensor([0.862069]), atol=1e-6, rtol=1e-6)
49-
50-
def test_denominator_is_zero(self):
51-
metric = FBetaScore(beta=2)
52-
metric(
53-
y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]])
54-
)
55-
assert_allclose(metric.aggregate()[0], torch.Tensor([0.0]), atol=1e-6, rtol=1e-6)
37+
@parameterized.expand(
38+
[
39+
(0.5, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.609756])), # success_beta_0_5
40+
(2, torch.Tensor([[1, 0, 1], [0, 1, 0], [1, 0, 1]]), torch.Tensor([0.862069])), # success_beta_2
41+
(
42+
2, # success_beta_2, denominator_zero
43+
torch.Tensor([[0, 0, 0], [0, 0, 0], [0, 0, 0]]),
44+
torch.Tensor([0.0]),
45+
),
46+
]
47+
)
48+
def test_success_and_zero(self, beta, y, expected_score):
49+
metric = FBetaScore(beta=beta)
50+
metric(y_pred=torch.Tensor([[1, 1, 1], [1, 1, 1], [1, 1, 1]]), y=y)
51+
assert_allclose(metric.aggregate()[0], expected_score, atol=1e-6, rtol=1e-6)
5652

5753
def test_number_of_dimensions_less_than_2_should_raise_error(self):
5854
metric = FBetaScore()

tests/test_global_mutual_information_loss.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import numpy as np
1717
import torch
18+
from parameterized import parameterized
1819

1920
from monai import transforms
2021
from monai.losses.image_dissimilarity import GlobalMutualInformationLoss
@@ -116,24 +117,33 @@ def transformation(translate_params=(0.0, 0.0, 0.0), rotate_params=(0.0, 0.0, 0.
116117

117118
class TestGlobalMutualInformationLossIll(unittest.TestCase):
118119

119-
def test_ill_shape(self):
120+
@parameterized.expand(
121+
[
122+
(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float)), # mismatched_simple_dims
123+
(
124+
torch.ones((1, 3, 3), dtype=torch.float),
125+
torch.ones((1, 3), dtype=torch.float),
126+
), # mismatched_advanced_dims
127+
]
128+
)
129+
def test_ill_shape(self, input1, input2):
120130
loss = GlobalMutualInformationLoss()
121-
with self.assertRaisesRegex(ValueError, ""):
122-
loss.forward(torch.ones((1, 2), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
123-
with self.assertRaisesRegex(ValueError, ""):
124-
loss.forward(torch.ones((1, 3, 3), dtype=torch.float), torch.ones((1, 3), dtype=torch.float, device=device))
125-
126-
def test_ill_opts(self):
131+
with self.assertRaises(ValueError):
132+
loss.forward(input1, input2)
133+
134+
@parameterized.expand(
135+
[
136+
(0, "mean", ValueError, ""), # num_bins_zero
137+
(-1, "mean", ValueError, ""), # num_bins_negative
138+
(64, "unknown", ValueError, ""), # reduction_unknown
139+
(64, None, ValueError, ""), # reduction_none
140+
]
141+
)
142+
def test_ill_opts(self, num_bins, reduction, expected_exception, expected_message):
127143
pred = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
128144
target = torch.ones((1, 3, 3, 3, 3), dtype=torch.float, device=device)
129-
with self.assertRaisesRegex(ValueError, ""):
130-
GlobalMutualInformationLoss(num_bins=0)(pred, target)
131-
with self.assertRaisesRegex(ValueError, ""):
132-
GlobalMutualInformationLoss(num_bins=-1)(pred, target)
133-
with self.assertRaisesRegex(ValueError, ""):
134-
GlobalMutualInformationLoss(reduction="unknown")(pred, target)
135-
with self.assertRaisesRegex(ValueError, ""):
136-
GlobalMutualInformationLoss(reduction=None)(pred, target)
145+
with self.assertRaisesRegex(expected_exception, expected_message):
146+
GlobalMutualInformationLoss(num_bins=num_bins, reduction=reduction)(pred, target)
137147

138148

139149
if __name__ == "__main__":

tests/test_hausdorff_loss.py

Lines changed: 6 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -219,17 +219,12 @@ def test_ill_opts(self):
219219
with self.assertRaisesRegex(ValueError, ""):
220220
HausdorffDTLoss(reduction=None)(chn_input, chn_target)
221221

222-
def test_input_warnings(self):
222+
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
223+
def test_input_warnings(self, include_background, softmax, to_onehot_y):
223224
chn_input = torch.ones((1, 1, 1, 3))
224225
chn_target = torch.ones((1, 1, 1, 3))
225226
with self.assertWarns(Warning):
226-
loss = HausdorffDTLoss(include_background=False)
227-
loss.forward(chn_input, chn_target)
228-
with self.assertWarns(Warning):
229-
loss = HausdorffDTLoss(softmax=True)
230-
loss.forward(chn_input, chn_target)
231-
with self.assertWarns(Warning):
232-
loss = HausdorffDTLoss(to_onehot_y=True)
227+
loss = HausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
233228
loss.forward(chn_input, chn_target)
234229

235230

@@ -256,17 +251,12 @@ def test_ill_opts(self):
256251
with self.assertRaisesRegex(ValueError, ""):
257252
LogHausdorffDTLoss(reduction=None)(chn_input, chn_target)
258253

259-
def test_input_warnings(self):
254+
@parameterized.expand([(False, False, False), (False, True, False), (False, False, True)])
255+
def test_input_warnings(self, include_background, softmax, to_onehot_y):
260256
chn_input = torch.ones((1, 1, 1, 3))
261257
chn_target = torch.ones((1, 1, 1, 3))
262258
with self.assertWarns(Warning):
263-
loss = LogHausdorffDTLoss(include_background=False)
264-
loss.forward(chn_input, chn_target)
265-
with self.assertWarns(Warning):
266-
loss = LogHausdorffDTLoss(softmax=True)
267-
loss.forward(chn_input, chn_target)
268-
with self.assertWarns(Warning):
269-
loss = LogHausdorffDTLoss(to_onehot_y=True)
259+
loss = LogHausdorffDTLoss(include_background=include_background, softmax=softmax, to_onehot_y=to_onehot_y)
270260
loss.forward(chn_input, chn_target)
271261

272262

tests/test_median_filter.py

Lines changed: 7 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -15,27 +15,20 @@
1515

1616
import numpy as np
1717
import torch
18+
from parameterized import parameterized
1819

1920
from monai.networks.layers import MedianFilter
2021

2122

2223
class MedianFilterTestCase(unittest.TestCase):
24+
@parameterized.expand([(torch.ones(1, 1, 2, 3, 5), [1, 2, 4]), (torch.ones(1, 1, 4, 3, 4), 1)]) # 3d_big # 3d
25+
def test_3d(self, input_tensor, radius):
26+
filter = MedianFilter(radius).to(torch.device("cpu:0"))
2327

24-
def test_3d_big(self):
25-
a = torch.ones(1, 1, 2, 3, 5)
26-
g = MedianFilter([1, 2, 4]).to(torch.device("cpu:0"))
28+
expected = input_tensor.numpy()
29+
output = filter(input_tensor).cpu().numpy()
2730

28-
expected = a.numpy()
29-
out = g(a).cpu().numpy()
30-
np.testing.assert_allclose(out, expected, rtol=1e-5)
31-
32-
def test_3d(self):
33-
a = torch.ones(1, 1, 4, 3, 4)
34-
g = MedianFilter(1).to(torch.device("cpu:0"))
35-
36-
expected = a.numpy()
37-
out = g(a).cpu().numpy()
38-
np.testing.assert_allclose(out, expected, rtol=1e-5)
31+
np.testing.assert_allclose(output, expected, rtol=1e-5)
3932

4033
def test_3d_radii(self):
4134
a = torch.ones(1, 1, 4, 3, 2)

tests/test_multi_scale.py

Lines changed: 18 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -58,17 +58,24 @@ def test_shape(self, input_param, input_data, expected_val):
5858
result = MultiScaleLoss(**input_param).forward(**input_data)
5959
np.testing.assert_allclose(result.detach().cpu().numpy(), expected_val, rtol=1e-5)
6060

61-
def test_ill_opts(self):
62-
with self.assertRaisesRegex(ValueError, ""):
63-
MultiScaleLoss(loss=dice_loss, kernel="none")
64-
with self.assertRaisesRegex(ValueError, ""):
65-
MultiScaleLoss(loss=dice_loss, scales=[-1])(
66-
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
67-
)
68-
with self.assertRaisesRegex(ValueError, ""):
69-
MultiScaleLoss(loss=dice_loss, scales=[-1], reduction="none")(
70-
torch.ones((1, 1, 3), device=device), torch.ones((1, 1, 3), device=device)
71-
)
61+
@parameterized.expand(
62+
[
63+
({"loss": dice_loss, "kernel": "none"}, None, None), # kernel_none
64+
({"loss": dice_loss, "scales": [-1]}, torch.ones((1, 1, 3)), torch.ones((1, 1, 3))), # scales_negative
65+
(
66+
{"loss": dice_loss, "scales": [-1], "reduction": "none"},
67+
torch.ones((1, 1, 3)),
68+
torch.ones((1, 1, 3)),
69+
), # scales_negative_reduction_none
70+
]
71+
)
72+
def test_ill_opts(self, kwargs, input, target):
73+
if input is None and target is None:
74+
with self.assertRaisesRegex(ValueError, ""):
75+
MultiScaleLoss(**kwargs)
76+
else:
77+
with self.assertRaisesRegex(ValueError, ""):
78+
MultiScaleLoss(**kwargs)(input, target)
7279

7380
def test_script(self):
7481
input_param, input_data, expected_val = TEST_CASES[0]

tests/test_optional_import.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,20 @@
1313

1414
import unittest
1515

16+
from parameterized import parameterized
17+
1618
from monai.utils import OptionalImportError, exact_version, optional_import
1719

1820

1921
class TestOptionalImport(unittest.TestCase):
2022

21-
def test_default(self):
22-
my_module, flag = optional_import("not_a_module")
23+
@parameterized.expand(["not_a_module", "torch.randint"])
24+
def test_default(self, import_module):
25+
my_module, flag = optional_import(import_module)
2326
self.assertFalse(flag)
2427
with self.assertRaises(OptionalImportError):
2528
my_module.test
2629

27-
my_module, flag = optional_import("torch.randint")
28-
with self.assertRaises(OptionalImportError):
29-
self.assertFalse(flag)
30-
print(my_module.test)
31-
3230
def test_import_valid(self):
3331
my_module, flag = optional_import("torch")
3432
self.assertTrue(flag)
@@ -47,18 +45,9 @@ def test_import_wrong_number(self):
4745
self.assertTrue(flag)
4846
print(my_module.randint(1, 2, (1, 2)))
4947

50-
def test_import_good_number(self):
51-
my_module, flag = optional_import("torch", "0")
52-
my_module.nn
53-
self.assertTrue(flag)
54-
print(my_module.randint(1, 2, (1, 2)))
55-
56-
my_module, flag = optional_import("torch", "0.0.0.1")
57-
my_module.nn
58-
self.assertTrue(flag)
59-
print(my_module.randint(1, 2, (1, 2)))
60-
61-
my_module, flag = optional_import("torch", "1.1.0")
48+
@parameterized.expand(["0", "0.0.0.1", "1.1.0"])
49+
def test_import_good_number(self, version_number):
50+
my_module, flag = optional_import("torch", version_number)
6251
my_module.nn
6352
self.assertTrue(flag)
6453
print(my_module.randint(1, 2, (1, 2)))

tests/test_perceptual_loss.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,10 @@ def test_1d(self):
113113
with self.assertRaises(NotImplementedError):
114114
PerceptualLoss(spatial_dims=1)
115115

116-
def test_medicalnet_on_2d_data(self):
116+
@parameterized.expand(["medicalnet_resnet10_23datasets", "medicalnet_resnet50_23datasets"])
117+
def test_medicalnet_on_2d_data(self, network_type):
117118
with self.assertRaises(ValueError):
118-
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet10_23datasets")
119-
120-
with self.assertRaises(ValueError):
121-
PerceptualLoss(spatial_dims=2, network_type="medicalnet_resnet50_23datasets")
119+
PerceptualLoss(spatial_dims=2, network_type=network_type)
122120

123121

124122
if __name__ == "__main__":

0 commit comments

Comments
 (0)