Skip to content

Commit f5cca26

Browse files
committed
[DLMED] add UNet example with dict based transforms
1 parent a65b182 commit f5cca26

File tree

7 files changed

+355
-8
lines changed

7 files changed

+355
-8
lines changed

examples/unet_segmentation_3d_dict.py

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,179 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import sys
14+
import tempfile
15+
from glob import glob
16+
import logging
17+
18+
import nibabel as nib
19+
import numpy as np
20+
import torch
21+
from torch.utils.tensorboard import SummaryWriter
22+
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
23+
from ignite.handlers import ModelCheckpoint, EarlyStopping
24+
from torch.utils.data import DataLoader
25+
26+
# assumes the framework is found here, change as necessary
27+
sys.path.append("..")
28+
29+
import monai
30+
import monai.transforms.compose as transforms
31+
from monai.utils.constants import DataElementKey as Dek
32+
from monai.data.nifti_reader import NiftiDatasetd
33+
from monai.transforms.composables import AddChanneld, RandRotate90d
34+
from monai.handlers.stats_handler import StatsHandler
35+
from monai.handlers.mean_dice import MeanDice
36+
from monai.visualize import img2tensorboard
37+
from monai.data.synthetic import create_test_image_3d
38+
from monai.handlers.utils import stopping_fn_from_metric
39+
40+
monai.config.print_config()
41+
42+
# Create a temporary directory and 50 random image, mask paris
43+
tempdir = tempfile.mkdtemp()
44+
45+
for i in range(50):
46+
im, seg = create_test_image_3d(128, 128, 128)
47+
48+
n = nib.Nifti1Image(im, np.eye(4))
49+
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))
50+
51+
n = nib.Nifti1Image(seg, np.eye(4))
52+
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))
53+
54+
images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
55+
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))
56+
57+
# Define transforms for image and segmentation
58+
transforms = transforms.Compose([
59+
AddChanneld(keys=[Dek.IMAGE, Dek.LABEL]),
60+
RandRotate90d(keys=[Dek.IMAGE, Dek.LABEL], prob=0.8, axes=[1, 3])
61+
])
62+
63+
# Define nifti dataset, dataloader.
64+
ds = NiftiDatasetd(images, segs, transform=transforms)
65+
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
66+
check_data = monai.utils.misc.first(loader)
67+
print(check_data[Dek.IMAGE].shape, check_data[Dek.LABEL].shape)
68+
69+
lr = 1e-5
70+
71+
# Create UNet, DiceLoss and Adam optimizer.
72+
net = monai.networks.nets.UNet(
73+
dimensions=3,
74+
in_channels=1,
75+
num_classes=1,
76+
channels=(16, 32, 64, 128, 256),
77+
strides=(2, 2, 2, 2),
78+
num_res_units=2,
79+
)
80+
81+
loss = monai.losses.DiceLoss(do_sigmoid=True)
82+
opt = torch.optim.Adam(net.parameters(), lr)
83+
84+
# Since network outputs logits and segmentation, we need a custom function.
85+
def _loss_fn(i, j):
86+
return loss(i[0], j)
87+
88+
# Create trainer
89+
def prepare_batch(batch, device=None, non_blocking=False):
90+
return _prepare_batch((batch[Dek.IMAGE], batch[Dek.LABEL]), device, non_blocking)
91+
92+
device = torch.device("cuda:0")
93+
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
94+
prepare_batch=prepare_batch,
95+
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])
96+
97+
# adding checkpoint handler to save models (network params and optimizer stats) during training
98+
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
99+
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
100+
handler=checkpoint_handler,
101+
to_save={'net': net, 'opt': opt})
102+
train_stats_handler = StatsHandler()
103+
train_stats_handler.attach(trainer)
104+
105+
@trainer.on(Events.EPOCH_COMPLETED)
106+
def log_training_loss(engine):
107+
# log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
108+
writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)
109+
110+
# tensor of ones to use where for converting labels to zero and ones
111+
ones = torch.ones(engine.state.batch[Dek.LABEL][0].shape, dtype=torch.int32)
112+
first_output_tensor = engine.state.output[0][1][0].detach().cpu()
113+
# log model output to tensorboard, as three dimensional tensor with no channels dimension
114+
img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64,
115+
255, engine.state.epoch)
116+
# get label tensor and convert to single class
117+
first_label_tensor = torch.where(engine.state.batch[Dek.LABEL][0] > 0, ones, engine.state.batch[Dek.LABEL][0])
118+
# log label tensor to tensorboard, there is a channel dimension when getting label from batch
119+
img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64,
120+
255, engine.state.epoch)
121+
second_output_tensor = engine.state.output[0][1][1].detach().cpu()
122+
img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64,
123+
255, engine.state.epoch)
124+
second_label_tensor = torch.where(engine.state.batch[Dek.LABEL][1] > 0, ones, engine.state.batch[Dek.LABEL][1])
125+
img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64,
126+
255, engine.state.epoch)
127+
third_output_tensor = engine.state.output[0][1][2].detach().cpu()
128+
img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64,
129+
255, engine.state.epoch)
130+
third_label_tensor = torch.where(engine.state.batch[Dek.LABEL][2] > 0, ones, engine.state.batch[Dek.LABEL][2])
131+
img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64,
132+
255, engine.state.epoch)
133+
engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1])
134+
135+
writer = SummaryWriter()
136+
137+
# Set parameters for validation
138+
validation_every_n_epochs = 1
139+
metric_name = 'Mean_Dice'
140+
141+
# add evaluation metric to the evaluator engine
142+
val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}
143+
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
144+
prepare_batch=prepare_batch,
145+
output_transform=lambda x, y, y_pred: (y_pred[0], y))
146+
147+
# Add stats event handler to print validation stats via evaluator
148+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
149+
val_stats_handler = StatsHandler()
150+
val_stats_handler.attach(evaluator)
151+
152+
# Add early stopping handler to evaluator.
153+
early_stopper = EarlyStopping(patience=4,
154+
score_function=stopping_fn_from_metric(metric_name),
155+
trainer=trainer)
156+
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)
157+
158+
# create a validation data loader
159+
val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms)
160+
val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
161+
162+
163+
@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
164+
def run_validation(engine):
165+
evaluator.run(val_loader)
166+
167+
@evaluator.on(Events.EPOCH_COMPLETED)
168+
def log_metrics_to_tensorboard(engine):
169+
for _, value in engine.state.metrics.items():
170+
writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch)
171+
172+
# create a training data loader
173+
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
174+
175+
train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms)
176+
train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())
177+
178+
train_epochs = 30
179+
state = trainer.run(train_loader, train_epochs)

monai/data/nifti_reader.py

Lines changed: 75 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,8 @@
1414

1515
from torch.utils.data import Dataset
1616
from torch.utils.data._utils.collate import np_str_obj_array_pattern
17-
17+
from monai.utils.constants import DataElementKey as Dek
18+
from monai.utils.constants import ImageProperty as Prop
1819
from monai.utils.module import export
1920

2021

@@ -40,14 +41,14 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
4041
img = nib.load(filename_or_obj)
4142

4243
header = dict(img.header)
43-
header['filename_or_obj'] = filename_or_obj
44-
header['original_affine'] = img.affine
45-
header['affine'] = img.affine
46-
header['as_closest_canonical'] = as_closest_canonical
44+
header[Prop.FILENAME_OR_OBJ] = filename_or_obj
45+
header[Prop.ORIGINAL_AFFINE] = img.affine
46+
header[Prop.AFFINE] = img.affine
47+
header[Prop.AS_CLOSEST_CANONICAL] = as_closest_canonical
4748

4849
if as_closest_canonical:
4950
img = nib.as_closest_canonical(img)
50-
header['affine'] = img.affine
51+
header[Prop.AFFINE] = img.affine
5152

5253
if dtype is not None:
5354
dat = img.get_fdata(dtype=dtype)
@@ -131,3 +132,71 @@ def __getitem__(self, index):
131132
continue
132133
compatible_meta[meta_key] = meta_datum
133134
return img, seg, compatible_meta
135+
136+
137+
@export("monai.data")
138+
class NiftiDatasetd(Dataset):
139+
"""
140+
Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be
141+
specified for the dictionary data which is constructed by image, label and other metadata.
142+
"""
143+
144+
def __init__(self, image_files, seg_files, as_closest_canonical=False, transform=None,
145+
image_only=True, dtype=None):
146+
"""
147+
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
148+
to the images and `seg_transform` to the segmentations.
149+
150+
Args:
151+
image_files (list of str): list of image filenames.
152+
seg_files (list of str): list of segmentation filenames.
153+
as_closest_canonical (bool): if True, load the image as closest to canonical orientation.
154+
transform (Callable, optional): dict transforms to excute operations on dictionary data.
155+
image_only (bool): if True return only the image volume, other return image volume and header dict.
156+
dtype (np.dtype, optional): if not None convert the loaded image to this data type.
157+
"""
158+
159+
if len(image_files) != len(seg_files):
160+
raise ValueError('Must have same number of image and segmentation files')
161+
162+
self.image_files = image_files
163+
self.seg_files = seg_files
164+
self.as_closest_canonical = as_closest_canonical
165+
self.transform = transform
166+
self.image_only = image_only
167+
self.dtype = dtype
168+
169+
def __len__(self):
170+
return len(self.image_files)
171+
172+
def __getitem__(self, index):
173+
meta_data = None
174+
if self.image_only:
175+
img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
176+
image_only=self.image_only, dtype=self.dtype)
177+
else:
178+
img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
179+
image_only=self.image_only, dtype=self.dtype)
180+
seg = load_nifti(self.seg_files[index])
181+
182+
compatible_meta = {}
183+
if meta_data is not None:
184+
assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.'
185+
for meta_key in meta_data:
186+
meta_datum = meta_data[meta_key]
187+
if type(meta_datum).__name__ == 'ndarray' \
188+
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
189+
continue
190+
compatible_meta[meta_key] = meta_datum
191+
192+
data = {
193+
Dek.IMAGE: img,
194+
Dek.LABEL: seg
195+
}
196+
if len(compatible_meta) > 0:
197+
data.update(compatible_meta)
198+
199+
if self.transform is not None:
200+
data = self.transform(data)
201+
202+
return data

monai/transforms/composables.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import monai
1919
from monai.transforms.compose import Randomizable, Transform
20-
from monai.transforms.transforms import Rotate90
20+
from monai.transforms.transforms import Rotate90, AddChannel
2121
from monai.utils.misc import ensure_tuple
2222

2323
export = monai.utils.export("monai.transforms")
@@ -120,6 +120,28 @@ def __call__(self, data):
120120
return d
121121

122122

123+
@export
124+
class AddChanneld(MapTransform):
125+
"""
126+
dictionary-based wrapper of AddChannel.
127+
"""
128+
129+
def __init__(self, keys):
130+
"""
131+
Args:
132+
keys (hashable items): keys of the corresponding items to be transformed.
133+
See also: monai.transform.composables.MapTransform
134+
"""
135+
MapTransform.__init__(self, keys)
136+
self.adder = AddChannel()
137+
138+
def __call__(self, data):
139+
d = dict(data)
140+
for key in self.keys:
141+
d[key] = self.adder(d[key])
142+
return d
143+
144+
123145
# if __name__ == "__main__":
124146
# import numpy as np
125147
# data = {

monai/transforms/transforms.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def __init__(self, k=1, axes=(1, 2)):
167167
self.plane_axes = axes
168168

169169
def __call__(self, img):
170-
return np.rot90(img, self.k, self.plane_axes)
170+
return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes))
171171

172172

173173
@export

monai/utils/constants.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
class ActivationFunc:
13+
"""Commonly used activation function names.
14+
"""
15+
16+
SOFTMAX = "softmax"
17+
LOG_SOFTMAX = "log_softmax"
18+
SIGMOID = "sigmoid"
19+
LINEAR = "linear"
20+
TANH = "tanh"
21+
22+
23+
class DataElementKey:
24+
"""Data Element keys
25+
"""
26+
27+
IMAGE = "image"
28+
LABEL = "label"
29+
30+
31+
class ImageProperty:
32+
"""Key names for image properties.
33+
"""
34+
35+
FILENAME_OR_OBJ = 'filename_or_obj'
36+
AFFINE = 'affine' # image affine matrix
37+
ORIGINAL_AFFINE = 'original_affine' # original affine matrix before transformation
38+
SPACING = 'spacing' # itk naming convention for pixel/voxel size
39+
AS_CLOSEST_CANONICAL = 'as_closest_canonical' # load the image as closest to canonical axis format
40+
BACKGROUND_INDEX = 'background_index' # which index is background

0 commit comments

Comments
 (0)