Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
Fixed Architecture and Dependencies
  • Loading branch information
abhi4ssj committed Nov 9, 2018
1 parent f3792dc commit f9f2fcf
Show file tree
Hide file tree
Showing 10 changed files with 176 additions and 51 deletions.
39 changes: 25 additions & 14 deletions few_shot_segmentor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch
import torch.nn as nn
from nn_common_modules import modules as sm
from data_utils import split_batch


class Conditioner(nn.Module):
Expand All @@ -19,7 +20,7 @@ def __init__(self, params):
params['num_channels'] = 64
self.genblock2 = sm.GenericBlock(params)
self.genblock3 = sm.GenericBlock(params)
self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'], return_indices=True)
self.maxpool = nn.MaxPool2d(kernel_size=params['pool'], stride=params['stride_pool'])
self.tanh = nn.Tanh()

def forward(self, input):
Expand All @@ -29,11 +30,14 @@ def forward(self, input):
o4 = self.maxpool(o3)
o5 = self.genblock3(o4)
batch_size, num_channels, H, W = o1.size()
o6 = self.tanh(o1.view(batch_size, num_channels, -1).mean(dim=2))
o6 = o1.view(batch_size, num_channels, -1).mean(dim=2)
# o6 = self.tanh(o1.view(batch_size, num_channels, -1).mean(dim=2))
batch_size, num_channels, H, W = o3.size()
o7 = self.tanh(o3.view(batch_size, num_channels, -1).mean(dim=2))
o7 = o3.view(batch_size, num_channels, -1).mean(dim=2)
# o7 = self.tanh(o3.view(batch_size, num_channels, -1).mean(dim=2))
batch_size, num_channels, H, W = o5.size()
o8 = self.tanh(o5.view(batch_size, num_channels, -1).mean(dim=2))
o8 = o5.view(batch_size, num_channels, -1).mean(dim=2)
# o8 = self.tanh(o5.view(batch_size, num_channels, -1).mean(dim=2))
return o6, o7, o8


Expand All @@ -58,6 +62,7 @@ class Segmentor(nn.Module):

def __init__(self, params):
super(Segmentor, self).__init__()
params['num_channels'] = 1
self.encode1 = sm.EncoderBlock(params)
params['num_channels'] = 64
self.encode2 = sm.EncoderBlock(params)
Expand Down Expand Up @@ -94,7 +99,7 @@ class FewShotSegmentor(nn.Module):
'''

def __init__(self, params):
super(FewShotSegmentor).__init__()
super(FewShotSegmentor, self).__init__()
self.conditioner = Conditioner(params)
self.segmentor = Segmentor(params)

Expand Down Expand Up @@ -128,27 +133,33 @@ def save(self, path):
print('Saving model... %s' % path)
torch.save(self, path)

def predict(self, X, device=0, enable_dropout=False):
def predict(self, X, y, query_label, device=0, enable_dropout=False):
"""
Predicts the outout after the model is trained.
Inputs:
- X: Volume to be predicted
"""
self.eval()

if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
input1, input2, y2 = split_batch(X, y, query_label)
input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device)

if enable_dropout:
self.enable_test_dropout()

with torch.no_grad():
out = self.forward(X)
out = self.forward(input1, input2)

max_val, idx = torch.max(out, 1)
# max_val, idx = torch.max(out, 1)
idx = out > 0.5
idx = idx.data.cpu().numpy()
prediction = np.squeeze(idx)
del X, out, idx, max_val
del X, out, idx
return prediction


def to_cuda(X, device):
if type(X) is np.ndarray:
X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True)
elif type(X) is torch.Tensor and not X.is_cuda:
X = X.type(torch.FloatTensor).cuda(device, non_blocking=True)
return X
18 changes: 10 additions & 8 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
import torch

import utils.evaluator as eu
from few_shot_segmentor import QuickNat
from few_shot_segmentor import FewShotSegmentor
from settings import Settings
from solver import Solver
from utils.data_utils import get_imdb_dataset
from utils.log_utils import LogWriter
from utils.shot_batch_sampler import OneShotBatchSampler

torch.set_default_tensor_type('torch.FloatTensor')

Expand All @@ -24,14 +25,15 @@ def load_data(data_params):
def train(train_params, common_params, data_params, net_params):
train_data, test_data = load_data(data_params)

train_loader = torch.utils.data.DataLoader(train_data, batch_size=train_params['train_batch_size'], shuffle=True,
num_workers=4, pin_memory=True)
val_loader = torch.utils.data.DataLoader(test_data, batch_size=train_params['val_batch_size'], shuffle=False,
num_workers=4, pin_memory=True)
train_sampler = OneShotBatchSampler(train_data.y, 'train', 5, iteration=100)
test_sampler = OneShotBatchSampler(test_data.y, 'val', 5, iteration=100)

quicknat_model = QuickNat(net_params)
train_loader = torch.utils.data.DataLoader(train_data, batch_sampler=train_sampler)
val_loader = torch.utils.data.DataLoader(test_data, batch_sampler=test_sampler)

solver = Solver(quicknat_model,
few_shot_model = FewShotSegmentor(net_params)

solver = Solver(few_shot_model,
device=common_params['device'],
num_class=net_params['num_class'],
optim_args={"lr": train_params['learning_rate'],
Expand All @@ -51,7 +53,7 @@ def train(train_params, common_params, data_params, net_params):

solver.train(train_loader, val_loader)
final_model_path = os.path.join(common_params['save_model_dir'], train_params['final_model_file'])
quicknat_model.save(final_model_path)
few_shot_model.save(final_model_path)
print("final model saved @ " + str(final_model_path))


Expand Down
2 changes: 1 addition & 1 deletion settings.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ final_model_file = "few_shot_segmentor.pth.tar"
learning_rate = 5e-4
train_batch_size = 2
val_batch_size = 2
log_nth = 50
log_nth = 10
num_epochs = 15
optim_betas = (0.9, 0.999)
optim_eps = 1e-8
Expand Down
57 changes: 38 additions & 19 deletions solver.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import os
import numpy as np
import torch
from nn_additional_losses import losses as additional_losses
# from nn_additional_losses import losses as additional_losses
from torch.optim import lr_scheduler

import torch.nn as nn
from utils.log_utils import LogWriter
import utils.common_utils as common_utils
from data_utils import split_batch

# plt.interactive(True)

CHECKPOINT_FILE_NAME = 'checkpoint.pth.tar'

Expand All @@ -20,7 +22,7 @@ def __init__(self,
num_class,
optim=torch.optim.Adam,
optim_args={},
loss_func=additional_losses.CombinedLoss(),
loss_func=nn.BCELoss(),
model_name='FewShotSegmentor',
labels=None,
num_epochs=10,
Expand Down Expand Up @@ -59,10 +61,6 @@ def __init__(self,
if use_last_checkpoint:
self.load_checkpoint()





# TODO:Need to correct the CM and dice score calculation.
def train(self, train_loader, val_loader):
"""
Expand All @@ -85,6 +83,7 @@ def train(self, train_loader, val_loader):
print('START TRAINING. : model name = %s, device = %s' % (
self.model_name, torch.cuda.get_device_name(self.device)))
current_iteration = self.start_iteration

for epoch in range(self.start_epoch, self.num_epochs + 1):
print("\n==== Epoch [ %d / %d ] START ====" % (epoch, self.num_epochs))
for phase in ['train', 'val']:
Expand All @@ -97,18 +96,24 @@ def train(self, train_loader, val_loader):
scheduler.step()
else:
model.eval()

for i_batch, sample_batched in enumerate(dataloaders[phase]):
X = sample_batched[0].type(torch.FloatTensor)
y = sample_batched[1].type(torch.LongTensor)
w = sample_batched[2].type(torch.FloatTensor)

# TODO: split x,y in input1,input2, label
query_label = int(dataloaders[phase].batch_sampler.query_label)
input1, input2, y2 = split_batch(X, y, query_label)
# plot_img(input2[0].squeeze().data, input1[0,0,:,:].squeeze().data, y2[0].squeeze().data, input1[0,1,:,:].squeeze().data)
if model.is_cuda:
X, y, w = X.cuda(self.device, non_blocking=True), y.cuda(self.device,
non_blocking=True), w.cuda(self.device,
non_blocking=True)
input1, input2, y2 = input1.cuda(self.device, non_blocking=True), input2.cuda(self.device,
non_blocking=True), y2.cuda(
self.device, non_blocking=True)

output = model(X)
loss = self.loss_func(output, y, w)
output = model(input1, input2)
# TODO: add weights
loss = self.loss_func(output, y2)
if phase == 'train':
optim.zero_grad()
loss.backward()
Expand All @@ -119,10 +124,12 @@ def train(self, train_loader, val_loader):

loss_arr.append(loss.item())

_, batch_output = torch.max(output, dim=1)
out_list.append(batch_output.cpu())
y_list.append(y.cpu())
# _, batch_output = torch.max(output, dim=1)
batch_output = output > 0.5

out_list.append(batch_output.cpu())
y_list.append(y2.cpu())
# del X, y, w, output, loss
del X, y, w, output, batch_output, loss
torch.cuda.empty_cache()
if phase == 'val':
Expand All @@ -134,10 +141,8 @@ def train(self, train_loader, val_loader):
with torch.no_grad():
out_arr, y_arr = torch.cat(out_list), torch.cat(y_list)
self.logWriter.loss_per_epoch(loss_arr, phase, epoch)
index = np.random.choice(len(dataloaders[phase].dataset.X), 3, replace=False)
self.logWriter.image_per_epoch(model.predict(dataloaders[phase].dataset.X[index], self.device),
dataloaders[phase].dataset.y[index], phase, epoch)
self.logWriter.cm_per_epoch(phase, out_arr, y_arr, epoch)
index = np.random.choice(len(out_arr), 3, replace=False)
self.logWriter.image_per_epoch(out_arr[index], y_arr[index], phase, epoch)
self.logWriter.dice_score_per_epoch(phase, out_arr, y_arr, epoch)

print("==== Epoch [" + str(epoch) + " / " + str(self.num_epochs) + "] DONE ====")
Expand Down Expand Up @@ -174,3 +179,17 @@ def load_checkpoint(self):
print("=> loaded checkpoint '{}' (epoch {})".format(self.checkpoint_path, checkpoint['epoch']))
else:
print("=> no checkpoint found at '{}'".format(self.checkpoint_path))

# def plot_img(data1=None, data2=None, label1=None, label2=None):
# print(plt.get_backend())
# fig = plt.figure()
# plt.subplot(1,4,1)
# plt.imshow(data1, cmap='gray')
# plt.subplot(1, 4, 2)
# plt.imshow(data2, cmap='gray')
# plt.subplot(1,4,3)
# plt.imshow(label1)
# plt.subplot(1, 4, 4)
# plt.imshow(label2)
# plt.ioff()
# plt.show()
2 changes: 1 addition & 1 deletion utils/convert_h5.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
- python utils/convert_h5.py -dd /home/masterthesis/shayan/nas_drive/Data_Neuro/IXI/IXI_FS -ld /home/masterthesis/shayan/nas_drive/Data_Neuro/IXI/IXI_FS -ds 98,2 -rc FS -o COR -df datasets/IXI/coronal
- python3.6 utils/convert_h5.py -dd /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral -ld /home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral -trv datasets/train_volumes.txt -tev datasets/test_volumes.txt -rc WholeBody -o SAG -df datasets/coronal
"""

import argparse
Expand Down
14 changes: 12 additions & 2 deletions utils/data_utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import os

import h5py
import nibabel as nb
import numpy as np
import torch
import torch.utils.data as data
import scipy.io as sio

# import preprocessor as preprocessor
import utils.preprocessor as preprocessor


Expand Down Expand Up @@ -129,3 +128,14 @@ def load_file_paths(data_dir, label_dir, volumes_txt_file=None):
file_paths = [os.path.join(data_dir, vol) for vol in volumes_to_use]

return file_paths


def split_batch(X, y, query_label):
batch_size = len(X) // 2
input1 = X[0:batch_size, :, :, :]
input2 = X[batch_size:, :, :, :]
y1 = (y[0:batch_size, :, :] == query_label).type(torch.FloatTensor)
y2 = (y[batch_size:, :, :] == query_label).type(torch.FloatTensor)
y2 = y2.unsqueeze(1)
input1 = torch.cat([input1, y1.unsqueeze(1)], dim=1)
return input1, input2, y2
12 changes: 12 additions & 0 deletions utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,18 @@
import utils.data_utils as du


def dice_score_binary(vol_output, ground_truth, no_samples=10, mode='train'):
ground_truth = ground_truth.type(torch.FloatTensor)
vol_output = vol_output.type(torch.FloatTensor)
if mode == 'train':
samples = np.random.choice(len(vol_output), no_samples)
vol_output, ground_truth = vol_output[samples], ground_truth[samples]
inter = 2*torch.sum(torch.mul(ground_truth, vol_output))
union = torch.sum(ground_truth) + torch.sum(vol_output) + 0.0001

return torch.mean(torch.div(inter, union))


def dice_confusion_matrix(vol_output, ground_truth, num_classes, no_samples=10, mode='train'):
dice_cm = torch.zeros(num_classes, num_classes)
if mode == 'train':
Expand Down
12 changes: 7 additions & 5 deletions utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import matplotlib.pyplot as plt
import numpy as np
from tensorboardX import SummaryWriter

import torch
import evaluator as eu

plt.switch_backend('agg')
Expand Down Expand Up @@ -86,8 +86,10 @@ def plot_cm(self, caption, phase, cm, step=None):

def dice_score_per_epoch(self, phase, output, correct_labels, epoch):
print("Dice Score...", end='', flush=True)
ds = eu.dice_score_perclass(output, correct_labels, self.num_class)
self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch)
# TODO: multiclass vs binary
ds = eu.dice_score_binary(output, correct_labels, self.num_class)
print('Dice score is '+str(ds))
# self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch)

print("DONE", flush=True)

Expand Down Expand Up @@ -123,10 +125,10 @@ def image_per_epoch(self, prediction, ground_truth, phase, epoch):
fig, ax = plt.subplots(nrows=nrows, ncols=ncols, figsize=(10, 20))

for i in range(nrows):
ax[i][0].imshow(prediction[i], cmap='CMRmap', vmin=0, vmax=self.num_class - 1)
ax[i][0].imshow(torch.squeeze(prediction[i]), cmap='CMRmap', vmin=0, vmax=1)
ax[i][0].set_title("Predicted", fontsize=10, color="blue")
ax[i][0].axis('off')
ax[i][1].imshow(ground_truth[i], cmap='CMRmap', vmin=0, vmax=self.num_class - 1)
ax[i][1].imshow(torch.squeeze(ground_truth[i]), cmap='CMRmap', vmin=0, vmax=1)
ax[i][1].set_title("Ground Truth", fontsize=10, color="blue")
ax[i][1].axis('off')
fig.set_tight_layout(True)
Expand Down
2 changes: 1 addition & 1 deletion utils/preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def rotate_orientation(volume_data, volume_label, orientation=ORIENTATION['coron
return volume_data.transpose((2, 0, 1)), volume_label.transpose((2, 0, 1))
elif orientation == ORIENTATION['axial']:
return volume_data.transpose((1, 2, 0)), volume_label.transpose((1, 2, 0))
elif orientation == ORIENTATION['sagital']:
elif orientation == ORIENTATION['sagittal']:
return volume_data, volume_label
else:
raise ValueError("Invalid value for orientation. Pleas see help")
Expand Down
Loading

0 comments on commit f9f2fcf

Please sign in to comment.