|
| 1 | +# Copyright 2020 MONAI Consortium |
| 2 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 3 | +# you may not use this file except in compliance with the License. |
| 4 | +# You may obtain a copy of the License at |
| 5 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 6 | +# Unless required by applicable law or agreed to in writing, software |
| 7 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 8 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 9 | +# See the License for the specific language governing permissions and |
| 10 | +# limitations under the License. |
| 11 | + |
| 12 | +import os |
| 13 | +import shutil |
| 14 | +import subprocess |
| 15 | +import tarfile |
| 16 | +import tempfile |
| 17 | +import unittest |
| 18 | + |
| 19 | +import numpy as np |
| 20 | +import torch |
| 21 | +from torch.utils.data import DataLoader |
| 22 | + |
| 23 | +import monai |
| 24 | +from monai.metrics import compute_roc_auc |
| 25 | +from monai.networks.nets import densenet121 |
| 26 | +from monai.transforms import (AddChannel, Compose, LoadPNG, RandFlip, RandRotate, RandZoom, Resize, ScaleIntensity, |
| 27 | + ToTensor) |
| 28 | +from tests.utils import skip_if_quick |
| 29 | + |
| 30 | +TEST_DATA_URL = 'https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz' |
| 31 | + |
| 32 | + |
| 33 | +class MedNISTDataset(torch.utils.data.Dataset): |
| 34 | + |
| 35 | + def __init__(self, image_files, labels, transforms): |
| 36 | + self.image_files = image_files |
| 37 | + self.labels = labels |
| 38 | + self.transforms = transforms |
| 39 | + |
| 40 | + def __len__(self): |
| 41 | + return len(self.image_files) |
| 42 | + |
| 43 | + def __getitem__(self, index): |
| 44 | + return self.transforms(self.image_files[index]), self.labels[index] |
| 45 | + |
| 46 | + |
| 47 | +def run_training_test(root_dir, train_x, train_y, val_x, val_y, device=torch.device("cuda:0")): |
| 48 | + |
| 49 | + monai.config.print_config() |
| 50 | + # define transforms for image and classification |
| 51 | + train_transforms = Compose([ |
| 52 | + LoadPNG(), |
| 53 | + AddChannel(), |
| 54 | + ScaleIntensity(), |
| 55 | + RandRotate(degrees=15, prob=0.5), |
| 56 | + RandFlip(spatial_axis=0, prob=0.5), |
| 57 | + RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5), |
| 58 | + Resize(spatial_size=(64, 64), mode='constant'), |
| 59 | + ToTensor() |
| 60 | + ]) |
| 61 | + train_transforms.set_random_state(1234) |
| 62 | + val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()]) |
| 63 | + |
| 64 | + # create train, val data loaders |
| 65 | + train_ds = MedNISTDataset(train_x, train_y, train_transforms) |
| 66 | + train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10) |
| 67 | + |
| 68 | + val_ds = MedNISTDataset(val_x, val_y, val_transforms) |
| 69 | + val_loader = DataLoader(val_ds, batch_size=300, num_workers=10) |
| 70 | + |
| 71 | + model = densenet121( |
| 72 | + spatial_dims=2, |
| 73 | + in_channels=1, |
| 74 | + out_channels=len(np.unique(train_y)), |
| 75 | + ).to(device) |
| 76 | + loss_function = torch.nn.CrossEntropyLoss() |
| 77 | + optimizer = torch.optim.Adam(model.parameters(), 1e-5) |
| 78 | + epoch_num = 4 |
| 79 | + val_interval = 1 |
| 80 | + |
| 81 | + # start training validation |
| 82 | + best_metric = -1 |
| 83 | + best_metric_epoch = -1 |
| 84 | + epoch_loss_values = list() |
| 85 | + metric_values = list() |
| 86 | + model_filename = os.path.join(root_dir, 'best_metric_model.pth') |
| 87 | + for epoch in range(epoch_num): |
| 88 | + print('-' * 10) |
| 89 | + print('Epoch {}/{}'.format(epoch + 1, epoch_num)) |
| 90 | + model.train() |
| 91 | + epoch_loss = 0 |
| 92 | + step = 0 |
| 93 | + for batch_data in train_loader: |
| 94 | + step += 1 |
| 95 | + inputs, labels = batch_data[0].to(device), batch_data[1].to(device) |
| 96 | + optimizer.zero_grad() |
| 97 | + outputs = model(inputs) |
| 98 | + loss = loss_function(outputs, labels) |
| 99 | + loss.backward() |
| 100 | + optimizer.step() |
| 101 | + epoch_loss += loss.item() |
| 102 | + epoch_loss /= step |
| 103 | + epoch_loss_values.append(epoch_loss) |
| 104 | + print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss)) |
| 105 | + |
| 106 | + if (epoch + 1) % val_interval == 0: |
| 107 | + model.eval() |
| 108 | + with torch.no_grad(): |
| 109 | + y_pred = torch.tensor([], dtype=torch.float32, device=device) |
| 110 | + y = torch.tensor([], dtype=torch.long, device=device) |
| 111 | + for val_data in val_loader: |
| 112 | + val_images, val_labels = val_data[0].to(device), val_data[1].to(device) |
| 113 | + y_pred = torch.cat([y_pred, model(val_images)], dim=0) |
| 114 | + y = torch.cat([y, val_labels], dim=0) |
| 115 | + auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, add_softmax=True) |
| 116 | + metric_values.append(auc_metric) |
| 117 | + acc_value = torch.eq(y_pred.argmax(dim=1), y) |
| 118 | + acc_metric = acc_value.sum().item() / len(acc_value) |
| 119 | + if auc_metric > best_metric: |
| 120 | + best_metric = auc_metric |
| 121 | + best_metric_epoch = epoch + 1 |
| 122 | + torch.save(model.state_dict(), model_filename) |
| 123 | + print('saved new best metric model') |
| 124 | + print("current epoch %d current AUC: %0.4f current accuracy: %0.4f best AUC: %0.4f at epoch %d" % |
| 125 | + (epoch + 1, auc_metric, acc_metric, best_metric, best_metric_epoch)) |
| 126 | + print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch)) |
| 127 | + return epoch_loss_values, best_metric, best_metric_epoch |
| 128 | + |
| 129 | + |
| 130 | +def run_inference_test(root_dir, test_x, test_y, device=torch.device("cuda:0")): |
| 131 | + # define transforms for image and classification |
| 132 | + val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()]) |
| 133 | + val_ds = MedNISTDataset(test_x, test_y, val_transforms) |
| 134 | + val_loader = DataLoader(val_ds, batch_size=300, num_workers=10) |
| 135 | + |
| 136 | + model = densenet121( |
| 137 | + spatial_dims=2, |
| 138 | + in_channels=1, |
| 139 | + out_channels=len(np.unique(test_y)), |
| 140 | + ).to(device) |
| 141 | + |
| 142 | + model_filename = os.path.join(root_dir, 'best_metric_model.pth') |
| 143 | + model.load_state_dict(torch.load(model_filename)) |
| 144 | + model.eval() |
| 145 | + y_true = list() |
| 146 | + y_pred = list() |
| 147 | + with torch.no_grad(): |
| 148 | + for test_data in val_loader: |
| 149 | + test_images, test_labels = test_data[0].to(device), test_data[1].to(device) |
| 150 | + pred = model(test_images).argmax(dim=1) |
| 151 | + for i in range(len(pred)): |
| 152 | + y_true.append(test_labels[i].item()) |
| 153 | + y_pred.append(pred[i].item()) |
| 154 | + tps = [np.sum((np.asarray(y_true) == idx) & (np.asarray(y_pred) == idx)) for idx in np.unique(test_y)] |
| 155 | + return tps |
| 156 | + |
| 157 | + |
| 158 | +class IntegrationClassification2D(unittest.TestCase): |
| 159 | + |
| 160 | + def setUp(self): |
| 161 | + torch.backends.cudnn.deterministic = True |
| 162 | + torch.backends.cudnn.benchmark = False |
| 163 | + np.random.seed(0) |
| 164 | + self.data_dir = tempfile.mkdtemp() |
| 165 | + |
| 166 | + # download |
| 167 | + subprocess.call(['wget', '-nv', '-P', self.data_dir, TEST_DATA_URL]) |
| 168 | + dataset_file = os.path.join(self.data_dir, 'MedNIST.tar.gz') |
| 169 | + assert os.path.exists(dataset_file) |
| 170 | + |
| 171 | + # extract tarfile |
| 172 | + datafile = tarfile.open(dataset_file) |
| 173 | + datafile.extractall(path=self.data_dir) |
| 174 | + datafile.close() |
| 175 | + |
| 176 | + # find image files and labels |
| 177 | + data_dir = os.path.join(self.data_dir, 'MedNIST') |
| 178 | + class_names = sorted(os.listdir(data_dir)) |
| 179 | + image_files = [[ |
| 180 | + os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name))) |
| 181 | + ] for class_name in class_names] |
| 182 | + image_file_list, image_classes = [], [] |
| 183 | + for i, class_name in enumerate(class_names): |
| 184 | + image_file_list.extend(image_files[i]) |
| 185 | + image_classes.extend([i] * len(image_files[i])) |
| 186 | + |
| 187 | + # split train, val, test |
| 188 | + valid_frac, test_frac = 0.1, 0.1 |
| 189 | + self.train_x, self.train_y = [], [] |
| 190 | + self.val_x, self.val_y = [], [] |
| 191 | + self.test_x, self.test_y = [], [] |
| 192 | + for i in range(len(image_classes)): |
| 193 | + rann = np.random.random() |
| 194 | + if rann < valid_frac: |
| 195 | + self.val_x.append(image_file_list[i]) |
| 196 | + self.val_y.append(image_classes[i]) |
| 197 | + elif rann < test_frac + valid_frac: |
| 198 | + self.test_x.append(image_file_list[i]) |
| 199 | + self.test_y.append(image_classes[i]) |
| 200 | + else: |
| 201 | + self.train_x.append(image_file_list[i]) |
| 202 | + self.train_y.append(image_classes[i]) |
| 203 | + |
| 204 | + np.random.seed(seed=None) |
| 205 | + self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0') |
| 206 | + |
| 207 | + def tearDown(self): |
| 208 | + shutil.rmtree(self.data_dir) |
| 209 | + |
| 210 | + @skip_if_quick |
| 211 | + def test_training(self): |
| 212 | + repeated = [] |
| 213 | + for i in range(2): |
| 214 | + torch.manual_seed(0) |
| 215 | + |
| 216 | + repeated.append([]) |
| 217 | + losses, best_metric, best_metric_epoch = \ |
| 218 | + run_training_test(self.data_dir, self.train_x, self.train_y, self.val_x, self.val_y, device=self.device) |
| 219 | + |
| 220 | + # check training properties |
| 221 | + np.testing.assert_allclose( |
| 222 | + losses, [0.8501208358129878, 0.18469145818121113, 0.08108749352158255, 0.04965383692342005], rtol=1e-3) |
| 223 | + repeated[i].extend(losses) |
| 224 | + print('best metric', best_metric) |
| 225 | + np.testing.assert_allclose(best_metric, 0.9999480167572079, rtol=1e-4) |
| 226 | + repeated[i].append(best_metric) |
| 227 | + np.testing.assert_allclose(best_metric_epoch, 4) |
| 228 | + model_file = os.path.join(self.data_dir, 'best_metric_model.pth') |
| 229 | + self.assertTrue(os.path.exists(model_file)) |
| 230 | + |
| 231 | + infer_metric = run_inference_test(self.data_dir, self.test_x, self.test_y, device=self.device) |
| 232 | + |
| 233 | + # check inference properties |
| 234 | + np.testing.assert_allclose(np.asarray(infer_metric), [1036, 895, 982, 1033, 958, 1047]) |
| 235 | + repeated[i].extend(infer_metric) |
| 236 | + |
| 237 | + np.testing.assert_allclose(repeated[0], repeated[1]) |
| 238 | + |
| 239 | + |
| 240 | +if __name__ == '__main__': |
| 241 | + unittest.main() |
0 commit comments