Skip to content

Commit 4a5f3ac

Browse files
authored
adds integration test classification (#282)
* adds integration test classification, ci tests to use update-to-date pip install
1 parent cd34d39 commit 4a5f3ac

File tree

2 files changed

+242
-1
lines changed

2 files changed

+242
-1
lines changed

.github/workflows/setupapp.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ jobs:
1818
which python
1919
python -m pip install --upgrade pip --no-cache-dir
2020
python -m pip uninstall -y torch torchvision
21-
python -m pip install -q -r requirements.txt --no-cache-dir
21+
python -m pip install --upgrade -q -r requirements.txt --no-cache-dir
2222
python -m pip list
2323
- name: Run unit tests report coverage
2424
run: |
Lines changed: 241 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,241 @@
1+
# Copyright 2020 MONAI Consortium
2+
# Licensed under the Apache License, Version 2.0 (the "License");
3+
# you may not use this file except in compliance with the License.
4+
# You may obtain a copy of the License at
5+
# http://www.apache.org/licenses/LICENSE-2.0
6+
# Unless required by applicable law or agreed to in writing, software
7+
# distributed under the License is distributed on an "AS IS" BASIS,
8+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9+
# See the License for the specific language governing permissions and
10+
# limitations under the License.
11+
12+
import os
13+
import shutil
14+
import subprocess
15+
import tarfile
16+
import tempfile
17+
import unittest
18+
19+
import numpy as np
20+
import torch
21+
from torch.utils.data import DataLoader
22+
23+
import monai
24+
from monai.metrics import compute_roc_auc
25+
from monai.networks.nets import densenet121
26+
from monai.transforms import (AddChannel, Compose, LoadPNG, RandFlip, RandRotate, RandZoom, Resize, ScaleIntensity,
27+
ToTensor)
28+
from tests.utils import skip_if_quick
29+
30+
TEST_DATA_URL = 'https://www.dropbox.com/s/5wwskxctvcxiuea/MedNIST.tar.gz'
31+
32+
33+
class MedNISTDataset(torch.utils.data.Dataset):
34+
35+
def __init__(self, image_files, labels, transforms):
36+
self.image_files = image_files
37+
self.labels = labels
38+
self.transforms = transforms
39+
40+
def __len__(self):
41+
return len(self.image_files)
42+
43+
def __getitem__(self, index):
44+
return self.transforms(self.image_files[index]), self.labels[index]
45+
46+
47+
def run_training_test(root_dir, train_x, train_y, val_x, val_y, device=torch.device("cuda:0")):
48+
49+
monai.config.print_config()
50+
# define transforms for image and classification
51+
train_transforms = Compose([
52+
LoadPNG(),
53+
AddChannel(),
54+
ScaleIntensity(),
55+
RandRotate(degrees=15, prob=0.5),
56+
RandFlip(spatial_axis=0, prob=0.5),
57+
RandZoom(min_zoom=0.9, max_zoom=1.1, prob=0.5),
58+
Resize(spatial_size=(64, 64), mode='constant'),
59+
ToTensor()
60+
])
61+
train_transforms.set_random_state(1234)
62+
val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()])
63+
64+
# create train, val data loaders
65+
train_ds = MedNISTDataset(train_x, train_y, train_transforms)
66+
train_loader = DataLoader(train_ds, batch_size=300, shuffle=True, num_workers=10)
67+
68+
val_ds = MedNISTDataset(val_x, val_y, val_transforms)
69+
val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)
70+
71+
model = densenet121(
72+
spatial_dims=2,
73+
in_channels=1,
74+
out_channels=len(np.unique(train_y)),
75+
).to(device)
76+
loss_function = torch.nn.CrossEntropyLoss()
77+
optimizer = torch.optim.Adam(model.parameters(), 1e-5)
78+
epoch_num = 4
79+
val_interval = 1
80+
81+
# start training validation
82+
best_metric = -1
83+
best_metric_epoch = -1
84+
epoch_loss_values = list()
85+
metric_values = list()
86+
model_filename = os.path.join(root_dir, 'best_metric_model.pth')
87+
for epoch in range(epoch_num):
88+
print('-' * 10)
89+
print('Epoch {}/{}'.format(epoch + 1, epoch_num))
90+
model.train()
91+
epoch_loss = 0
92+
step = 0
93+
for batch_data in train_loader:
94+
step += 1
95+
inputs, labels = batch_data[0].to(device), batch_data[1].to(device)
96+
optimizer.zero_grad()
97+
outputs = model(inputs)
98+
loss = loss_function(outputs, labels)
99+
loss.backward()
100+
optimizer.step()
101+
epoch_loss += loss.item()
102+
epoch_loss /= step
103+
epoch_loss_values.append(epoch_loss)
104+
print("epoch %d average loss:%0.4f" % (epoch + 1, epoch_loss))
105+
106+
if (epoch + 1) % val_interval == 0:
107+
model.eval()
108+
with torch.no_grad():
109+
y_pred = torch.tensor([], dtype=torch.float32, device=device)
110+
y = torch.tensor([], dtype=torch.long, device=device)
111+
for val_data in val_loader:
112+
val_images, val_labels = val_data[0].to(device), val_data[1].to(device)
113+
y_pred = torch.cat([y_pred, model(val_images)], dim=0)
114+
y = torch.cat([y, val_labels], dim=0)
115+
auc_metric = compute_roc_auc(y_pred, y, to_onehot_y=True, add_softmax=True)
116+
metric_values.append(auc_metric)
117+
acc_value = torch.eq(y_pred.argmax(dim=1), y)
118+
acc_metric = acc_value.sum().item() / len(acc_value)
119+
if auc_metric > best_metric:
120+
best_metric = auc_metric
121+
best_metric_epoch = epoch + 1
122+
torch.save(model.state_dict(), model_filename)
123+
print('saved new best metric model')
124+
print("current epoch %d current AUC: %0.4f current accuracy: %0.4f best AUC: %0.4f at epoch %d" %
125+
(epoch + 1, auc_metric, acc_metric, best_metric, best_metric_epoch))
126+
print('train completed, best_metric: %0.4f at epoch: %d' % (best_metric, best_metric_epoch))
127+
return epoch_loss_values, best_metric, best_metric_epoch
128+
129+
130+
def run_inference_test(root_dir, test_x, test_y, device=torch.device("cuda:0")):
131+
# define transforms for image and classification
132+
val_transforms = Compose([LoadPNG(), AddChannel(), ScaleIntensity(), ToTensor()])
133+
val_ds = MedNISTDataset(test_x, test_y, val_transforms)
134+
val_loader = DataLoader(val_ds, batch_size=300, num_workers=10)
135+
136+
model = densenet121(
137+
spatial_dims=2,
138+
in_channels=1,
139+
out_channels=len(np.unique(test_y)),
140+
).to(device)
141+
142+
model_filename = os.path.join(root_dir, 'best_metric_model.pth')
143+
model.load_state_dict(torch.load(model_filename))
144+
model.eval()
145+
y_true = list()
146+
y_pred = list()
147+
with torch.no_grad():
148+
for test_data in val_loader:
149+
test_images, test_labels = test_data[0].to(device), test_data[1].to(device)
150+
pred = model(test_images).argmax(dim=1)
151+
for i in range(len(pred)):
152+
y_true.append(test_labels[i].item())
153+
y_pred.append(pred[i].item())
154+
tps = [np.sum((np.asarray(y_true) == idx) & (np.asarray(y_pred) == idx)) for idx in np.unique(test_y)]
155+
return tps
156+
157+
158+
class IntegrationClassification2D(unittest.TestCase):
159+
160+
def setUp(self):
161+
torch.backends.cudnn.deterministic = True
162+
torch.backends.cudnn.benchmark = False
163+
np.random.seed(0)
164+
self.data_dir = tempfile.mkdtemp()
165+
166+
# download
167+
subprocess.call(['wget', '-nv', '-P', self.data_dir, TEST_DATA_URL])
168+
dataset_file = os.path.join(self.data_dir, 'MedNIST.tar.gz')
169+
assert os.path.exists(dataset_file)
170+
171+
# extract tarfile
172+
datafile = tarfile.open(dataset_file)
173+
datafile.extractall(path=self.data_dir)
174+
datafile.close()
175+
176+
# find image files and labels
177+
data_dir = os.path.join(self.data_dir, 'MedNIST')
178+
class_names = sorted(os.listdir(data_dir))
179+
image_files = [[
180+
os.path.join(data_dir, class_name, x) for x in sorted(os.listdir(os.path.join(data_dir, class_name)))
181+
] for class_name in class_names]
182+
image_file_list, image_classes = [], []
183+
for i, class_name in enumerate(class_names):
184+
image_file_list.extend(image_files[i])
185+
image_classes.extend([i] * len(image_files[i]))
186+
187+
# split train, val, test
188+
valid_frac, test_frac = 0.1, 0.1
189+
self.train_x, self.train_y = [], []
190+
self.val_x, self.val_y = [], []
191+
self.test_x, self.test_y = [], []
192+
for i in range(len(image_classes)):
193+
rann = np.random.random()
194+
if rann < valid_frac:
195+
self.val_x.append(image_file_list[i])
196+
self.val_y.append(image_classes[i])
197+
elif rann < test_frac + valid_frac:
198+
self.test_x.append(image_file_list[i])
199+
self.test_y.append(image_classes[i])
200+
else:
201+
self.train_x.append(image_file_list[i])
202+
self.train_y.append(image_classes[i])
203+
204+
np.random.seed(seed=None)
205+
self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu:0')
206+
207+
def tearDown(self):
208+
shutil.rmtree(self.data_dir)
209+
210+
@skip_if_quick
211+
def test_training(self):
212+
repeated = []
213+
for i in range(2):
214+
torch.manual_seed(0)
215+
216+
repeated.append([])
217+
losses, best_metric, best_metric_epoch = \
218+
run_training_test(self.data_dir, self.train_x, self.train_y, self.val_x, self.val_y, device=self.device)
219+
220+
# check training properties
221+
np.testing.assert_allclose(
222+
losses, [0.8501208358129878, 0.18469145818121113, 0.08108749352158255, 0.04965383692342005], rtol=1e-3)
223+
repeated[i].extend(losses)
224+
print('best metric', best_metric)
225+
np.testing.assert_allclose(best_metric, 0.9999480167572079, rtol=1e-4)
226+
repeated[i].append(best_metric)
227+
np.testing.assert_allclose(best_metric_epoch, 4)
228+
model_file = os.path.join(self.data_dir, 'best_metric_model.pth')
229+
self.assertTrue(os.path.exists(model_file))
230+
231+
infer_metric = run_inference_test(self.data_dir, self.test_x, self.test_y, device=self.device)
232+
233+
# check inference properties
234+
np.testing.assert_allclose(np.asarray(infer_metric), [1036, 895, 982, 1033, 958, 1047])
235+
repeated[i].extend(infer_metric)
236+
237+
np.testing.assert_allclose(repeated[0], repeated[1])
238+
239+
240+
if __name__ == '__main__':
241+
unittest.main()

0 commit comments

Comments
 (0)