Skip to content

v0.5 to v0.6 migration guide

Yiheng Wang edited this page Aug 26, 2021 · 13 revisions

Migrating your v0.5 code to v0.6

In MONAI v0.6, we enhance the design of metrics and postprocessing transforms to provide flexible and advanced features, which, in the meantime, bring in breaking changes.

Please check out What's new in 0.6 for details of the new features.

To help smoothly migrate the existing code from MONAI v0.5 to v0.6, this document shows the detailed steps with code examples.

Decollate batch-first Tensor to list of channel-first Tensors

  1. After model forward and loss backward, to independently apply postprocessing transforms for every single data in a batch, need to execute decollate_batch to convert the batch Tensor to a list of Tensors.
  2. Currently, all the MONAI postprocessing transforms are updated to handle channel-first Tensor instead of batch-first Tensor. So both the preprocessing transforms and postprocessing transforms handle the same data shape. Just execute postprocessing transform for every items of the list.
  3. As all the postprocessing transforms expect Tensor type input, in order to ensure the data after decollate_batch is Tensor, suggest to add EnsureType or EnsureTyped transform.
  4. Use from_engine() utility to extract expected data from the decollated list, set first=True for scalar values because for the scalar values which don't have batch dimension, we copied it to every item of the decollated list.
  5. Code examples:

(1) If you are using the array based postprocessing transforms, a v0.5 classification program can be:

pred_trans = Activations(softmax=True)
label_trans = AsDiscrete(to_onehot=True, n_classes=5)

pred = model(image)
pred = pred_trans(pred)
label = label_trans(label)

metric(y_pred=pred, y=label)

And the corresponding code of v0.6 can be:

from monai.data import decollate_batch

pred_trans = Compose([EnsureType(), Activations(softmax=True)])
label_trans = Compose([EnsureType(), AsDiscrete(to_onehot=True, n_classes=5)])

pred = model(image)
pred = [pred_trans(i) for i in decollate_batch(pred)]
label = [label_trans(i) for i in decollate_batch(label)]

metric(y_pred=pred, y=label)

(2) If you are using the dictionary based postprocessing transforms, a v0.5 classification program can be:

postprocessing = Compose([
    Activations(keys="pred", softmax=True),
    AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])

data["pred"] = model(data["image"])
data = postprocessing(data)

metric(y_pred=data["pred"], y=data["label"])

And the corresponding code of v0.6 can be:

from monai.data import decollate_batch
from monai.handlers import from_engine

postprocessing = Compose([
    Activations(keys="pred", softmax=True),
    AsDiscrete(keys="label", to_onehot=True, n_classes=5),
])

data["pred"] = model(data["image"])
# decollate data into a list of dictionaries
data = [postprocessing(i) for in decollate_batch(data)]

# extract the `pred` and `label` to compute metric
pred, label = from_engine(["pred", "label"])(data)
metric(y_pred=pred, y=label)

For more detailed tutorial of decollate_batch, please check: decollate_batch tutorial.

Adjust the new metrics APIs to automatically support data parallel

  1. Support both batch-first Tensor and list of channel-first Tensors as input.
  2. Support data parallel in multi-GPUs or multi-nodes cases.
  3. Example code of a validation during the training of segmentation task

A typical code example of v0.5:

dice_metric = DiceMetric(include_background=True, reduction="mean")

metric_sum = 0.0
metric_count = 0
for val_data in val_loader:
    images, labels = val_data["img"].to(device), val_data["seg"].to(device)
    preds = val_post_tran(sliding_window_inference(images, (96, 96, 96), 4, model))
    value, not_nans = dice_metric(y_pred=preds, y=labels)
    metric_count += not_nans.item()
    metric_sum += value.item() * not_nans.item()
metric = metric_sum / metric_count

And the corresponding code of v0.6 can be:

dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)

for val_data in val_loader:
    images, labels = val_data["img"].to(device), val_data["seg"].to(device)
    preds = sliding_window_inference(val_images, (96, 96, 96), 4, model)
    # decollate prediction into a list and execute post processing for every item
    preds = [postprocessing(i) for i in decollate_batch(preds)]
    # compute metric for current iteration
    dice_metric(y_pred=val_outputs, y=val_labels)

# aggregate and compute the final result of metric
metric = dice_metric.aggregate().item()
dice_metric.reset()

For more details about how to compute metrics in multi-processing, please check: compute metrics example.

Update the batch_transform or output_transform of several event handlers

The batch_transform and output_transform args of v0.5 can be:

StatsHandler(tag_name="train_loss", output_transform=lambda x: x["loss"]),
SegmentationSaver(
    output_dir=root_dir,
    batch_transform=lambda batch: batch["image_meta_dict"],
    output_transform=lambda output: output["pred"],
)

And the corresponding code of v0.6 can be:

from monai.handlers import from_engine

StatsHandler(tag_name="train_loss", output_transform=from_engine("loss", first=True)),
SegmentationSaver(
    output_dir=root_dir,
    batch_transform=from_engine("image_meta_dict"),
    output_transform=from_engine("pred"),
),

Update all the post transform to postprocessing

Some args of post_transform changed to postprocessing in v0.6, for example, the arg of SupervisedTrainer:

trainer = SupervisedTrainer(
    device=device,
    max_epochs=5,
    train_data_loader=train_loader,
    network=net,
    optimizer=opt,
    loss_function=loss,
    inferer=SimpleInferer(),
    postprocessing=train_postprocessing,
    key_train_metric={"train_acc": Accuracy(output_transform=from_engine(["pred", "label"]))},
    train_handlers=train_handlers,
)

DynUNet

In v0.6, DynUNet has been updated, the previous version is still made available:

from monai.networks.nets.dynunet_v1 import DynUNetV1 as DynUNet

dynunet_v1 will be removed in the future release.

UnetResBlock, UnetBasicBlock and UnetUpBlock

In v0.5, affine=True is used as default if using instance norm. To maintain the consistency, please specify norm_name=("instance", {"affine": True}) for these classes in v0.6.

Attachment: a diff file for the segmentation training pipelines:

diff --git a/3d_segmentation/torch/unet_training_dict.py b/3d_segmentation/torch/unet_training_dict.py
index 0c85cbf..febecb4 100644
--- a/3d_segmentation/torch/unet_training_dict.py
+++ b/3d_segmentation/torch/unet_training_dict.py
@@ -22,7 +22,7 @@ from torch.utils.data import DataLoader
 from torch.utils.tensorboard import SummaryWriter
 
 import monai
-from monai.data import create_test_image_3d, list_data_collate
+from monai.data import create_test_image_3d, list_data_collate, decollate_batch
 from monai.inferers import sliding_window_inference
 from monai.metrics import DiceMetric
 from monai.transforms import (
@@ -34,7 +34,8 @@ from monai.transforms import (
     RandCropByPosNegLabeld,
     RandRotate90d,
     ScaleIntensityd,
-    ToTensord,
+    EnsureTyped,
+    EnsureType,
 )
 from monai.visualize import plot_2d_or_3d_image
 
@@ -69,7 +70,7 @@ def main(tempdir):
                 keys=["img", "seg"], label_key="seg", spatial_size=[96, 96, 96], pos=1, neg=1, num_samples=4
             ),
             RandRotate90d(keys=["img", "seg"], prob=0.5, spatial_axes=[0, 2]),
-            ToTensord(keys=["img", "seg"]),
+            EnsureTyped(keys=["img", "seg"]),
         ]
     )
     val_transforms = Compose(
@@ -77,7 +78,7 @@ def main(tempdir):
             LoadImaged(keys=["img", "seg"]),
             AsChannelFirstd(keys=["img", "seg"], channel_dim=-1),
             ScaleIntensityd(keys="img"),
-            ToTensord(keys=["img", "seg"]),
+            EnsureTyped(keys=["img", "seg"]),
         ]
     )
 
@@ -102,8 +103,8 @@ def main(tempdir):
     # create a validation data loader
     val_ds = monai.data.Dataset(data=val_files, transform=val_transforms)
     val_loader = DataLoader(val_ds, batch_size=1, num_workers=4, collate_fn=list_data_collate)
-    dice_metric = DiceMetric(include_background=True, reduction="mean")
-    post_trans = Compose([Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
+    dice_metric = DiceMetric(include_background=True, reduction="mean", get_not_nans=False)
+    post_trans = Compose([EnsureType(), Activations(sigmoid=True), AsDiscrete(threshold_values=True)])
     # create UNet, DiceLoss and Adam optimizer
     device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
     model = monai.networks.nets.UNet(
@@ -149,8 +150,6 @@ def main(tempdir):
         if (epoch + 1) % val_interval == 0:
             model.eval()
             with torch.no_grad():
-                metric_sum = 0.0
-                metric_count = 0
                 val_images = None
                 val_labels = None
                 val_outputs = None
@@ -159,11 +158,14 @@ def main(tempdir):
                     roi_size = (96, 96, 96)
                     sw_batch_size = 4
                     val_outputs = sliding_window_inference(val_images, roi_size, sw_batch_size, model)
-                    val_outputs = post_trans(val_outputs)
-                    value, _ = dice_metric(y_pred=val_outputs, y=val_labels)
-                    metric_count += len(value)
-                    metric_sum += value.item() * len(value)
-                metric = metric_sum / metric_count
+                    val_outputs = [post_trans(i) for i in decollate_batch(val_outputs)]
+                    # compute metric for current iteration
+                    dice_metric(y_pred=val_outputs, y=val_labels)
+                # aggregate the final mean dice result
+                metric = dice_metric.aggregate().item()
+                # reset the status for next validation round
+                dice_metric.reset()
+
                 metric_values.append(metric)
                 if metric > best_metric:
                     best_metric = metric
Clone this wiki locally