Skip to content

Commit

Permalink
[DLMED] add UNet example with dict based transforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Nic-Ma committed Mar 4, 2020
1 parent a65b182 commit f5cca26
Show file tree
Hide file tree
Showing 7 changed files with 355 additions and 8 deletions.
File renamed without changes.
179 changes: 179 additions & 0 deletions examples/unet_segmentation_3d_dict.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import sys
import tempfile
from glob import glob
import logging

import nibabel as nib
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch
from ignite.handlers import ModelCheckpoint, EarlyStopping
from torch.utils.data import DataLoader

# assumes the framework is found here, change as necessary
sys.path.append("..")

import monai
import monai.transforms.compose as transforms
from monai.utils.constants import DataElementKey as Dek
from monai.data.nifti_reader import NiftiDatasetd
from monai.transforms.composables import AddChanneld, RandRotate90d
from monai.handlers.stats_handler import StatsHandler
from monai.handlers.mean_dice import MeanDice
from monai.visualize import img2tensorboard
from monai.data.synthetic import create_test_image_3d
from monai.handlers.utils import stopping_fn_from_metric

monai.config.print_config()

# Create a temporary directory and 50 random image, mask paris
tempdir = tempfile.mkdtemp()

for i in range(50):
im, seg = create_test_image_3d(128, 128, 128)

n = nib.Nifti1Image(im, np.eye(4))
nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i))

n = nib.Nifti1Image(seg, np.eye(4))
nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i))

images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz')))
segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz')))

# Define transforms for image and segmentation
transforms = transforms.Compose([
AddChanneld(keys=[Dek.IMAGE, Dek.LABEL]),
RandRotate90d(keys=[Dek.IMAGE, Dek.LABEL], prob=0.8, axes=[1, 3])
])

# Define nifti dataset, dataloader.
ds = NiftiDatasetd(images, segs, transform=transforms)
loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available())
check_data = monai.utils.misc.first(loader)
print(check_data[Dek.IMAGE].shape, check_data[Dek.LABEL].shape)

lr = 1e-5

# Create UNet, DiceLoss and Adam optimizer.
net = monai.networks.nets.UNet(
dimensions=3,
in_channels=1,
num_classes=1,
channels=(16, 32, 64, 128, 256),
strides=(2, 2, 2, 2),
num_res_units=2,
)

loss = monai.losses.DiceLoss(do_sigmoid=True)
opt = torch.optim.Adam(net.parameters(), lr)

# Since network outputs logits and segmentation, we need a custom function.
def _loss_fn(i, j):
return loss(i[0], j)

# Create trainer
def prepare_batch(batch, device=None, non_blocking=False):
return _prepare_batch((batch[Dek.IMAGE], batch[Dek.LABEL]), device, non_blocking)

device = torch.device("cuda:0")
trainer = create_supervised_trainer(net, opt, _loss_fn, device, False,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y])

# adding checkpoint handler to save models (network params and optimizer stats) during training
checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False)
trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
handler=checkpoint_handler,
to_save={'net': net, 'opt': opt})
train_stats_handler = StatsHandler()
train_stats_handler.attach(trainer)

@trainer.on(Events.EPOCH_COMPLETED)
def log_training_loss(engine):
# log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform
writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch)

# tensor of ones to use where for converting labels to zero and ones
ones = torch.ones(engine.state.batch[Dek.LABEL][0].shape, dtype=torch.int32)
first_output_tensor = engine.state.output[0][1][0].detach().cpu()
# log model output to tensorboard, as three dimensional tensor with no channels dimension
img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64,
255, engine.state.epoch)
# get label tensor and convert to single class
first_label_tensor = torch.where(engine.state.batch[Dek.LABEL][0] > 0, ones, engine.state.batch[Dek.LABEL][0])
# log label tensor to tensorboard, there is a channel dimension when getting label from batch
img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64,
255, engine.state.epoch)
second_output_tensor = engine.state.output[0][1][1].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64,
255, engine.state.epoch)
second_label_tensor = torch.where(engine.state.batch[Dek.LABEL][1] > 0, ones, engine.state.batch[Dek.LABEL][1])
img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64,
255, engine.state.epoch)
third_output_tensor = engine.state.output[0][1][2].detach().cpu()
img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64,
255, engine.state.epoch)
third_label_tensor = torch.where(engine.state.batch[Dek.LABEL][2] > 0, ones, engine.state.batch[Dek.LABEL][2])
img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64,
255, engine.state.epoch)
engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1])

writer = SummaryWriter()

# Set parameters for validation
validation_every_n_epochs = 1
metric_name = 'Mean_Dice'

# add evaluation metric to the evaluator engine
val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)}
evaluator = create_supervised_evaluator(net, val_metrics, device, True,
prepare_batch=prepare_batch,
output_transform=lambda x, y, y_pred: (y_pred[0], y))

# Add stats event handler to print validation stats via evaluator
logging.basicConfig(stream=sys.stdout, level=logging.INFO)
val_stats_handler = StatsHandler()
val_stats_handler.attach(evaluator)

# Add early stopping handler to evaluator.
early_stopper = EarlyStopping(patience=4,
score_function=stopping_fn_from_metric(metric_name),
trainer=trainer)
evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper)

# create a validation data loader
val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms)
val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())


@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs))
def run_validation(engine):
evaluator.run(val_loader)

@evaluator.on(Events.EPOCH_COMPLETED)
def log_metrics_to_tensorboard(engine):
for _, value in engine.state.metrics.items():
writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch)

# create a training data loader
logging.basicConfig(stream=sys.stdout, level=logging.INFO)

train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms)
train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available())

train_epochs = 30
state = trainer.run(train_loader, train_epochs)
81 changes: 75 additions & 6 deletions monai/data/nifti_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@

from torch.utils.data import Dataset
from torch.utils.data._utils.collate import np_str_obj_array_pattern

from monai.utils.constants import DataElementKey as Dek
from monai.utils.constants import ImageProperty as Prop
from monai.utils.module import export


Expand All @@ -40,14 +41,14 @@ def load_nifti(filename_or_obj, as_closest_canonical=False, image_only=True, dty
img = nib.load(filename_or_obj)

header = dict(img.header)
header['filename_or_obj'] = filename_or_obj
header['original_affine'] = img.affine
header['affine'] = img.affine
header['as_closest_canonical'] = as_closest_canonical
header[Prop.FILENAME_OR_OBJ] = filename_or_obj
header[Prop.ORIGINAL_AFFINE] = img.affine
header[Prop.AFFINE] = img.affine
header[Prop.AS_CLOSEST_CANONICAL] = as_closest_canonical

if as_closest_canonical:
img = nib.as_closest_canonical(img)
header['affine'] = img.affine
header[Prop.AFFINE] = img.affine

if dtype is not None:
dat = img.get_fdata(dtype=dtype)
Expand Down Expand Up @@ -131,3 +132,71 @@ def __getitem__(self, index):
continue
compatible_meta[meta_key] = meta_datum
return img, seg, compatible_meta


@export("monai.data")
class NiftiDatasetd(Dataset):
"""
Loads image/segmentation pairs of Nifti files from the given filename lists. Dict level transformations can be
specified for the dictionary data which is constructed by image, label and other metadata.
"""

def __init__(self, image_files, seg_files, as_closest_canonical=False, transform=None,
image_only=True, dtype=None):
"""
Initializes the dataset with the image and segmentation filename lists. The transform `transform` is applied
to the images and `seg_transform` to the segmentations.
Args:
image_files (list of str): list of image filenames.
seg_files (list of str): list of segmentation filenames.
as_closest_canonical (bool): if True, load the image as closest to canonical orientation.
transform (Callable, optional): dict transforms to excute operations on dictionary data.
image_only (bool): if True return only the image volume, other return image volume and header dict.
dtype (np.dtype, optional): if not None convert the loaded image to this data type.
"""

if len(image_files) != len(seg_files):
raise ValueError('Must have same number of image and segmentation files')

self.image_files = image_files
self.seg_files = seg_files
self.as_closest_canonical = as_closest_canonical
self.transform = transform
self.image_only = image_only
self.dtype = dtype

def __len__(self):
return len(self.image_files)

def __getitem__(self, index):
meta_data = None
if self.image_only:
img = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
image_only=self.image_only, dtype=self.dtype)
else:
img, meta_data = load_nifti(self.image_files[index], as_closest_canonical=self.as_closest_canonical,
image_only=self.image_only, dtype=self.dtype)
seg = load_nifti(self.seg_files[index])

compatible_meta = {}
if meta_data is not None:
assert isinstance(meta_data, dict), 'meta_data must be in dictionary format.'
for meta_key in meta_data:
meta_datum = meta_data[meta_key]
if type(meta_datum).__name__ == 'ndarray' \
and np_str_obj_array_pattern.search(meta_datum.dtype.str) is not None:
continue
compatible_meta[meta_key] = meta_datum

data = {
Dek.IMAGE: img,
Dek.LABEL: seg
}
if len(compatible_meta) > 0:
data.update(compatible_meta)

if self.transform is not None:
data = self.transform(data)

return data
24 changes: 23 additions & 1 deletion monai/transforms/composables.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import monai
from monai.transforms.compose import Randomizable, Transform
from monai.transforms.transforms import Rotate90
from monai.transforms.transforms import Rotate90, AddChannel
from monai.utils.misc import ensure_tuple

export = monai.utils.export("monai.transforms")
Expand Down Expand Up @@ -120,6 +120,28 @@ def __call__(self, data):
return d


@export
class AddChanneld(MapTransform):
"""
dictionary-based wrapper of AddChannel.
"""

def __init__(self, keys):
"""
Args:
keys (hashable items): keys of the corresponding items to be transformed.
See also: monai.transform.composables.MapTransform
"""
MapTransform.__init__(self, keys)
self.adder = AddChannel()

def __call__(self, data):
d = dict(data)
for key in self.keys:
d[key] = self.adder(d[key])
return d


# if __name__ == "__main__":
# import numpy as np
# data = {
Expand Down
2 changes: 1 addition & 1 deletion monai/transforms/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def __init__(self, k=1, axes=(1, 2)):
self.plane_axes = axes

def __call__(self, img):
return np.rot90(img, self.k, self.plane_axes)
return np.ascontiguousarray(np.rot90(img, self.k, self.plane_axes))


@export
Expand Down
40 changes: 40 additions & 0 deletions monai/utils/constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Copyright 2020 MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

class ActivationFunc:
"""Commonly used activation function names.
"""

SOFTMAX = "softmax"
LOG_SOFTMAX = "log_softmax"
SIGMOID = "sigmoid"
LINEAR = "linear"
TANH = "tanh"


class DataElementKey:
"""Data Element keys
"""

IMAGE = "image"
LABEL = "label"


class ImageProperty:
"""Key names for image properties.
"""

FILENAME_OR_OBJ = 'filename_or_obj'
AFFINE = 'affine' # image affine matrix
ORIGINAL_AFFINE = 'original_affine' # original affine matrix before transformation
SPACING = 'spacing' # itk naming convention for pixel/voxel size
AS_CLOSEST_CANONICAL = 'as_closest_canonical' # load the image as closest to canonical axis format
BACKGROUND_INDEX = 'background_index' # which index is background
Loading

0 comments on commit f5cca26

Please sign in to comment.