Skip to content

Commit 448b61a

Browse files
authored
17 unit tests for each mvp component and workflow (#83)
* unit test -- mean_dice handler * unit test -- stats handler * integration test -- sliding window
1 parent 5e12c5d commit 448b61a

File tree

7 files changed

+236
-11
lines changed

7 files changed

+236
-11
lines changed

monai/handlers/mean_dice.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,29 @@ class MeanDice(Metric):
2727
def __init__(
2828
self,
2929
include_background=True,
30-
to_onehot_y=False,
31-
logit_thresh=0.5,
30+
to_onehot_y=True,
31+
logit_thresh=None,
3232
add_sigmoid=False,
33-
mutually_exclusive=False,
33+
mutually_exclusive=True,
3434
output_transform: Callable = lambda x: x,
3535
device: Optional[Union[str, torch.device]] = None
3636
):
3737
"""
3838
3939
Args:
4040
include_background (Bool): whether to include dice computation on the first channel of the predicted output.
41-
to_onehot_y (Bool): whether to convert the output prediction into the one-hot format.
42-
logit_thresh (Float): the threshold value to round value to 0.0 and 1.0, default is 0.5.
41+
Defaults to True.
42+
to_onehot_y (Bool): whether to convert the output prediction into the one-hot format. Defaults to True.
43+
logit_thresh (Float): the threshold value to round value to 0.0 and 1.0. Defaults to None (no thresholding).
4344
add_sigmoid (Bool): whether to add sigmoid function to the output prediction before computing Dice.
45+
Defaults to False.
4446
mutually_exclusive (Bool): if True, the output prediction will be converted into a binary matrix using
45-
a combination of argmax and to_onehot.
47+
a combination of argmax and to_onehot. Defaults to True.
4648
output_transform (Callable): transform the ignite.engine.state.output into [y_pred, y] pair.
4749
device (torch.device): device specification in case of distributed computation usage.
50+
51+
See also:
52+
monai.metrics.compute_meandice.compute_meandice
4853
"""
4954
super(MeanDice, self).__init__(output_transform, device=device)
5055
self.include_background = include_background

monai/handlers/segmentation_saver.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class SegmentationSaver:
2323
"""
2424

2525
def __init__(self, output_path='./', dtype='float32', output_postfix='seg', output_ext='.nii.gz',
26-
output_transform=lambda x: x):
26+
output_transform=lambda x: x, name=None):
2727
"""
2828
Args:
2929
output_path (str): output image directory.
@@ -34,14 +34,19 @@ def __init__(self, output_path='./', dtype='float32', output_postfix='seg', outp
3434
ignite.engine.output into the form expected nifti image data.
3535
The first dimension of this transform's output will be treated as the
3636
batch dimension. Each item in the batch will be saved individually.
37+
name (str): identifier of logging.logger to use, defaulting to `engine.logger`.
3738
"""
3839
self.output_path = output_path
3940
self.dtype = dtype
4041
self.output_postfix = output_postfix
4142
self.output_ext = output_ext
4243
self.output_transform = output_transform
4344

45+
self.logger = None if name is None else logging.getLogger(name)
46+
4447
def attach(self, engine):
48+
if self.logger is None:
49+
self.logger = engine.logger
4550
return engine.add_event_handler(Events.ITERATION_COMPLETED, self)
4651

4752
@staticmethod
@@ -103,4 +108,4 @@ def __call__(self, engine):
103108
output_filename = self._create_file_basename(self.output_postfix, filename, self.output_path)
104109
output_filename = '{}{}'.format(output_filename, self.output_ext)
105110
write_nifti(seg_output, _affine, output_filename, _original_affine, dtype=seg_output.dtype)
106-
print('saved: {}'.format(output_filename))
111+
self.logger.info('saved: {}'.format(output_filename))

monai/metrics/compute_meandice.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,8 @@ def compute_meandice(y_pred,
2727
y_pred (torch.Tensor): input data to compute, typical segmentation model output.
2828
it must be One-Hot format and first dim is batch, example shape: [16, 3, 32, 32].
2929
y (torch.Tensor): ground truth to compute mean dice metric, the first dim is batch.
30+
example shape: [16, 3, 32, 32] for 3-class one-hot labels.
31+
alternative shape: [16, 1, 32, 32] and set `to_onehot_y=True` to convert it into [16, 3, 32, 32].
3032
include_background (Bool): whether to skip dice computation on the first channel of the predicted output.
3133
to_onehot_y (Bool): whether to convert `y` into the one-hot format.
3234
mutually_exclusive (Bool): if True, `y_pred` will be converted into a binary matrix using
@@ -44,8 +46,8 @@ def compute_meandice(y_pred,
4446
n_channels_y_pred = y_pred.shape[1]
4547

4648
if mutually_exclusive:
47-
if logit_thresh is not None:
48-
raise ValueError('`logit_thresh` is incompatible when mutually_exclusive is True.')
49+
if logit_thresh is not None or add_sigmoid:
50+
raise ValueError('`logit_thresh` and `add_sigmoid` are incompatible when mutually_exclusive is True.')
4951
y_pred = torch.argmax(y_pred, dim=1, keepdim=True)
5052
y_pred = one_hot(y_pred, n_channels_y_pred)
5153
else: # channel-wise thresholding
@@ -61,6 +63,9 @@ def compute_meandice(y_pred,
6163
y = y[:, 1:] if y.shape[1] > 1 else y
6264
y_pred = y_pred[:, 1:] if y_pred.shape[1] > 1 else y_pred
6365

66+
assert y.shape == y_pred.shape, ("Ground truth one-hot has differing shape (%r) from source (%r)" %
67+
(y.shape, y_pred.shape))
68+
6469
# reducing only spatial dimensions (not batch nor channels)
6570
reduce_axis = list(range(2, y_pred.dim()))
6671
intersection = torch.sum(y * y_pred, reduce_axis)

tests/integration_sliding_window.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import sys
14+
import tempfile
15+
16+
import nibabel as nib
17+
import torch
18+
from ignite.engine import Engine
19+
from torch.utils.data import DataLoader
20+
21+
from monai.data.nifti_reader import NiftiDataset
22+
from monai.data.synthetic import create_test_image_3d
23+
from monai.handlers.segmentation_saver import SegmentationSaver
24+
from monai.networks.nets.unet import UNet
25+
from monai.networks.utils import predict_segmentation
26+
from monai.transforms.transforms import AddChannel
27+
from monai.utils.sliding_window_inference import sliding_window_inference
28+
from tests.utils import make_nifti_image
29+
30+
31+
def run_test(batch_size=2, device=torch.device("cpu:0")):
32+
33+
im, seg = create_test_image_3d(25, 28, 63, rad_max=10, noise_max=1, num_objs=4, num_seg_classes=1)
34+
input_shape = im.shape
35+
img_name = make_nifti_image(im)
36+
seg_name = make_nifti_image(seg)
37+
ds = NiftiDataset([img_name], [seg_name], transform=AddChannel(), seg_transform=AddChannel(), image_only=False)
38+
loader = DataLoader(ds, batch_size=1, pin_memory=torch.cuda.is_available())
39+
40+
net = UNet(
41+
dimensions=3,
42+
in_channels=1,
43+
num_classes=1,
44+
channels=(4, 8, 16, 32),
45+
strides=(2, 2, 2),
46+
num_res_units=2,
47+
)
48+
roi_size = (16, 32, 48)
49+
sw_batch_size = batch_size
50+
51+
def _sliding_window_processor(_engine, batch):
52+
net.eval()
53+
img, seg, meta_data = batch
54+
with torch.no_grad():
55+
seg_probs = sliding_window_inference(img, roi_size, sw_batch_size, lambda x: net(x)[0], device)
56+
return predict_segmentation(seg_probs)
57+
58+
infer_engine = Engine(_sliding_window_processor)
59+
60+
with tempfile.TemporaryDirectory() as temp_dir:
61+
SegmentationSaver(output_path=temp_dir, output_ext='.nii.gz', output_postfix='seg').attach(infer_engine)
62+
63+
infer_engine.run(loader)
64+
65+
basename = os.path.basename(img_name)[:-len('.nii.gz')]
66+
saved_name = os.path.join(temp_dir, basename, '{}_seg.nii.gz'.format(basename))
67+
testing_shape = nib.load(saved_name).get_fdata().shape
68+
69+
if os.path.exists(img_name):
70+
os.remove(img_name)
71+
if os.path.exists(seg_name):
72+
os.remove(seg_name)
73+
74+
return testing_shape == input_shape
75+
76+
77+
if __name__ == "__main__":
78+
result = run_test()
79+
80+
sys.exit(0 if result else 1)

tests/test_handler_mean_dice.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import unittest
13+
14+
import torch
15+
from parameterized import parameterized
16+
17+
from monai.handlers.mean_dice import MeanDice
18+
19+
TEST_CASE_1 = [{'to_onehot_y': True, 'mutually_exclusive': True}, 0.75]
20+
TEST_CASE_2 = [{'include_background': False, 'to_onehot_y': False, 'mutually_exclusive': False}, 0.8333333]
21+
22+
TEST_CASE_3 = [{'mutually_exclusive': True, 'add_sigmoid': True}]
23+
24+
25+
class TestHandlerMeanDice(unittest.TestCase):
26+
# TODO test multi node averaged dice
27+
28+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
29+
def test_compute(self, input_params, expected_avg):
30+
dice_metric = MeanDice(**input_params)
31+
32+
y_pred = torch.Tensor([[0, 1], [1, 0]])
33+
y = torch.ones((2, 1))
34+
dice_metric.update([y_pred, y])
35+
36+
y_pred = torch.Tensor([[0, 1], [1, 0]])
37+
y = torch.Tensor([[1.], [0.]])
38+
dice_metric.update([y_pred, y])
39+
40+
avg_dice = dice_metric.compute()
41+
self.assertAlmostEqual(avg_dice, expected_avg)
42+
43+
@parameterized.expand([TEST_CASE_3])
44+
def test_misconfig(self, input_params):
45+
with self.assertRaisesRegex(ValueError, 'compatib'):
46+
dice_metric = MeanDice(**input_params)
47+
48+
y_pred = torch.Tensor([[0, 1], [1, 0]])
49+
y = torch.ones((2, 1))
50+
dice_metric.update([y_pred, y])
51+
52+
@parameterized.expand([TEST_CASE_1, TEST_CASE_2])
53+
def test_shape_mismatch(self, input_params, _expected):
54+
dice_metric = MeanDice(**input_params)
55+
with self.assertRaises((AssertionError, ValueError)):
56+
y_pred = torch.Tensor([[0, 1], [1, 0]])
57+
y = torch.ones((2, 3))
58+
dice_metric.update([y_pred, y])
59+
60+
with self.assertRaises((AssertionError, ValueError)):
61+
y_pred = torch.Tensor([[0, 1], [1, 0]])
62+
y = torch.ones((3, 2))
63+
dice_metric.update([y_pred, y])
64+
65+
66+
if __name__ == '__main__':
67+
unittest.main()

tests/test_handler_stats.py

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import logging
13+
import re
14+
import unittest
15+
from io import StringIO
16+
17+
from ignite.engine import Engine, Events
18+
19+
from monai.handlers.stats_handler import StatsHandler
20+
21+
22+
class TestHandlerStats(unittest.TestCase):
23+
24+
def test_metrics_print(self):
25+
log_stream = StringIO()
26+
logging.basicConfig(stream=log_stream, level=logging.INFO)
27+
key_to_handler = 'test_logging'
28+
key_to_print = 'testing_metric'
29+
30+
# set up engine
31+
def _train_func(engine, batch):
32+
pass
33+
34+
engine = Engine(_train_func)
35+
36+
# set up dummy metric
37+
@engine.on(Events.ITERATION_COMPLETED)
38+
def _update_metric(engine):
39+
current_metric = engine.state.metrics.get(key_to_print, 0.1)
40+
engine.state.metrics[key_to_print] = current_metric + 0.1
41+
42+
# set up testing handler
43+
stats_handler = StatsHandler(name=key_to_handler)
44+
stats_handler.attach(engine)
45+
46+
engine.run(range(3), max_epochs=2)
47+
48+
# check logging output
49+
output_str = log_stream.getvalue()
50+
grep = re.compile('.*{}.*'.format(key_to_handler))
51+
has_key_word = re.compile('.*{}.*'.format(key_to_print))
52+
matched = []
53+
for idx, line in enumerate(output_str.split('\n')):
54+
if grep.match(line):
55+
self.assertTrue(has_key_word.match(line))
56+
matched.append(idx)
57+
self.assertEqual(matched, [1, 2, 3, 5, 6, 7, 8, 10])
58+
59+
60+
if __name__ == '__main__':
61+
unittest.main()

tests/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,13 @@ def skip_if_quick(obj):
2828
return unittest.skipIf(is_quick, "Skipping slow tests")(obj)
2929

3030

31-
def make_nifti_image(array, affine):
31+
def make_nifti_image(array, affine=None):
3232
"""
3333
Create a temporary nifti image on the disk and return the image name.
3434
User is responsible for deleting the temporary file when done with it.
3535
"""
36+
if affine is None:
37+
affine = np.eye(4)
3638
test_image = nib.Nifti1Image(array, affine)
3739

3840
_, image_name = tempfile.mkstemp(suffix='.nii.gz')

0 commit comments

Comments
 (0)