|
| 1 | +# Copyright 2020 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import os |
| 13 | +import sys |
| 14 | +import tempfile |
| 15 | +from glob import glob |
| 16 | +import logging |
| 17 | + |
| 18 | +import nibabel as nib |
| 19 | +import numpy as np |
| 20 | +import torch |
| 21 | +from torch.utils.tensorboard import SummaryWriter |
| 22 | +from ignite.engine import Events, create_supervised_trainer, create_supervised_evaluator, _prepare_batch |
| 23 | +from ignite.handlers import ModelCheckpoint, EarlyStopping |
| 24 | +from torch.utils.data import DataLoader |
| 25 | + |
| 26 | +# assumes the framework is found here, change as necessary |
| 27 | +sys.path.append("..") |
| 28 | + |
| 29 | +import monai |
| 30 | +import monai.transforms.compose as transforms |
| 31 | +from monai.utils.constants import DataElementKey as Dek |
| 32 | +from monai.data.nifti_reader import NiftiDatasetd |
| 33 | +from monai.transforms.composables import AddChanneld, RandRotate90d |
| 34 | +from monai.handlers.stats_handler import StatsHandler |
| 35 | +from monai.handlers.mean_dice import MeanDice |
| 36 | +from monai.visualize import img2tensorboard |
| 37 | +from monai.data.synthetic import create_test_image_3d |
| 38 | +from monai.handlers.utils import stopping_fn_from_metric |
| 39 | + |
| 40 | +monai.config.print_config() |
| 41 | + |
| 42 | +# Create a temporary directory and 50 random image, mask paris |
| 43 | +tempdir = tempfile.mkdtemp() |
| 44 | + |
| 45 | +for i in range(50): |
| 46 | + im, seg = create_test_image_3d(128, 128, 128) |
| 47 | + |
| 48 | + n = nib.Nifti1Image(im, np.eye(4)) |
| 49 | + nib.save(n, os.path.join(tempdir, 'im%i.nii.gz' % i)) |
| 50 | + |
| 51 | + n = nib.Nifti1Image(seg, np.eye(4)) |
| 52 | + nib.save(n, os.path.join(tempdir, 'seg%i.nii.gz' % i)) |
| 53 | + |
| 54 | +images = sorted(glob(os.path.join(tempdir, 'im*.nii.gz'))) |
| 55 | +segs = sorted(glob(os.path.join(tempdir, 'seg*.nii.gz'))) |
| 56 | + |
| 57 | +# Define transforms for image and segmentation |
| 58 | +transforms = transforms.Compose([ |
| 59 | + AddChanneld(keys=[Dek.IMAGE, Dek.LABEL]), |
| 60 | + RandRotate90d(keys=[Dek.IMAGE, Dek.LABEL], prob=0.8, axes=[1, 3]) |
| 61 | +]) |
| 62 | + |
| 63 | +# Define nifti dataset, dataloader. |
| 64 | +ds = NiftiDatasetd(images, segs, transform=transforms) |
| 65 | +loader = DataLoader(ds, batch_size=10, num_workers=2, pin_memory=torch.cuda.is_available()) |
| 66 | +check_data = monai.utils.misc.first(loader) |
| 67 | +print(check_data[Dek.IMAGE].shape, check_data[Dek.LABEL].shape) |
| 68 | + |
| 69 | +lr = 1e-5 |
| 70 | + |
| 71 | +# Create UNet, DiceLoss and Adam optimizer. |
| 72 | +net = monai.networks.nets.UNet( |
| 73 | + dimensions=3, |
| 74 | + in_channels=1, |
| 75 | + num_classes=1, |
| 76 | + channels=(16, 32, 64, 128, 256), |
| 77 | + strides=(2, 2, 2, 2), |
| 78 | + num_res_units=2, |
| 79 | +) |
| 80 | + |
| 81 | +loss = monai.losses.DiceLoss(do_sigmoid=True) |
| 82 | +opt = torch.optim.Adam(net.parameters(), lr) |
| 83 | + |
| 84 | +# Since network outputs logits and segmentation, we need a custom function. |
| 85 | +def _loss_fn(i, j): |
| 86 | + return loss(i[0], j) |
| 87 | + |
| 88 | +# Create trainer |
| 89 | +def prepare_batch(batch, device=None, non_blocking=False): |
| 90 | + return _prepare_batch((batch[Dek.IMAGE], batch[Dek.LABEL]), device, non_blocking) |
| 91 | + |
| 92 | +device = torch.device("cuda:0") |
| 93 | +trainer = create_supervised_trainer(net, opt, _loss_fn, device, False, |
| 94 | + prepare_batch=prepare_batch, |
| 95 | + output_transform=lambda x, y, y_pred, loss: [y_pred, loss.item(), y]) |
| 96 | + |
| 97 | +# adding checkpoint handler to save models (network params and optimizer stats) during training |
| 98 | +checkpoint_handler = ModelCheckpoint('./runs/', 'net', n_saved=10, require_empty=False) |
| 99 | +trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED, |
| 100 | + handler=checkpoint_handler, |
| 101 | + to_save={'net': net, 'opt': opt}) |
| 102 | +train_stats_handler = StatsHandler() |
| 103 | +train_stats_handler.attach(trainer) |
| 104 | + |
| 105 | +@trainer.on(Events.EPOCH_COMPLETED) |
| 106 | +def log_training_loss(engine): |
| 107 | + # log loss to tensorboard with second item of engine.state.output, loss.item() from output_transform |
| 108 | + writer.add_scalar('Loss/train', engine.state.output[1], engine.state.epoch) |
| 109 | + |
| 110 | + # tensor of ones to use where for converting labels to zero and ones |
| 111 | + ones = torch.ones(engine.state.batch[Dek.LABEL][0].shape, dtype=torch.int32) |
| 112 | + first_output_tensor = engine.state.output[0][1][0].detach().cpu() |
| 113 | + # log model output to tensorboard, as three dimensional tensor with no channels dimension |
| 114 | + img2tensorboard.add_animated_gif_no_channels(writer, "first_output_final_batch", first_output_tensor, 64, |
| 115 | + 255, engine.state.epoch) |
| 116 | + # get label tensor and convert to single class |
| 117 | + first_label_tensor = torch.where(engine.state.batch[Dek.LABEL][0] > 0, ones, engine.state.batch[Dek.LABEL][0]) |
| 118 | + # log label tensor to tensorboard, there is a channel dimension when getting label from batch |
| 119 | + img2tensorboard.add_animated_gif(writer, "first_label_final_batch", first_label_tensor, 64, |
| 120 | + 255, engine.state.epoch) |
| 121 | + second_output_tensor = engine.state.output[0][1][1].detach().cpu() |
| 122 | + img2tensorboard.add_animated_gif_no_channels(writer, "second_output_final_batch", second_output_tensor, 64, |
| 123 | + 255, engine.state.epoch) |
| 124 | + second_label_tensor = torch.where(engine.state.batch[Dek.LABEL][1] > 0, ones, engine.state.batch[Dek.LABEL][1]) |
| 125 | + img2tensorboard.add_animated_gif(writer, "second_label_final_batch", second_label_tensor, 64, |
| 126 | + 255, engine.state.epoch) |
| 127 | + third_output_tensor = engine.state.output[0][1][2].detach().cpu() |
| 128 | + img2tensorboard.add_animated_gif_no_channels(writer, "third_output_final_batch", third_output_tensor, 64, |
| 129 | + 255, engine.state.epoch) |
| 130 | + third_label_tensor = torch.where(engine.state.batch[Dek.LABEL][2] > 0, ones, engine.state.batch[Dek.LABEL][2]) |
| 131 | + img2tensorboard.add_animated_gif(writer, "third_label_final_batch", third_label_tensor, 64, |
| 132 | + 255, engine.state.epoch) |
| 133 | + engine.logger.info("Epoch[%s] Loss: %s", engine.state.epoch, engine.state.output[1]) |
| 134 | + |
| 135 | +writer = SummaryWriter() |
| 136 | + |
| 137 | +# Set parameters for validation |
| 138 | +validation_every_n_epochs = 1 |
| 139 | +metric_name = 'Mean_Dice' |
| 140 | + |
| 141 | +# add evaluation metric to the evaluator engine |
| 142 | +val_metrics = {metric_name: MeanDice(add_sigmoid=True, to_onehot_y=False)} |
| 143 | +evaluator = create_supervised_evaluator(net, val_metrics, device, True, |
| 144 | + prepare_batch=prepare_batch, |
| 145 | + output_transform=lambda x, y, y_pred: (y_pred[0], y)) |
| 146 | + |
| 147 | +# Add stats event handler to print validation stats via evaluator |
| 148 | +logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
| 149 | +val_stats_handler = StatsHandler() |
| 150 | +val_stats_handler.attach(evaluator) |
| 151 | + |
| 152 | +# Add early stopping handler to evaluator. |
| 153 | +early_stopper = EarlyStopping(patience=4, |
| 154 | + score_function=stopping_fn_from_metric(metric_name), |
| 155 | + trainer=trainer) |
| 156 | +evaluator.add_event_handler(event_name=Events.EPOCH_COMPLETED, handler=early_stopper) |
| 157 | + |
| 158 | +# create a validation data loader |
| 159 | +val_ds = NiftiDatasetd(images[-20:], segs[-20:], transform=transforms) |
| 160 | +val_loader = DataLoader(ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) |
| 161 | + |
| 162 | + |
| 163 | +@trainer.on(Events.EPOCH_COMPLETED(every=validation_every_n_epochs)) |
| 164 | +def run_validation(engine): |
| 165 | + evaluator.run(val_loader) |
| 166 | + |
| 167 | +@evaluator.on(Events.EPOCH_COMPLETED) |
| 168 | +def log_metrics_to_tensorboard(engine): |
| 169 | + for _, value in engine.state.metrics.items(): |
| 170 | + writer.add_scalar('Metrics/' + metric_name, value, trainer.state.epoch) |
| 171 | + |
| 172 | +# create a training data loader |
| 173 | +logging.basicConfig(stream=sys.stdout, level=logging.INFO) |
| 174 | + |
| 175 | +train_ds = NiftiDatasetd(images[:20], segs[:20], transform=transforms) |
| 176 | +train_loader = DataLoader(train_ds, batch_size=5, num_workers=8, pin_memory=torch.cuda.is_available()) |
| 177 | + |
| 178 | +train_epochs = 30 |
| 179 | +state = trainer.run(train_loader, train_epochs) |
0 commit comments