From ec534f34a0ed2477059b46f9a643fa8426bfe37b Mon Sep 17 00:00:00 2001 From: Abhijit Guha Roy Date: Tue, 22 Jan 2019 11:39:43 +0100 Subject: [PATCH] Reverts to SGD --- datasets/eval_query.txt | 22 +- datasets/eval_support.txt | 2 +- few_shot_segmentor_model6.py | 15 +- ...ot_segmentor_sne_position_all_type_both.py | 351 ++++++++++++++++++ ...segmentor_sne_position_all_type_channel.py | 279 ++++++++++++++ ...segmentor_sne_position_all_type_spatial.py | 268 +++++++++++++ ..._segmentor_sne_position_bn_type_channel.py | 240 ++++++++++++ ..._segmentor_sne_position_bn_type_spatial.py | 238 ++++++++++++ ...entor_sne_position_decoder_type_channel.py | 252 +++++++++++++ ...entor_sne_position_decoder_type_spatial.py | 251 +++++++++++++ ...entor_sne_position_encoder_type_channel.py | 287 ++++++++++++++ ...entor_sne_position_encoder_type_spatial.py | 247 ++++++++++++ run_oneshot.py | 70 ++-- settings.ini | 12 +- solver_oneshot_multiOpti_auto.py | 8 +- utils/convert_h5.py | 4 +- utils/data_utils.py | 36 +- utils/evaluator.py | 242 +++++++++--- utils/evaluator_kshot.py | 75 ++-- utils/evaluator_multi_volume_support.py | 150 +++++--- utils/evaluator_slow.py | 212 ++--------- utils/log_utils.py | 1 + utils/preprocessor.py | 21 +- utils/shot_batch_sampler.py | 10 +- 24 files changed, 2928 insertions(+), 365 deletions(-) create mode 100644 few_shot_segmentor_sne_position_all_type_both.py create mode 100644 few_shot_segmentor_sne_position_all_type_channel.py create mode 100644 few_shot_segmentor_sne_position_all_type_spatial.py create mode 100644 few_shot_segmentor_sne_position_bn_type_channel.py create mode 100644 few_shot_segmentor_sne_position_bn_type_spatial.py create mode 100644 few_shot_segmentor_sne_position_decoder_type_channel.py create mode 100644 few_shot_segmentor_sne_position_decoder_type_spatial.py create mode 100644 few_shot_segmentor_sne_position_encoder_type_channel.py create mode 100644 few_shot_segmentor_sne_position_encoder_type_spatial.py diff --git a/datasets/eval_query.txt b/datasets/eval_query.txt index fa57bcc..aa5bf2f 100644 --- a/datasets/eval_query.txt +++ b/datasets/eval_query.txt @@ -1,19 +1,19 @@ 10000100_1_CTce_ThAb.mat -10000111_1_CTce_ThAb.mat -10000131_1_CTce_ThAb.mat 10000104_1_CTce_ThAb.mat -10000112_1_CTce_ThAb.mat -10000132_1_CTce_ThAb.mat +10000105_1_CTce_ThAb.mat +10000106_1_CTce_ThAb.mat +10000108_1_CTce_ThAb.mat 10000109_1_CTce_ThAb.mat +10000110_1_CTce_ThAb.mat +10000111_1_CTce_ThAb.mat +10000112_1_CTce_ThAb.mat 10000113_1_CTce_ThAb.mat -10000133_1_CTce_ThAb.mat -10000106_1_CTce_ThAb.mat 10000127_1_CTce_ThAb.mat -10000134_1_CTce_ThAb.mat -10000108_1_CTce_ThAb.mat 10000128_1_CTce_ThAb.mat -10000135_1_CTce_ThAb.mat 10000129_1_CTce_ThAb.mat +10000130_1_CTce_ThAb.mat +10000131_1_CTce_ThAb.mat +10000133_1_CTce_ThAb.mat +10000134_1_CTce_ThAb.mat +10000135_1_CTce_ThAb.mat 10000136_1_CTce_ThAb.mat -10000110_1_CTce_ThAb.mat -10000105_1_CTce_ThAb.mat diff --git a/datasets/eval_support.txt b/datasets/eval_support.txt index ea11171..4ba647b 100644 --- a/datasets/eval_support.txt +++ b/datasets/eval_support.txt @@ -1 +1 @@ -10000130_1_CTce_ThAb.mat +10000132_1_CTce_ThAb.mat diff --git a/few_shot_segmentor_model6.py b/few_shot_segmentor_model6.py index 4f96daf..3a7fd16 100644 --- a/few_shot_segmentor_model6.py +++ b/few_shot_segmentor_model6.py @@ -100,32 +100,33 @@ def forward(self, input): num_batch, ch = e_c1.size() e_c1 = e_c1.view(num_batch, ch, 1, 1) - # e_w1 = self.sigmoid(self.squeeze_conv_e1(e1)) e2, out2, ind2 = self.encode2(e1) num_batch, ch, _, _ = out2.size() e_c2 = self.sigmoid(self.channel_conv_e2(out2.view(num_batch, ch, -1).mean(dim=2))) num_batch, ch = e_c2.size() e_c2 = e_c2.view(num_batch, ch, 1, 1) - # e_w2 = self.sigmoid(self.squeeze_conv_e2(e2)) + e3, _, ind3 = self.encode3(e2) e_w3 = self.sigmoid(self.squeeze_conv_e3(e3)) - e4, _, ind4 = self.encode3(e3) + e4, _, ind4 = self.encode4(e3) e_w4 = self.sigmoid(self.squeeze_conv_e4(e4)) bn = self.bottleneck(e4) bn_w = self.squeeze_conv_bn(bn) - d4 = self.decode1(bn, None, ind4) + d4 = self.decode4(bn, None, ind4) d_w4 = self.sigmoid(self.squeeze_conv_d4(d4)) - d3 = self.decode1(d4, None, ind3) + + d3 = self.decode3(d4, None, ind3) d_w3 = self.sigmoid(self.squeeze_conv_d3(d3)) + d2 = self.decode2(d3, None, ind2) num_batch, ch, _, _ = d2.size() d_c2 = self.sigmoid(self.channel_conv_d2(d2.view(num_batch, ch, -1).mean(dim=2))) num_batch, ch = d_c2.size() d_c2 = d_c2.view(num_batch, ch, 1, 1) - # d_w2 = self.sigmoid(self.squeeze_conv_d2(d2)) - d1 = self.decode3(d2, None, ind1) + + d1 = self.decode1(d2, None, ind1) num_batch, ch, _, _ = d1.size() d_c1 = self.sigmoid(self.channel_conv_d1(d1.view(num_batch, ch, -1).mean(dim=2))) num_batch, ch = d_c1.size() diff --git a/few_shot_segmentor_sne_position_all_type_both.py b/few_shot_segmentor_sne_position_all_type_both.py new file mode 100644 index 0000000..8648023 --- /dev/null +++ b/few_shot_segmentor_sne_position_all_type_both.py @@ -0,0 +1,351 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.channel_conv_e1 = nn.Linear(params['num_filters'], 64, bias=True) + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.channel_conv_e2 = nn.Linear(params['num_filters'], 64, bias=True) + self.encode3 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.encode4 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.bottleneck = sm.GenericBlock(params) + self.squeeze_conv_bn = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.channel_conv_d1 = nn.Linear(params['num_filters'], 64, bias=True) + self.decode2 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.channel_conv_d2 = nn.Linear(params['num_filters'], 64, bias=True) + self.decode3 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode4 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + # e1, out1, ind1 = self.encode1(input) + # e_w1 = self.sigmoid(self.squeeze_conv_e1(e1)) + # e2, out2, ind2 = self.encode2(e1) + # e_w2 = self.sigmoid(self.squeeze_conv_e2(e2)) + # e3, out3, ind3 = self.encode3(e2) + # e_w3 = self.sigmoid(self.squeeze_conv_e3(e3)) + # + # bn = self.bottleneck(e3) + # bn_w = self.sigmoid(self.squeeze_conv_bn(bn)) + # + # d3 = self.decode1(bn, out3, ind3) + # d_w3 = self.sigmoid(self.squeeze_conv_d3(d3)) + # d2 = self.decode2(d3, out2, ind2) + # d_w2 = self.sigmoid(self.squeeze_conv_d2(d2)) + # d1 = self.decode3(d2, out1, ind1) + # d_w1 = self.sigmoid(self.squeeze_conv_d1(d1)) + # logit = self.classifier.forward(d1) + # cls_w = self.sigmoid(logit) + + e1, out1, ind1 = self.encode1(input) + num_batch, ch, _, _ = out1.size() + e_c1 = self.sigmoid(self.channel_conv_e1(out1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c1.size() + e_c1 = e_c1.view(num_batch, ch, 1, 1) + + e2, out2, ind2 = self.encode2(e1) + num_batch, ch, _, _ = out2.size() + e_c2 = self.sigmoid(self.channel_conv_e2(out2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c2.size() + e_c2 = e_c2.view(num_batch, ch, 1, 1) + + e3, _, ind3 = self.encode3(e2) + e_w3 = self.sigmoid(self.squeeze_conv_e3(e3)) + + e4, _, ind4 = self.encode4(e3) + e_w4 = self.sigmoid(self.squeeze_conv_e4(e4)) + + bn = self.bottleneck(e4) + bn_w = self.squeeze_conv_bn(bn) + + d4 = self.decode4(bn, None, ind4) + d_w4 = self.sigmoid(self.squeeze_conv_d4(d4)) + + d3 = self.decode3(d4, None, ind3) + d_w3 = self.sigmoid(self.squeeze_conv_d3(d3)) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d_c2 = self.sigmoid(self.channel_conv_d2(d2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c2.size() + d_c2 = d_c2.view(num_batch, ch, 1, 1) + + d1 = self.decode1(d2, None, ind1) + num_batch, ch, _, _ = d1.size() + d_c1 = self.sigmoid(self.channel_conv_d1(d1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c1.size() + d_c1 = d_c1.view(num_batch, ch, 1, 1) + # d_w1 = self.sigmoid(self.squeeze_conv_d1(d1)) + # logit = self.classifier.forward(d1) + # cls_w = logit + + # return e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w + # return e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w + + space_weights = (None, None, e_w3, e_w4, bn_w, d_w4, d_w3, None, None, None) + channel_weights = (e_c1, e_c2, d_c2, d_c1) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + # e1, out1, ind1 = self.encode1(inpt) + # if e_w1 is not None: + # e1 = torch.mul(e1, e_w1) + # e2, out2, ind2 = self.encode2(e1) + # if e_w2 is not None: + # e2 = torch.mul(e2, e_w2) + # e3, out3, ind3 = self.encode3(e2) + # if e_w3 is not None: + # e3 = torch.mul(e3, e_w3) + # + # e4, out4, ind4 = self.encode4(e3) + # if e_w4 is not None: + # e4 = torch.mul(e4, e_w4) + # + # bn = self.bottleneck(e4) + # if bn_w is not None: + # bn = torch.mul(bn, bn_w) + # + # d4 = self.decode4(bn, out4, ind4) + # if d_w4 is not None: + # d4 = torch.mul(d4, d_w4) + # + # d3 = self.decode1(d4, out3, ind3) + # if d_w3 is not None: + # d3 = torch.mul(d3, d_w3) + # + # d2 = self.decode2(d3, out2, ind2) + # if d_w2 is not None: + # d2 = torch.mul(d2, d_w2) + # + # d1 = self.decode3(d2, out1, ind1) + # if d_w1 is not None: + # d1 = torch.mul(d1, d_w1) + + e1, _, ind1 = self.encode1(inpt) + e1 = torch.mul(e1, e_c1) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + e2, _, ind2 = self.encode2(e1) + e2 = torch.mul(e2, e_c2) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + d2 = torch.mul(d2, d_c2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + d1 = torch.mul(d1, d_c1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_all_type_channel.py b/few_shot_segmentor_sne_position_all_type_channel.py new file mode 100644 index 0000000..2feac6b --- /dev/null +++ b/few_shot_segmentor_sne_position_all_type_channel.py @@ -0,0 +1,279 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + self.channel_conv_e1 = nn.Linear(params['num_filters'], 64, bias=True) + + params['num_channels'] = 16 + + self.encode2 = sm.SDnetEncoderBlock(params) + self.channel_conv_e2 = nn.Linear(params['num_filters'], 64, bias=True) + + self.encode3 = sm.SDnetEncoderBlock(params) + self.channel_conv_e3 = nn.Linear(params['num_filters'], 64, bias=True) + + self.encode4 = sm.SDnetEncoderBlock(params) + self.channel_conv_e4 = nn.Linear(params['num_filters'], 64, bias=True) + + self.bottleneck = sm.GenericBlock(params) + self.channel_conv_bn = nn.Linear(params['num_filters'], 64, bias=True) + params['num_channels'] = 16 + + self.decode1 = sm.SDnetDecoderBlock(params) + self.channel_conv_d1 = nn.Linear(params['num_filters'], 64, bias=True) + self.decode2 = sm.SDnetDecoderBlock(params) + self.channel_conv_d2 = nn.Linear(params['num_filters'], 64, bias=True) + self.decode3 = sm.SDnetDecoderBlock(params) + self.channel_conv_d3 = nn.Linear(params['num_filters'], 64, bias=True) + self.decode4 = sm.SDnetDecoderBlock(params) + self.channel_conv_d4 = nn.Linear(params['num_filters'], 64, bias=True) + params['num_channels'] = 16 + + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + e1, out1, ind1 = self.encode1(input) + num_batch, ch, _, _ = out1.size() + e_c1 = self.sigmoid(self.channel_conv_e1(out1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c1.size() + e_c1 = e_c1.view(num_batch, ch, 1, 1) + + e2, out2, ind2 = self.encode2(e1) + num_batch, ch, _, _ = out2.size() + e_c2 = self.sigmoid(self.channel_conv_e2(out2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c2.size() + e_c2 = e_c2.view(num_batch, ch, 1, 1) + + e3, out3, ind3 = self.encode3(e2) + num_batch, ch, _, _ = out3.size() + e_c3 = self.sigmoid(self.channel_conv_e3(out3.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c3.size() + e_c3 = e_c3.view(num_batch, ch, 1, 1) + + e4, out4, ind4 = self.encode4(e3) + num_batch, ch, _, _ = out4.size() + e_c4 = self.sigmoid(self.channel_conv_e4(out4.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c4.size() + e_c4 = e_c4.view(num_batch, ch, 1, 1) + + bn = self.bottleneck(e4) + num_batch, ch, _, _ = bn.size() + bn_c = self.sigmoid(self.channel_conv_bn(bn.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = bn_c.size() + bn_c = bn_c.view(num_batch, ch, 1, 1) + + d4 = self.decode4(bn, None, ind4) + num_batch, ch, _, _ = d4.size() + d_c4 = self.sigmoid(self.channel_conv_d4(d4.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c4.size() + d_c4 = d_c4.view(num_batch, ch, 1, 1) + + d3 = self.decode3(d4, None, ind3) + num_batch, ch, _, _ = d3.size() + d_c3 = self.sigmoid(self.channel_conv_d3(d3.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c3.size() + d_c3 = d_c3.view(num_batch, ch, 1, 1) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d_c2 = self.sigmoid(self.channel_conv_d2(d2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c2.size() + d_c2 = d_c2.view(num_batch, ch, 1, 1) + + d1 = self.decode1(d2, None, ind1) + num_batch, ch, _, _ = d1.size() + d_c1 = self.sigmoid(self.channel_conv_d1(d1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c1.size() + d_c1 = d_c1.view(num_batch, ch, 1, 1) + + space_weights = (e_c1, e_c2, e_c3, e_c4, bn_c, d_c4, d_c3, d_c2, d_c1, None) + channel_weights = (e_c1, e_c2, e_c3, e_c4) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, e_c3, e_c4 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, out1, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + e2, out2, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + e3, out3, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_all_type_spatial.py b/few_shot_segmentor_sne_position_all_type_spatial.py new file mode 100644 index 0000000..cb3f148 --- /dev/null +++ b/few_shot_segmentor_sne_position_all_type_spatial.py @@ -0,0 +1,268 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.encode3 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.encode4 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.bottleneck = sm.GenericBlock(params) + self.squeeze_conv_bn = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode2 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode3 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode4 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + e1, _, ind1 = self.encode1(input) + e_w1 = self.sigmoid(self.squeeze_conv_e1(e1)) + e2, out2, ind2 = self.encode2(e1) + e_w2 = self.sigmoid(self.squeeze_conv_e2(e2)) + e3, _, ind3 = self.encode3(e2) + e_w3 = self.sigmoid(self.squeeze_conv_e3(e3)) + e4, _, ind4 = self.encode3(e3) + e_w4 = self.sigmoid(self.squeeze_conv_e4(e4)) + + bn = self.bottleneck(e4) + bn_w4 = self.sigmoid(self.squeeze_conv_bn(bn)) + d4 = self.decode4(bn, None, ind4) + d_w4 = self.sigmoid(self.squeeze_conv_d4(d4)) + d3 = self.decode3(d4, None, ind3) + d_w3 = self.sigmoid(self.squeeze_conv_d3(d3)) + d2 = self.decode2(d3, None, ind2) + d_w2 = self.sigmoid(self.squeeze_conv_d2(d2)) + d1 = self.decode1(d2, None, ind1) + d_w1 = self.sigmoid(self.squeeze_conv_d1(d1)) + + space_weights = (e_w1, e_w2, e_w3, e_w4, bn_w4, d_w4, d_w3, d_w2, d_w1, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, _, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + + e2, _, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_bn_type_channel.py b/few_shot_segmentor_sne_position_bn_type_channel.py new file mode 100644 index 0000000..537ba91 --- /dev/null +++ b/few_shot_segmentor_sne_position_bn_type_channel.py @@ -0,0 +1,240 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + + self.encode3 = sm.SDnetEncoderBlock(params) + + self.encode4 = sm.SDnetEncoderBlock(params) + + self.bottleneck = sm.GenericBlock(params) + self.channel_conv_bn = nn.Linear(params['num_filters'], 64, bias=True) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + e1, _, ind1 = self.encode1(input) + + e2, out2, ind2 = self.encode2(e1) + + e3, _, ind3 = self.encode3(e2) + + e4, _, ind4 = self.encode4(e3) + + bn = self.bottleneck(e4) + num_batch, ch, _, _ = bn.size() + bn_c = self.sigmoid(self.channel_conv_bn(bn.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = bn_c.size() + bn_c = bn_c.view(num_batch, ch, 1, 1) + + d4 = self.decode4(bn, None, ind4) + + d3 = self.decode3(d4, None, ind3) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d1 = self.decode1(d2, None, ind1) + + # NOTE: For ease putting channel weight in space. + space_weights = (None, None, None, None, bn_c, None, None, None, None, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, _, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + + e2, _, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_bn_type_spatial.py b/few_shot_segmentor_sne_position_bn_type_spatial.py new file mode 100644 index 0000000..bc4fe6c --- /dev/null +++ b/few_shot_segmentor_sne_position_bn_type_spatial.py @@ -0,0 +1,238 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + + self.encode3 = sm.SDnetEncoderBlock(params) + + self.encode4 = sm.SDnetEncoderBlock(params) + + self.bottleneck = sm.GenericBlock(params) + self.squeeze_conv_bn = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + e1, _, ind1 = self.encode1(input) + + e2, out2, ind2 = self.encode2(e1) + + e3, _, ind3 = self.encode3(e2) + + e4, _, ind4 = self.encode3(e3) + + bn = self.bottleneck(e4) + bn_w4 = self.sigmoid(self.squeeze_conv_bn(bn)) + d4 = self.decode4(bn, None, ind4) + + d3 = self.decode3(d4, None, ind3) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d1 = self.decode1(d2, None, ind1) + + space_weights = (None, None, None, None, bn_w4, None, None, None, None, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, _, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + + e2, _, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_decoder_type_channel.py b/few_shot_segmentor_sne_position_decoder_type_channel.py new file mode 100644 index 0000000..c4ccfa6 --- /dev/null +++ b/few_shot_segmentor_sne_position_decoder_type_channel.py @@ -0,0 +1,252 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + + params['num_channels'] = 16 + + self.encode2 = sm.SDnetEncoderBlock(params) + + self.encode3 = sm.SDnetEncoderBlock(params) + + self.encode4 = sm.SDnetEncoderBlock(params) + + self.bottleneck = sm.GenericBlock(params) + + params['num_channels'] = 16 + + self.decode1 = sm.SDnetDecoderBlock(params) + self.channel_conv_d1 = nn.Linear(params['num_filters'], 64, bias=True) + + self.decode2 = sm.SDnetDecoderBlock(params) + self.channel_conv_d2 = nn.Linear(params['num_filters'], 64, bias=True) + + self.decode3 = sm.SDnetDecoderBlock(params) + self.channel_conv_d3 = nn.Linear(params['num_filters'], 64, bias=True) + + self.decode4 = sm.SDnetDecoderBlock(params) + self.channel_conv_d4 = nn.Linear(params['num_filters'], 64, bias=True) + + params['num_channels'] = 16 + + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + e1, out1, ind1 = self.encode1(input) + + e2, out2, ind2 = self.encode2(e1) + + e3, out3, ind3 = self.encode3(e2) + + e4, out4, ind4 = self.encode4(e3) + + bn = self.bottleneck(e4) + + d4 = self.decode4(bn, None, ind4) + num_batch, ch, _, _ = d4.size() + d_c4 = self.sigmoid(self.channel_conv_d4(d4.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c4.size() + d_c4 = d_c4.view(num_batch, ch, 1, 1) + + d3 = self.decode3(d4, None, ind3) + num_batch, ch, _, _ = d3.size() + d_c3 = self.sigmoid(self.channel_conv_d3(d4.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c3.size() + d_c3 = d_c3.view(num_batch, ch, 1, 1) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d_c2 = self.sigmoid(self.channel_conv_d2(d2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c2.size() + d_c2 = d_c2.view(num_batch, ch, 1, 1) + + d1 = self.decode1(d2, None, ind1) + num_batch, ch, _, _ = d1.size() + d_c1 = self.sigmoid(self.channel_conv_d1(d1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = d_c1.size() + d_c1 = d_c1.view(num_batch, ch, 1, 1) + + space_weights = (None, None, None, None, None, d_c4, d_c3, d_c2, d_c1, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, e_c3, e_c4 = channel_weights + + e1, _, ind1 = self.encode1(inpt) + + e2, _, ind2 = self.encode2(e1) + + e3, _, ind3 = self.encode3(e2) + + e4, out4, ind4 = self.encode4(e3) + + bn = self.bottleneck(e4) + + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_decoder_type_spatial.py b/few_shot_segmentor_sne_position_decoder_type_spatial.py new file mode 100644 index 0000000..ad6e423 --- /dev/null +++ b/few_shot_segmentor_sne_position_decoder_type_spatial.py @@ -0,0 +1,251 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + + self.encode3 = sm.SDnetEncoderBlock(params) + + self.encode4 = sm.SDnetEncoderBlock(params) + + self.bottleneck = sm.GenericBlock(params) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode2 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode3 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.decode4 = sm.SDnetDecoderBlock(params) + self.squeeze_conv_d4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + e1, _, ind1 = self.encode1(input) + + e2, out2, ind2 = self.encode2(e1) + + e3, _, ind3 = self.encode3(e2) + + e4, _, ind4 = self.encode3(e3) + + bn = self.bottleneck(e4) + + d4 = self.decode4(bn, None, ind4) + d_w4 = self.sigmoid(self.squeeze_conv_d4(d4)) + d3 = self.decode3(d4, None, ind3) + d_w3 = self.sigmoid(self.squeeze_conv_d3(d3)) + d2 = self.decode2(d3, None, ind2) + d_w2 = self.sigmoid(self.squeeze_conv_d2(d2)) + d1 = self.decode1(d2, None, ind1) + d_w1 = self.sigmoid(self.squeeze_conv_d1(d1)) + + space_weights = (None, None, None, None, None, d_w4, d_w3, d_w2, d_w1, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, _, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + + e2, _, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_encoder_type_channel.py b/few_shot_segmentor_sne_position_encoder_type_channel.py new file mode 100644 index 0000000..66ca197 --- /dev/null +++ b/few_shot_segmentor_sne_position_encoder_type_channel.py @@ -0,0 +1,287 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + self.channel_conv_e1 = nn.Linear(params['num_filters'], 64, bias=True) + + params['num_channels'] = 16 + + self.encode2 = sm.SDnetEncoderBlock(params) + self.channel_conv_e2 = nn.Linear(params['num_filters'], 64, bias=True) + + self.encode3 = sm.SDnetEncoderBlock(params) + self.channel_conv_e3 = nn.Linear(params['num_filters'], 64, bias=True) + + self.encode4 = sm.SDnetEncoderBlock(params) + self.channel_conv_e4 = nn.Linear(params['num_filters'], 64, bias=True) + + self.bottleneck = sm.GenericBlock(params) + + params['num_channels'] = 16 + + self.decode1 = sm.SDnetDecoderBlock(params) + + self.decode2 = sm.SDnetDecoderBlock(params) + + self.decode3 = sm.SDnetDecoderBlock(params) + + self.decode4 = sm.SDnetDecoderBlock(params) + + params['num_channels'] = 16 + + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + e1, out1, ind1 = self.encode1(input) + num_batch, ch, _, _ = out1.size() + e_c1 = self.sigmoid(self.channel_conv_e1(out1.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c1.size() + e_c1 = e_c1.view(num_batch, ch, 1, 1) + + e2, out2, ind2 = self.encode2(e1) + num_batch, ch, _, _ = out2.size() + e_c2 = self.sigmoid(self.channel_conv_e2(out2.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c2.size() + e_c2 = e_c2.view(num_batch, ch, 1, 1) + + e3, out3, ind3 = self.encode3(e2) + num_batch, ch, _, _ = out3.size() + e_c3 = self.sigmoid(self.channel_conv_e3(out3.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c3.size() + e_c3 = e_c3.view(num_batch, ch, 1, 1) + + e4, out4, ind4 = self.encode4(e3) + num_batch, ch, _, _ = out4.size() + e_c4 = self.sigmoid(self.channel_conv_e4(out4.view(num_batch, ch, -1).mean(dim=2))) + num_batch, ch = e_c4.size() + e_c4 = e_c4.view(num_batch, ch, 1, 1) + + bn = self.bottleneck(e4) + + d4 = self.decode4(bn, None, ind4) + d3 = self.decode3(d4, None, ind3) + d2 = self.decode2(d3, None, ind2) + d1 = self.decode1(d2, None, ind1) + + space_weights = (None, None, None, None, None, None, None, None, None, None) + channel_weights = (e_c1, e_c2, e_c3, e_c4) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, e_c3, e_c4 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + # e1, out1, ind1 = self.encode1(inpt) + # if e_w1 is not None: + # e1 = torch.mul(e1, e_w1) + # e2, out2, ind2 = self.encode2(e1) + # if e_w2 is not None: + # e2 = torch.mul(e2, e_w2) + # e3, out3, ind3 = self.encode3(e2) + # if e_w3 is not None: + # e3 = torch.mul(e3, e_w3) + # + # e4, out4, ind4 = self.encode4(e3) + # if e_w4 is not None: + # e4 = torch.mul(e4, e_w4) + # + # bn = self.bottleneck(e4) + # if bn_w is not None: + # bn = torch.mul(bn, bn_w) + # + # d4 = self.decode4(bn, out4, ind4) + # if d_w4 is not None: + # d4 = torch.mul(d4, d_w4) + # + # d3 = self.decode1(d4, out3, ind3) + # if d_w3 is not None: + # d3 = torch.mul(d3, d_w3) + # + # d2 = self.decode2(d3, out2, ind2) + # if d_w2 is not None: + # d2 = torch.mul(d2, d_w2) + # + # d1 = self.decode3(d2, out1, ind1) + # if d_w1 is not None: + # d1 = torch.mul(d1, d_w1) + + e1, _, ind1 = self.encode1(inpt) + e1 = torch.mul(e1, e_c1) + e2, _, ind2 = self.encode2(e1) + e2 = torch.mul(e2, e_c2) + + e3, _, ind3 = self.encode3(e2) + e3 = torch.mul(e3, e_c3) + + e4, out4, ind4 = self.encode4(e3) + e4 = torch.mul(e4, e_c4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/few_shot_segmentor_sne_position_encoder_type_spatial.py b/few_shot_segmentor_sne_position_encoder_type_spatial.py new file mode 100644 index 0000000..17b092f --- /dev/null +++ b/few_shot_segmentor_sne_position_encoder_type_spatial.py @@ -0,0 +1,247 @@ +"""Few-Shot_learning Segmentation""" + +import numpy as np +import torch +import torch.nn as nn +from nn_common_modules import modules as sm +from data_utils import split_batch +# import torch.nn.functional as F +from squeeze_and_excitation import squeeze_and_excitation as se + + +class SDnetConditioner(nn.Module): + """ + A conditional branch of few shot learning regressing the parameters for the segmentor + """ + + def __init__(self, params): + super(SDnetConditioner, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 2 + params['num_filters'] = 16 + self.encode1 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e1 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + params['num_channels'] = 16 + self.encode2 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e2 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.encode3 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e3 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.encode4 = sm.SDnetEncoderBlock(params) + self.squeeze_conv_e4 = nn.Conv2d(in_channels=params['num_filters'], out_channels=1, + kernel_size=(1, 1), + padding=(0, 0), + stride=1) + self.bottleneck = sm.GenericBlock(params) + params['num_channels'] = 16 + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 16 + self.classifier = sm.ClassifierBlock(params) + self.sigmoid = nn.Sigmoid() + + def forward(self, input): + + e1, _, ind1 = self.encode1(input) + e_w1 = self.sigmoid(self.squeeze_conv_e1(e1)) + e2, out2, ind2 = self.encode2(e1) + e_w2 = self.sigmoid(self.squeeze_conv_e2(e2)) + e3, _, ind3 = self.encode3(e2) + e_w3 = self.sigmoid(self.squeeze_conv_e3(e3)) + e4, _, ind4 = self.encode3(e3) + e_w4 = self.sigmoid(self.squeeze_conv_e4(e4)) + + bn = self.bottleneck(e4) + + d4 = self.decode4(bn, None, ind4) + + d3 = self.decode2(d4, None, ind3) + + d2 = self.decode2(d3, None, ind2) + num_batch, ch, _, _ = d2.size() + d1 = self.decode1(d2, None, ind1) + + space_weights = (e_w1, e_w2, e_w3, e_w4, None, None, None, None, None, None) + channel_weights = (None, None, None, None) + + return space_weights, channel_weights + + +class SDnetSegmentor(nn.Module): + """ + Segmentor Code + + param ={ + 'num_channels':1, + 'num_filters':64, + 'kernel_h':5, + 'kernel_w':5, + 'stride_conv':1, + 'pool':2, + 'stride_pool':2, + 'num_classes':1 + 'se_block': True, + 'drop_out':0 + } + + """ + + def __init__(self, params): + super(SDnetSegmentor, self).__init__() + se_block_type = se.SELayer.SSE + params['num_channels'] = 1 + params['num_filters'] = 64 + self.encode1 = sm.SDnetEncoderBlock(params) + params['num_channels'] = 64 + self.encode2 = sm.SDnetEncoderBlock(params) + self.encode3 = sm.SDnetEncoderBlock(params) + self.encode4 = sm.SDnetEncoderBlock(params) + self.bottleneck = sm.GenericBlock(params) + + self.decode1 = sm.SDnetDecoderBlock(params) + self.decode2 = sm.SDnetDecoderBlock(params) + self.decode3 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 128 + self.decode4 = sm.SDnetDecoderBlock(params) + params['num_channels'] = 64 + self.classifier = sm.ClassifierBlock(params) + self.soft_max = nn.Softmax2d() + # self.sigmoid = nn.Sigmoid() + + def forward(self, inpt, weights=None): + space_weights, channel_weights = weights + # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights if weights is not None else ( + # None, None, None, None, None, None, None, None) + + e_w1, e_w2, e_w3, e_w4, bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = space_weights if space_weights is not None else ( + None, None, None, None, None, None, None, None, None, None) + e_c1, e_c2, d_c1, d_c2 = channel_weights + # if weights is not None: + # bn_w, d_w4, d_w3, d_w2, d_w1, cls_w = bn_w * 50, d_w4 * 50, d_w3 * 50, d_w2 * 50, d_w1 * 50, cls_w * 50 + + e1, _, ind1 = self.encode1(inpt) + if e_w1 is not None: + e1 = torch.mul(e1, e_w1) + + e2, _, ind2 = self.encode2(e1) + if e_w2 is not None: + e2 = torch.mul(e2, e_w2) + + e3, _, ind3 = self.encode3(e2) + if e_w3 is not None: + e3 = torch.mul(e3, e_w3) + + e4, out4, ind4 = self.encode4(e3) + if e_w4 is not None: + e4 = torch.mul(e4, e_w4) + + bn = self.bottleneck(e4) + if bn_w is not None: + bn = torch.mul(bn, bn_w) + + d4 = self.decode4(bn, out4, ind4) + if d_w4 is not None: + d4 = torch.mul(d4, d_w4) + + d3 = self.decode3(d4, None, ind3) + if d_w3 is not None: + d3 = torch.mul(d3, d_w3) + + d2 = self.decode2(d3, None, ind2) + if d_w2 is not None: + d2 = torch.mul(d2, d_w2) + + d1 = self.decode1(d2, None, ind1) + if d_w1 is not None: + d1 = torch.mul(d1, d_w1) + + # d1_1 = torch.cat((d1, inpt), dim=1) + logit = self.classifier.forward(d1) + if cls_w is not None: + logit = torch.mul(logit, cls_w) + logit = self.soft_max(logit) + + return logit + + +class FewShotSegmentorDoubleSDnet(nn.Module): + ''' + Class Combining Conditioner and Segmentor for few shot learning + ''' + + def __init__(self, params): + super(FewShotSegmentorDoubleSDnet, self).__init__() + self.conditioner = SDnetConditioner(params) + self.segmentor = SDnetSegmentor(params) + + def forward(self, input1, input2): + weights = self.conditioner(input1) + segment = self.segmentor(input2, weights) + return segment + + def enable_test_dropout(self): + attr_dict = self.__dict__['_modules'] + for i in range(1, 5): + encode_block, decode_block = attr_dict['encode' + str(i)], attr_dict['decode' + str(i)] + encode_block.drop_out = encode_block.drop_out.apply(nn.Module.train) + decode_block.drop_out = decode_block.drop_out.apply(nn.Module.train) + + @property + def is_cuda(self): + """ + Check if model parameters are allocated on the GPU. + """ + return next(self.parameters()).is_cuda + + def save(self, path): + """ + Save model with its parameters to the given path. Conventionally the + path should end with "*.model". + + Inputs: + - path: path string + """ + print('Saving model... %s' % path) + torch.save(self, path) + + def predict(self, X, y, query_label, device=0, enable_dropout=False): + """ + Predicts the outout after the model is trained. + Inputs: + - X: Volume to be predicted + """ + self.eval() + input1, input2, y2 = split_batch(X, y, query_label) + input1, input2, y2 = to_cuda(input1, device), to_cuda(input2, device), to_cuda(y2, device) + + if enable_dropout: + self.enable_test_dropout() + + with torch.no_grad(): + out = self.forward(input1, input2) + + # max_val, idx = torch.max(out, 1) + idx = out > 0.5 + idx = idx.data.cpu().numpy() + prediction = np.squeeze(idx) + del X, out, idx + return prediction + + +def to_cuda(X, device): + if type(X) is np.ndarray: + X = torch.tensor(X, requires_grad=False).type(torch.FloatTensor).cuda(device, non_blocking=True) + elif type(X) is torch.Tensor and not X.is_cuda: + X = X.type(torch.FloatTensor).cuda(device, non_blocking=True) + return X + diff --git a/run_oneshot.py b/run_oneshot.py index 4ff3154..85125ab 100644 --- a/run_oneshot.py +++ b/run_oneshot.py @@ -16,6 +16,7 @@ import few_shot_segmentor_model9 as fs9 import few_shot_segmentor_model10 as fs10 import few_shot_segmentor_model11 as fs11 +import few_shot_segmentor_sne_position_all_type_both as fs from settings import Settings # from solver_oneshot_singleOpti import Solver from solver_oneshot_multiOpti_auto import Solver @@ -45,9 +46,8 @@ def forward(self, x): def train(train_params, common_params, data_params, net_params): train_data, test_data = load_data(data_params) - - folds = ['fold1'] - model_prefix = 'model6_wholebody_condch16_e4Skip_inter_e3e4bnd4d3d2_allSseSeg_DiceLoss_' + model_prefix = 'sne_position_all_type_both_try2_' + folds = ['fold2', 'fold3', 'fold4'] for fold in folds: final_model_path = os.path.join(common_params['save_model_dir'], model_prefix + fold + '.pth.tar') @@ -69,7 +69,7 @@ def train(train_params, common_params, data_params, net_params): # for param in segmentor_pretrained.parameters(): # param.requires_grad = False - few_shot_model = fs6.FewShotSegmentorDoubleSDnet(net_params) + few_shot_model = fs.FewShotSegmentorDoubleSDnet(net_params) # few_shot_model = segmentor_pretrained # few_shot_model.conditioner = conditioner_pretrained @@ -122,31 +122,31 @@ def evaluate(eval_params, net_params, data_params, common_params, train_params): # model_name = 'model6_Dice_L2_loss_target_fold1.pth.tar' folds = ['fold1'] - eval_model_path1 = "saved_models/model6_wholebody_condch16_e4Skip_inter_e3e4bnd4d3d2_allSseSeg_DiceLoss_fold1.pth.tar" - eval_model_path2 = "saved_models/model6_coronal_wholebody_condch16_e4Skip_inter_e3e4bnd4d3d2_allSseSeg_DiceLoss_fold1.pth.tar" + eval_model_path1 = "saved_models/sne_position_all_type_both_try2_fold1.pth.tar" + eval_model_path2 = "saved_models/model6_coronal_wholebody_condch16_e4Skip_inter_e3e4bnd4d3_ch_e1e2d1d2_noSseSeg_DiceLoss_lowrate_fold2.pth.tar" # eval_model_path3 = "saved_models/model6_sagittal_fold1.pth.tar" orientaion1 = 'AXI' - # orientaion2 = 'COR' + orientaion2 = 'COR' for fold in folds: - # eval_model_path = os.path.join('saved_models', model_name + '_' + fold + '.pth.tar') + # eval_model_path = os.path.join('saved_models', eval_model_path1 + '_' + fold + '.pth.tar') query_labels = get_lab_list('val', fold) num_classes = len(fold) - #avg_dice_score = eu.evaluate_dice_score(eval_model_path1, - # num_classes, - # query_labels, - # data_dir, - # query_txt_file, - # support_txt_file, - # remap_config, - # orientaion1, - # prediction_path, - # device, - # logWriter, fold=fold) - - #avg_dice_score = eu.evaluate_dice_score_3view(eval_model_path1, + avg_dice_score = eu.evaluate_dice_score(eval_model_path1, + num_classes, + query_labels, + data_dir, + query_txt_file, + support_txt_file, + remap_config, + orientaion1, + prediction_path, + device, + logWriter, fold=fold) + + # avg_dice_score = eu.evaluate_dice_score_3view(eval_model_path1, # eval_model_path2, # eval_model_path3, # num_classes, @@ -159,18 +159,18 @@ def evaluate(eval_params, net_params, data_params, common_params, train_params): # prediction_path, # device, # logWriter, fold=fold) - avg_dice_score = eu.evaluate_dice_score_2view(eval_model_path1, - eval_model_path2, - num_classes, - query_labels, - data_dir, - query_txt_file, - support_txt_file, - remap_config, - orientaion1, - prediction_path, - device, - logWriter, fold=fold) + # avg_dice_score = eu.evaluate_dice_score_2view(eval_model_path1, + # eval_model_path2, + # num_classes, + # query_labels, + # data_dir, + # query_txt_file, + # support_txt_file, + # remap_config, + # orientaion1, + # prediction_path, + # device, + # logWriter, fold=fold) logWriter.log(avg_dice_score) logWriter.close() @@ -180,12 +180,16 @@ def evaluate(eval_params, net_params, data_params, common_params, train_params): parser = argparse.ArgumentParser() parser.add_argument('--mode', '-m', required=True, help='run mode, valid values are train and eval') + parser.add_argument('--device', '-d', required=False, help='device to run on') args = parser.parse_args() settings = Settings() common_params, data_params, net_params, train_params, eval_params = settings['COMMON'], settings['DATA'], settings[ 'NETWORK'], settings['TRAINING'], settings['EVAL'] + if args.device is not None: + common_params['device'] = args.device + if args.mode == 'train': train(train_params, common_params, data_params, net_params) elif args.mode == 'eval': diff --git a/settings.ini b/settings.ini index 5f4a757..a883c73 100644 --- a/settings.ini +++ b/settings.ini @@ -33,15 +33,15 @@ se_block = "NONE" drop_out = 0 [TRAINING] -fold = 'fold2' -exp_name = "model6_wholebody_condch16_e4Skip_inter_e3e4bnd4d3d2_allSseSeg_DiceLoss_fold1" +fold = 'fold1' +exp_name = "sne_position_all_type_both_try2_fold1" final_model_file = "model6_wholebody_axial_fold1.pth.tar" learning_rate = 1e-1 momentum = 0.95 train_batch_size = 10 val_batch_size = 5 log_nth = 10 -num_epochs = 11 +num_epochs = 10 optim_betas = (0.9, 0.999) optim_eps = 1e-8 optim_weight_decay = 0.00001 @@ -49,13 +49,13 @@ lr_scheduler_step_size = 2 lr_scheduler_gamma = 0.5 iterations=100 test_iterations=100 -pre_trained_path = "saved_models/model6_Focal_loss_noClsLastDec_fold1.pth.tar" +pre_trained_path = "saved_models/sne_position_decoder_type_spatial_fold1.pth.tar" #Uses the last checkpoint file from the exp_dir_name folder use_last_checkpoint = True [EVAL] -eval_model_path = "saved_models/model6_Dice_L2_loss_target_fold1.pth.tar" +eval_model_path = "saved_models/sne_position_all_type_both_try2_fold1.pth.tar" data_dir = "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral" label_dir = "/home/deeplearning/Abhijit/nas_drive/Abhijit/WholeBody/CT_ce/Data/Visceral" volumes_txt_file = "datasets/MALC/test_volumes.txt" @@ -65,4 +65,4 @@ support_txt_file = "datasets/eval_support.txt" remap_config = "WholeBody" #Valid options : COR, AXI, SAG orientation = "AXI" -save_predictions_dir = "predictions2view_new" \ No newline at end of file +save_predictions_dir = "predictions_slow_multivolsupport" \ No newline at end of file diff --git a/solver_oneshot_multiOpti_auto.py b/solver_oneshot_multiOpti_auto.py index 7171c30..b9b0d90 100644 --- a/solver_oneshot_multiOpti_auto.py +++ b/solver_oneshot_multiOpti_auto.py @@ -59,10 +59,10 @@ def __init__(self, # self.scheduler = lr_scheduler.StepLR(self.optim, step_size=5, # gamma=0.1) - self.scheduler_s = lr_scheduler.StepLR(self.optim_s, step_size=4, - gamma=0.5) - self.scheduler_c = lr_scheduler.StepLR(self.optim_c, step_size=4, + self.scheduler_s = lr_scheduler.StepLR(self.optim_s, step_size=10, gamma=0.1) + self.scheduler_c = lr_scheduler.StepLR(self.optim_c, step_size=10, + gamma=0.001) exp_dir_path = os.path.join(exp_dir, exp_name) common_utils.create_if_not(exp_dir_path) @@ -104,7 +104,7 @@ def train(self, train_loader, test_loader): self.logWriter.log('START TRAINING. : model name = %s, device = %s' % ( self.model_name, torch.cuda.get_device_name(self.device))) current_iteration = self.start_iteration - warm_up_epoch = 10 + warm_up_epoch = 15 val_old = 0 change_model = False current_model = 'seg' diff --git a/utils/convert_h5.py b/utils/convert_h5.py index 24e6406..bf16f80 100644 --- a/utils/convert_h5.py +++ b/utils/convert_h5.py @@ -49,8 +49,8 @@ def convert_h5(data_dir, label_dir, data_split, train_volumes, test_volumes, f, if data_split: train_file_paths, test_file_paths = apply_split(data_split, data_dir, label_dir) elif train_volumes and test_volumes: - train_file_paths = du.load_file_paths(data_dir, label_dir, train_volumes) - test_file_paths = du.load_file_paths(data_dir, label_dir, test_volumes) + train_file_paths = du.load_file_paths_brain(data_dir, label_dir, train_volumes) + test_file_paths = du.load_file_paths_brain(data_dir, label_dir, test_volumes) else: raise ValueError('You must either provide the split ratio or a train, train dataset list') diff --git a/utils/data_utils.py b/utils/data_utils.py index e994664..9e39115 100644 --- a/utils/data_utils.py +++ b/utils/data_utils.py @@ -6,6 +6,7 @@ import torch.utils.data as data import scipy.io as sio import preprocessor as preprocessor +import nibabel as nb import math from torchvision import transforms @@ -32,8 +33,6 @@ def __init__(self, X, y, w, transforms=None): self.w = w self.transforms = transforms - - def __getitem__(self, index): img = torch.from_numpy(self.X[index]) label = torch.from_numpy(self.y[index]) @@ -103,6 +102,15 @@ def load_and_preprocess(file_path, orientation, remap_config, reduce_slices=Fals return volume, labelmap, class_weights, weights +def load_data(file_path, orientation): + print(file_path[0], file_path[1]) + volume_nifty, labelmap_nifty = nb.load(file_path[0]), nb.load(file_path[1]) + volume, labelmap = volume_nifty.get_fdata(), labelmap_nifty.get_fdata() + volume = (volume - np.min(volume)) / (np.max(volume) - np.min(volume)) + volume, labelmap = preprocessor.rotate_orientation(volume, labelmap, orientation) + return volume, labelmap, volume_nifty.header + + def load_data_mat(file_path, orientation): data = sio.loadmat(file_path) volume = data['DatVol'] @@ -128,6 +136,30 @@ def preprocess(volume, labelmap, remap_config, reduce_slices=False, remove_black return volume, labelmap, None, None +def load_file_paths_brain(data_dir, label_dir, volumes_txt_file=None): + """ + This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label. + It should be modified to suit the need of the project + :param data_dir: Directory which contains the data files + :param label_dir: Directory which contains the label files + :param volumes_txt_file: (Optional) Path to the a csv file, when provided only these data points will be read + :return: list of file paths as string + """ + + volume_exclude_list = ['IXI290', 'IXI423'] + if volumes_txt_file: + with open(volumes_txt_file) as file_handle: + volumes_to_use = file_handle.read().splitlines() + else: + volumes_to_use = [name for name in os.listdir(data_dir) if name not in volume_exclude_list] + + file_paths = [ + [os.path.join(data_dir, vol, 'mri/orig.mgz'), os.path.join(label_dir, vol+'_glm.mgz')] + for + vol in volumes_to_use] + return file_paths + + def load_file_paths(data_dir, label_dir, volumes_txt_file=None): """ This function returns the file paths combined as a list where each element is a 2 element tuple, 0th being data and 1st being label. diff --git a/utils/evaluator.py b/utils/evaluator.py index 2504385..3cd8459 100644 --- a/utils/evaluator.py +++ b/utils/evaluator.py @@ -38,6 +38,15 @@ def dice_confusion_matrix(vol_output, ground_truth, num_classes, no_samples=10, return avg_dice, dice_cm +def get_range(volume): + batch, _, _ = volume.size() + slice_with_class = torch.sum(volume.view(batch, -1), dim=1) > 10 + index = slice_with_class[:-1] - slice_with_class[1:] > 0 + seq = torch.Tensor(range(batch - 1)) + range_index = seq[index].type(torch.LongTensor) + return range_index + + def dice_score_perclass(vol_output, ground_truth, num_classes, no_samples=10, mode='train'): dice_perclass = torch.zeros(num_classes) if mode == 'train': @@ -54,8 +63,15 @@ def dice_score_perclass(vol_output, ground_truth, num_classes, no_samples=10, mo def binarize_label(volume, groud_truth, class_label): groud_truth = (groud_truth == class_label).type(torch.FloatTensor) - condition_input = torch.mul(volume, groud_truth.unsqueeze(1)) - return condition_input + batch, _, _ = groud_truth.size() + slice_with_class = torch.sum(groud_truth.view(batch, -1), dim=1) > 10 + index = slice_with_class[:-1] - slice_with_class[1:] > 0 + seq = torch.Tensor(range(batch - 1)) + range_index = seq[index].type(torch.LongTensor) + groud_truth = groud_truth[slice_with_class] + volume = volume[slice_with_class] + condition_input = torch.cat((volume, groud_truth.unsqueeze(1)), dim=1) + return condition_input, range_index.cpu().numpy() def evaluate_dice_score(model_path, @@ -69,9 +85,8 @@ def evaluate_dice_score(model_path, prediction_path, device=0, logWriter=None, mode='eval', fold=None): print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") print("Loading model => " + model_path) - batch_size = 10 - # Num_support = 6 - + batch_size = 20 + Num_support = 15 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() @@ -96,87 +111,212 @@ def evaluate_dice_score(model_path, all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] - for vol_idx, file_path in enumerate(support_file_paths): - # Loading support - support_volume, support_labelmap, _, _ = du.load_and_preprocess(file_path, - orientation=orientation, - remap_config=remap_config) - support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, - :] - support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor( - support_labelmap).type(torch.LongTensor) - support_volume = binarize_label(support_volume, support_labelmap, query_label) - # sz = support_volume.size() - # slice_gap = sz[0] // Num_support + # + # support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], + # orientation=orientation, + # remap_config=remap_config) + # + # support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :] + # + # support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor( + # support_labelmap).type(torch.LongTensor) + # support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) + # support_volume = support_volume[range_index[0]: range_index[1]] + + # Loading support + support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], + orientation=orientation, + remap_config=remap_config) + support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, + :] + support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ + torch.tensor(support_labelmap).type(torch.LongTensor) + + support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) + + slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) + + support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] + + if len(support_slice_indexes) < Num_support: + support_slice_indexes.append(len(support_volume) - 1) for vol_idx, file_path in enumerate(query_file_paths): + query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation, remap_config=remap_config) query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] - query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor( - query_labelmap).type(torch.LongTensor) + query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), \ + torch.tensor(query_labelmap).type(torch.LongTensor) query_labelmap = query_labelmap == query_label - support_batch_x = [] - k = 2 + range_query = get_range(query_labelmap) + query_volume = query_volume[range_query[0]: range_query[1] + 1] + query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1] + + slice_gap_query = int(np.ceil((len(query_volume) / Num_support))) + + query_slice_indexes = [i for i in range(0, len(query_volume), slice_gap_query)] + if len(query_slice_indexes) < Num_support: + query_slice_indexes.append(len(query_volume) - 1) + volume_prediction = [] - for i in range(0, len(query_volume), batch_size): - query_batch_x = query_volume[i: i + batch_size] - if k % 2 == 0: - support_batch_x = support_volume[i: i + batch_size] - sz = query_batch_x.size() - support_batch_x = support_batch_x[batch_size - 1].repeat(sz[0], 1, 1, 1) - k += 1 + + # for i in range(0, len(query_volume), batch_size): + support_current_slice = 0 + query_current_slice = 0 + + for i, query_start_slice in enumerate(query_slice_indexes): + if query_start_slice == query_slice_indexes[-1]: + query_batch_x = query_volume[query_slice_indexes[i]:] + else: + query_batch_x = query_volume[query_slice_indexes[i]:query_slice_indexes[i + 1]] + + support_batch_x = support_volume[support_slice_indexes[i]] + + support_batch_x = support_batch_x.repeat(len(query_batch_x), 1, 1, 1) if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) weights = model.conditioner(support_batch_x) - # space_w, channel_w = weights - # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = space_w - # e_c1, e_c2, e_c3, bn_c, d_c3, d_c2, d_c1, cls_c = channel_w - # # e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights - # space_w = [e_w1, e_w2, None, None, None, d_w2, d_w1, cls_w] - # channel_w = [e_c1, e_c2, e_c3, bn_c, d_c3, d_c2, d_c1, cls_c] - # weights = (space_w, channel_w) - e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, d_w1, cls_w = weights - weights = [e_w1, e_w2, e_w3, bn_w, d_w3, d_w2, None, None] out = model.segmentor(query_batch_x, weights) _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) volume_prediction.append(batch_output) + query_current_slice += slice_gap_query + support_current_slice += slice_gap_support + + # query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation, + # remap_config=remap_config) + # query_labelmap = query_labelmap == query_label + # range_query = get_range(query_labelmap) + # query_volume = query_volume[range_query[0]: range_query[1]] + # + # query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] + # query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor( + # query_labelmap).type(torch.LongTensor) + # + # support_batch_x = [] + # + # volume_prediction = [] + # + # support_current_slice = 0 + # query_current_slice = 0 + # support_slice_left = support_volume[range_index[0]] + + # for i in range(0, range_index[0], batch_size): + # end_index_query = query_current_slice + batch_size + # end_index_query = end_index_query if end_index_query < range_index[0] else range_index[0] + # + # query_batch_x = query_volume[i: end_index_query] + # + # support_batch_x = support_slice_left.repeat(query_batch_x.size()[0], 1, 1, 1) + # + # if cuda_available: + # query_batch_x = query_batch_x.cuda(device) + # support_batch_x = support_batch_x.cuda(device) + # + # weights = model.conditioner(support_batch_x) + # out = model.segmentor(query_batch_x, weights) + # + # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + # volume_prediction.append(batch_output) + # query_current_slice = end_index_query + # support_current_slice = query_current_slice + # + # for i in range(range_index[0], range_index[1] + 1, batch_size): + # end_index_query = query_current_slice + batch_size + # end_index_query = end_index_query if end_index_query < range_index[1] + 1 else range_index[1] + 1 + # + # query_batch_x = query_volume[i: end_index_query] + # + # # end_index_support = support_current_slice + batch_size + # # end_index_support = end_index_support if end_index_support < len(range_index[1] + 1) else len( + # # range_index[1] + 1) + # # print(len(support_volume)) + # # print(support_current_slice, end_index_query) + # support_batch_x = support_volume[support_current_slice: end_index_query] + # + # query_current_slice = end_index_query + # support_current_slice = query_current_slice + # + # support_batch_x = support_batch_x[0].repeat(query_batch_x.size()[0], 1, 1, 1) + # + # # k += 1 + # if cuda_available: + # query_batch_x = query_batch_x.cuda(device) + # support_batch_x = support_batch_x.cuda(device) + # + # weights = model.conditioner(support_batch_x) + # out = model.segmentor(query_batch_x, weights) + # + # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + # volume_prediction.append(batch_output) + # + # support_slice_right = support_volume[range_index[1]] + # for i in range(range_index[1] + 1, len(support_volume), batch_size): + # end_index_query = query_current_slice + batch_size + # end_index_query = end_index_query if end_index_query < len(support_volume) else len(support_volume) + # + # query_batch_x = query_volume[i: end_index_query] + # + # support_batch_x = support_slice_right.repeat(query_batch_x.size()[0], 1, 1, 1) + # + # if cuda_available: + # query_batch_x = query_batch_x.cuda(device) + # support_batch_x = support_batch_x.cuda(device) + # + # weights = model.conditioner(support_batch_x) + # out = model.segmentor(query_batch_x, weights) + # + # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + # volume_prediction.append(batch_output) + # query_current_slice = end_index_query + # support_current_slice = query_current_slice volume_prediction = torch.cat(volume_prediction) - volume_dice_score = dice_score_binary(volume_prediction, query_labelmap.cuda(device), phase=mode) + + # batch, _, _ = query_labelmap.size() + # slice_with_class = torch.sum(query_labelmap.view(batch, -1), dim=1) > 10 + # index = slice_with_class[:-1] - slice_with_class[1:] > 0 + # seq = torch.Tensor(range(batch - 1)) + # range_index_gt = seq[index].type(torch.LongTensor) + + volume_dice_score = dice_score_binary(volume_prediction[:len(query_labelmap)], query_labelmap.cuda(device), phase=mode) volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4)) nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) + # + # # # Save Input + # # nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4)) + # # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) - # # Save Input - # nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4)) - # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) # # # Condition Input # nifti_img = nib.MGHImage(np.squeeze(support_volume.cpu().numpy()), np.eye(4)) # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) - # # # Cond GT - # nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4)) - # nib.save(nifti_img, - # os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) - # # # # Save Ground Truth - # nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4)) - # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + str('.mgz'))) + # # Cond GT + nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4)) + nib.save(nifti_img, + os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) + + # # # Save Ground Truth + nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4)) + nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + fold + + str('.mgz'))) # if logWriter: # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], # vol_idx) - volume_dice_score = volume_dice_score.cpu().numpy() + volume_dice_score = volume_dice_score.item() volume_dice_score_list.append(volume_dice_score) print(volume_dice_score) + print(volume_dice_score_list) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) @@ -520,7 +660,9 @@ def evaluate_dice_score_3view(model1_path, volume_prediction1 = torch.cat(volume_prediction1) volume_prediction2 = torch.cat(volume_prediction2) volume_prediction3 = torch.cat(volume_prediction3) - volume_prediction = 0.33 * F.softmax(volume_prediction1,dim=1) + 0.33 * F.softmax(volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax(volume_prediction3.permute(2, 1, 3, 0), dim=1) + volume_prediction = 0.33 * F.softmax(volume_prediction1, dim=1) + 0.33 * F.softmax( + volume_prediction2.permute(3, 1, 0, 2), dim=1) + 0.33 * F.softmax( + volume_prediction3.permute(2, 1, 3, 0), dim=1) _, batch_output = torch.max(volume_prediction, dim=1) volume_dice_score = dice_score_binary(batch_output, query_labelmap1.cuda(device), phase=mode) diff --git a/utils/evaluator_kshot.py b/utils/evaluator_kshot.py index 7e1866a..854702a 100644 --- a/utils/evaluator_kshot.py +++ b/utils/evaluator_kshot.py @@ -123,12 +123,21 @@ def evaluate_dice_score(model_path, support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) - # slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) + # # Save Input + nifti_img = nib.MGHImage(np.squeeze(support_volume[:, 0, :, :].cpu().numpy()), np.eye(4)) + nib.save(nifti_img, os.path.join(prediction_path, 'SupportInput_' + str('.mgz'))) - # support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] + nifti_img = nib.MGHImage(np.squeeze(support_volume[:, 1, :, :].cpu().numpy()), np.eye(4)) + nib.save(nifti_img, os.path.join(prediction_path, 'SupportGT_' + str('.mgz'))) - # if len(support_slice_indexes) < Num_support: - # support_slice_indexes.append(len(support_volume) - 1) + print("Saved") + + slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) + + support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] + + if len(support_slice_indexes) < Num_support: + support_slice_indexes.append(len(support_volume) - 1) for vol_idx, file_path in enumerate(query_file_paths): @@ -147,32 +156,50 @@ def evaluate_dice_score(model_path, dice_per_slice = [] vol_output = [] - - for i, query_slice in enumerate(query_volume): - query_batch_x = query_slice.unsqueeze(0) - max_dice = -1.0 - max_output = None - for j in range(0, len(support_volume), 5): - support_slice = support_volume[j] - - support_batch_x = support_slice.unsqueeze(0) + for support_slice_idx in support_slice_indexes: + batch_output = [] + for i in range(0, len(query_volume), batch_size): + query_batch_x = query_volume[i: i + batch_size] + support_batch_x = support_volume[support_slice_idx].repeat(query_batch_x.size()[0], 1, 1, 1) if cuda_available: query_batch_x = query_batch_x.cuda(device) support_batch_x = support_batch_x.cuda(device) - weights = model.conditioner(support_batch_x) out = model.segmentor(query_batch_x, weights) - _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - slice_dice_score = dice_score_binary(batch_output, - query_labelmap[i].cuda(device), phase=mode) - if slice_dice_score.item() >= max_dice: - max_dice = slice_dice_score.item() - max_output = batch_output - # dice_per_slice.append(max_dice) - vol_output.append(max_output) - - vol_output = torch.cat(vol_output) + # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + batch_output.append(out) + batch_output = torch.cat(batch_output) + vol_output.append(batch_output) + vol_output = torch.stack(vol_output) + vol_output = torch.mean(vol_output, dim=0) + _, vol_output = torch.max(F.softmax(vol_output, dim=1), dim=1) + + # for i, query_slice in enumerate(query_volume): + # query_batch_x = query_slice.unsqueeze(0) + # max_dice = -1.0 + # max_output = None + # for j in range(0, len(support_volume), 5): + # support_slice = support_volume[j] + # + # support_batch_x = support_slice.unsqueeze(0) + # if cuda_available: + # query_batch_x = query_batch_x.cuda(device) + # support_batch_x = support_batch_x.cuda(device) + # + # weights = model.conditioner(support_batch_x) + # out = model.segmentor(query_batch_x, weights) + # + # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + # slice_dice_score = dice_score_binary(batch_output, + # query_labelmap[i].cuda(device), phase=mode) + # if slice_dice_score.item() >= max_dice: + # max_dice = slice_dice_score.item() + # max_output = batch_output + # # dice_per_slice.append(max_dice) + # vol_output.append(max_output) + # + # vol_output = torch.cat(vol_output) # volume_dice_score = np.mean(np.asarray(dice_per_slice)) volume_dice_score = dice_score_binary(vol_output, query_labelmap.cuda(device), phase=mode) volume_dice_score_list.append(volume_dice_score) diff --git a/utils/evaluator_multi_volume_support.py b/utils/evaluator_multi_volume_support.py index 53adc53..45e66fe 100644 --- a/utils/evaluator_multi_volume_support.py +++ b/utils/evaluator_multi_volume_support.py @@ -7,6 +7,7 @@ import utils.common_utils as common_utils import utils.data_utils as du import torch.nn.functional as F +import matplotlib.pyplot as plt import shot_batch_sampler as SB @@ -74,6 +75,73 @@ def binarize_label(volume, groud_truth, class_label): return condition_input, range_index.cpu().numpy() +def CV(samples_arr): + eps = 0.0001 + threshold = 0.0001 + sample_size = samples_arr[0].size() + total_pixels = sample_size[-1] * sample_size[-2] + + samples_arr = [sample.squeeze() for sample in samples_arr if + (sample.sum().item() / total_pixels) > threshold] + + if len(samples_arr) > 0: + samples_arr = torch.cat(samples_arr).float().squeeze() + std = torch.std(samples_arr).item() + mean = torch.mean(samples_arr).item() + eps + return std / mean + else: + return 1000000 + + +def IOU_Single(sample1, sample2): + eps = 0.0001 + sample1, sample2 = sample1.squeeze().byte(), sample2.squeeze().byte() + intersection = sample1 & sample2 + union = sample1 | sample2 + + numerator = intersection.sum().item() + denominator = union.sum().item() + eps + return numerator / denominator + + +def IoU(samples_arr, support, query, vol): + # Hardcoded for 3 elements as of now + eps = 0.0001 + threshold = 0.0001 + sample_size = samples_arr[0].size() + total_pixels = sample_size[-1] * sample_size[-2] + + samples_arr = [sample.squeeze().byte() for sample in samples_arr if + (sample.sum().item() / total_pixels) > threshold] + + # plt.imsave(str(vol) + '_' + str(query) + '_' + str(support) + '_' + 'fig1', sample1.cpu().numpy()) + # plt.imsave(str(vol) + '_' + str(query) + '_' + str(support) + '_' + 'fig2', sample2.cpu().numpy()) + # plt.imsave(str(vol) + '_' + str(query) + '_' + str(support) + '_' + 'fig3', sample3.cpu().numpy()) + if len(samples_arr) > 0: + intersection = torch.ones(samples_arr[0].size()).byte().cuda() + union = torch.zeros(samples_arr[0].size()).byte().cuda() + + for sample in samples_arr: + intersection = intersection & sample + union = union | sample + + numerator = intersection.sum().item() + denominator = union.sum().item() + eps + + return numerator / denominator + else: + return 0 + + # inter = torch.ones(samples_arr[0].size()).cuda() + # union = torch.zeros(samples_arr[0].size()).cuda() + # + # + # for sample in samples_arr: + # inter = torch.sum(torch.mul(inter, sample.type(torch.cuda.FloatTensor))) + # union = torch.sum(union) + torch.sum(sample.type(torch.cuda.FloatTensor)) - inter + # return -torch.div(inter, union + 0.0001).item() + + def evaluate_dice_score(model_path, num_classes, query_labels, @@ -86,7 +154,8 @@ def evaluate_dice_score(model_path, print("**Starting evaluation. Please check tensorboard for plots if a logWriter is provided in arguments**") print("Loading model => " + model_path) batch_size = 20 - Num_support = 10 + Num_support = 15 + MC_samples = 10 with open(query_txt_file) as file_handle: volumes_query = file_handle.read().splitlines() @@ -100,7 +169,6 @@ def evaluate_dice_score(model_path, model.cuda(device) model.eval() - common_utils.create_if_not(prediction_path) print("Evaluating now... " + fold) @@ -109,26 +177,33 @@ def evaluate_dice_score(model_path, with torch.no_grad(): all_query_dice_score_list = [] + for query_label in query_labels: volume_dice_score_list = [] - # Loading support - support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], - orientation=orientation, - remap_config=remap_config) - support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, - :] - support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ - torch.tensor(support_labelmap).type(torch.LongTensor) + support_slices = [] + + for i, file_path in enumerate(support_file_paths): + # Loading support + support_volume, support_labelmap, _, _ = du.load_and_preprocess(file_path, + orientation=orientation, + remap_config=remap_config) + + support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, + :] + support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), \ + torch.tensor(support_labelmap).type(torch.LongTensor) + + support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) + + slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) - support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) + support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] - # slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) - # - # support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] - # - # if len(support_slice_indexes) < Num_support: - # support_slice_indexes.append(len(support_volume) - 1) + if len(support_slice_indexes) < Num_support: + support_slice_indexes.append(len(support_volume) - 1) + + support_slices.extend([support_volume[idx] for idx in support_slice_indexes]) for vol_idx, file_path in enumerate(query_file_paths): @@ -145,42 +220,31 @@ def evaluate_dice_score(model_path, query_volume = query_volume[range_query[0]: range_query[1] + 1] query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1] - dice_per_slice = [] - vol_output = [] - - for i, query_slice in enumerate(query_volume): - query_batch_x = query_slice.unsqueeze(0) - max_dice = -1.0 - max_output = None - for j in range(0, len(support_volume), 10): - support_slice = support_volume[j] + slice_gap_query = int(np.ceil(len(query_volume) / Num_support)) + dice_per_batch = [] + batch_output_arr = [] + for support_slice_idx, i in enumerate(range(0, len(query_volume), slice_gap_query)): + query_batch_x = query_volume[i:i+slice_gap_query] + support_batch_x = support_volume[support_slice_idx].repeat(query_batch_x.size()[0], 1, 1, 1) - support_batch_x = support_slice.unsqueeze(0) - if cuda_available: - query_batch_x = query_batch_x.cuda(device) - support_batch_x = support_batch_x.cuda(device) + if cuda_available: + query_batch_x = query_batch_x.cuda(device) + support_batch_x = support_batch_x.cuda(device) weights = model.conditioner(support_batch_x) out = model.segmentor(query_batch_x, weights) _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - slice_dice_score = dice_score_binary(batch_output, - query_labelmap[i].cuda(device), phase=mode) - dice_per_slice.append(slice_dice_score.item()) - if slice_dice_score.item() >= max_dice: - max_dice = slice_dice_score.item() - max_output = batch_output - # dice_per_slice.append(max_dice) - vol_output.append(max_output) - - vol_output = torch.cat(vol_output) - volume_dice_score = dice_score_binary(vol_output, query_labelmap.cuda(device), phase=mode) - volume_dice_score_list.append(volume_dice_score) + batch_output_arr.append(batch_output) - print(volume_dice_score) + volume_output = torch.cat(batch_output_arr) + volume_dice_score = dice_score_binary(volume_output, query_labelmap.cuda(device), phase=mode) + volume_dice_score_list.append(volume_dice_score.item()) + print(str(file_path), volume_dice_score) dice_score_arr = np.asarray(volume_dice_score_list) avg_dice_score = np.median(dice_score_arr) + print(volume_dice_score_list) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) diff --git a/utils/evaluator_slow.py b/utils/evaluator_slow.py index 47a5116..53adc53 100644 --- a/utils/evaluator_slow.py +++ b/utils/evaluator_slow.py @@ -111,17 +111,6 @@ def evaluate_dice_score(model_path, all_query_dice_score_list = [] for query_label in query_labels: volume_dice_score_list = [] - # - # support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], - # orientation=orientation, - # remap_config=remap_config) - # - # support_volume = support_volume if len(support_volume.shape) == 4 else support_volume[:, np.newaxis, :, :] - # - # support_volume, support_labelmap = torch.tensor(support_volume).type(torch.FloatTensor), torch.tensor( - # support_labelmap).type(torch.LongTensor) - # support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) - # support_volume = support_volume[range_index[0]: range_index[1]] # Loading support support_volume, support_labelmap, _, _ = du.load_and_preprocess(support_file_paths[0], @@ -134,12 +123,12 @@ def evaluate_dice_score(model_path, support_volume, range_index = binarize_label(support_volume, support_labelmap, query_label) - slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) - - support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] - - if len(support_slice_indexes) < Num_support: - support_slice_indexes.append(len(support_volume) - 1) + # slice_gap_support = int(np.ceil(len(support_volume) / Num_support)) + # + # support_slice_indexes = [i for i in range(0, len(support_volume), slice_gap_support)] + # + # if len(support_slice_indexes) < Num_support: + # support_slice_indexes.append(len(support_volume) - 1) for vol_idx, file_path in enumerate(query_file_paths): @@ -156,162 +145,36 @@ def evaluate_dice_score(model_path, query_volume = query_volume[range_query[0]: range_query[1] + 1] query_labelmap = query_labelmap[range_query[0]: range_query[1] + 1] - slice_gap_query = int(np.ceil((len(query_volume) / Num_support))) - - query_slice_indexes = [i for i in range(0, len(query_volume), slice_gap_query)] - if len(query_slice_indexes) < Num_support: - query_slice_indexes.append(len(query_volume) - 1) - - volume_prediction = [] - - # for i in range(0, len(query_volume), batch_size): - support_current_slice = 0 - query_current_slice = 0 - - for i, query_start_slice in enumerate(query_slice_indexes): - if query_start_slice == query_slice_indexes[-1]: - query_batch_x = query_volume[query_slice_indexes[i]:] - else: - query_batch_x = query_volume[query_slice_indexes[i]:query_slice_indexes[i + 1]] - - support_batch_x = support_volume[support_slice_indexes[i]] - - support_batch_x = support_batch_x.repeat(len(query_batch_x), 1, 1, 1) - if cuda_available: - query_batch_x = query_batch_x.cuda(device) - support_batch_x = support_batch_x.cuda(device) - - weights = model.conditioner(support_batch_x) - out = model.segmentor(query_batch_x, weights) - - _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - volume_prediction.append(batch_output) - query_current_slice += slice_gap_query - support_current_slice += slice_gap_support - - # query_volume, query_labelmap, _, _ = du.load_and_preprocess(file_path, orientation=orientation, - # remap_config=remap_config) - # query_labelmap = query_labelmap == query_label - # range_query = get_range(query_labelmap) - # query_volume = query_volume[range_query[0]: range_query[1]] - # - # query_volume = query_volume if len(query_volume.shape) == 4 else query_volume[:, np.newaxis, :, :] - # query_volume, query_labelmap = torch.tensor(query_volume).type(torch.FloatTensor), torch.tensor( - # query_labelmap).type(torch.LongTensor) - # - # support_batch_x = [] - # - # volume_prediction = [] - # - # support_current_slice = 0 - # query_current_slice = 0 - # support_slice_left = support_volume[range_index[0]] - - # for i in range(0, range_index[0], batch_size): - # end_index_query = query_current_slice + batch_size - # end_index_query = end_index_query if end_index_query < range_index[0] else range_index[0] - # - # query_batch_x = query_volume[i: end_index_query] - # - # support_batch_x = support_slice_left.repeat(query_batch_x.size()[0], 1, 1, 1) - # - # if cuda_available: - # query_batch_x = query_batch_x.cuda(device) - # support_batch_x = support_batch_x.cuda(device) - # - # weights = model.conditioner(support_batch_x) - # out = model.segmentor(query_batch_x, weights) - # - # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - # volume_prediction.append(batch_output) - # query_current_slice = end_index_query - # support_current_slice = query_current_slice - # - # for i in range(range_index[0], range_index[1] + 1, batch_size): - # end_index_query = query_current_slice + batch_size - # end_index_query = end_index_query if end_index_query < range_index[1] + 1 else range_index[1] + 1 - # - # query_batch_x = query_volume[i: end_index_query] - # - # # end_index_support = support_current_slice + batch_size - # # end_index_support = end_index_support if end_index_support < len(range_index[1] + 1) else len( - # # range_index[1] + 1) - # # print(len(support_volume)) - # # print(support_current_slice, end_index_query) - # support_batch_x = support_volume[support_current_slice: end_index_query] - # - # query_current_slice = end_index_query - # support_current_slice = query_current_slice - # - # support_batch_x = support_batch_x[0].repeat(query_batch_x.size()[0], 1, 1, 1) - # - # # k += 1 - # if cuda_available: - # query_batch_x = query_batch_x.cuda(device) - # support_batch_x = support_batch_x.cuda(device) - # - # weights = model.conditioner(support_batch_x) - # out = model.segmentor(query_batch_x, weights) - # - # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - # volume_prediction.append(batch_output) - # - # support_slice_right = support_volume[range_index[1]] - # for i in range(range_index[1] + 1, len(support_volume), batch_size): - # end_index_query = query_current_slice + batch_size - # end_index_query = end_index_query if end_index_query < len(support_volume) else len(support_volume) - # - # query_batch_x = query_volume[i: end_index_query] - # - # support_batch_x = support_slice_right.repeat(query_batch_x.size()[0], 1, 1, 1) - # - # if cuda_available: - # query_batch_x = query_batch_x.cuda(device) - # support_batch_x = support_batch_x.cuda(device) - # - # weights = model.conditioner(support_batch_x) - # out = model.segmentor(query_batch_x, weights) - # - # _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) - # volume_prediction.append(batch_output) - # query_current_slice = end_index_query - # support_current_slice = query_current_slice - - volume_prediction = torch.cat(volume_prediction) - - # batch, _, _ = query_labelmap.size() - # slice_with_class = torch.sum(query_labelmap.view(batch, -1), dim=1) > 10 - # index = slice_with_class[:-1] - slice_with_class[1:] > 0 - # seq = torch.Tensor(range(batch - 1)) - # range_index_gt = seq[index].type(torch.LongTensor) - - volume_dice_score = dice_score_binary(volume_prediction[:len(query_labelmap)], query_labelmap.cuda(device), phase=mode) - - volume_prediction = (volume_prediction.cpu().numpy()).astype('float32') - nifti_img = nib.MGHImage(np.squeeze(volume_prediction), np.eye(4)) - nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_' + fold + str('.mgz'))) - # - # # # Save Input - # # nifti_img = nib.MGHImage(np.squeeze(query_volume.cpu().numpy()), np.eye(4)) - # # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_Input_' + str('.mgz'))) - - # # # Condition Input - # nifti_img = nib.MGHImage(np.squeeze(support_volume.cpu().numpy()), np.eye(4)) - # nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInput_' + str('.mgz'))) - # # Cond GT - nifti_img = nib.MGHImage(np.squeeze(support_labelmap.cpu().numpy()).astype('float32'), np.eye(4)) - nib.save(nifti_img, - os.path.join(prediction_path, volumes_query[vol_idx] + '_CondInputGT_' + str('.mgz'))) - - # # # Save Ground Truth - nifti_img = nib.MGHImage(np.squeeze(query_labelmap.cpu().numpy()), np.eye(4)) - nib.save(nifti_img, os.path.join(prediction_path, volumes_query[vol_idx] + '_GT_' + fold - + str('.mgz'))) - - # if logWriter: - # logWriter.plot_dice_score('val', 'eval_dice_score', volume_dice_score, volumes_to_use[vol_idx], - # vol_idx) - volume_dice_score = volume_dice_score.cpu().numpy() + dice_per_slice = [] + vol_output = [] + + for i, query_slice in enumerate(query_volume): + query_batch_x = query_slice.unsqueeze(0) + max_dice = -1.0 + max_output = None + for j in range(0, len(support_volume), 10): + support_slice = support_volume[j] + + support_batch_x = support_slice.unsqueeze(0) + if cuda_available: + query_batch_x = query_batch_x.cuda(device) + support_batch_x = support_batch_x.cuda(device) + + weights = model.conditioner(support_batch_x) + out = model.segmentor(query_batch_x, weights) + + _, batch_output = torch.max(F.softmax(out, dim=1), dim=1) + slice_dice_score = dice_score_binary(batch_output, + query_labelmap[i].cuda(device), phase=mode) + dice_per_slice.append(slice_dice_score.item()) + if slice_dice_score.item() >= max_dice: + max_dice = slice_dice_score.item() + max_output = batch_output + # dice_per_slice.append(max_dice) + vol_output.append(max_output) + + vol_output = torch.cat(vol_output) + volume_dice_score = dice_score_binary(vol_output, query_labelmap.cuda(device), phase=mode) volume_dice_score_list.append(volume_dice_score) print(volume_dice_score) @@ -320,10 +183,7 @@ def evaluate_dice_score(model_path, avg_dice_score = np.median(dice_score_arr) print('Query Label -> ' + str(query_label) + ' ' + str(avg_dice_score)) all_query_dice_score_list.append(avg_dice_score) - # class_dist = [dice_score_arr[:, c] for c in range(num_classes)] - # if logWriter: - # logWriter.plot_eval_box_plot('eval_dice_score_box_plot', class_dist, 'Box plot Dice Score') print("DONE") return np.mean(all_query_dice_score_list) diff --git a/utils/log_utils.py b/utils/log_utils.py index 7d112d4..bfd03c5 100644 --- a/utils/log_utils.py +++ b/utils/log_utils.py @@ -109,6 +109,7 @@ def dice_score_per_epoch(self, phase, output, correct_labels, epoch): # self.plot_dice_score(phase, 'dice_score_per_epoch', ds, 'Dice Score', epoch) self.log("DONE") + return ds def dice_score_per_epoch_segmentor(self, phase, output, correct_labels, epoch): self.log("Dice Score...") diff --git a/utils/preprocessor.py b/utils/preprocessor.py index 3687034..2042252 100644 --- a/utils/preprocessor.py +++ b/utils/preprocessor.py @@ -52,16 +52,29 @@ def remap_labels(labels, remap_config): elif remap_config == 'WholeBody': label_list = [1, 2, 7, 8, 9, 13, 14, 17, 18] + + elif remap_config == 'brain_fewshot': + labels[(labels >= 100) & (labels % 2 == 0)] = 210 + labels[(labels >= 100) & (labels % 2 == 1)] = 211 + label_list = [[210, 211], [45, 44], [52, 51], [35], [39, 41, 40, 38], [36, 37, 57, 58, 60, 59, 56, 55]] else: raise ValueError("Invalid argument value for remap config, only valid options are FS and Neo") new_labels = np.zeros_like(labels) - for i, label in enumerate(label_list): - label_present = np.zeros_like(labels) - label_present[labels == label] = 1 - new_labels = new_labels + (i + 1) * label_present + k = isinstance(label_list[0], list) + if not k: + for i, label in enumerate(label_list): + label_present = np.zeros_like(labels) + label_present[labels == label] = 1 + new_labels = new_labels + (i + 1) * label_present + else: + for i, label in enumerate(label_list): + label_present = np.zeros_like(labels) + for j in label: + label_present[labels == j] = 1 + new_labels = new_labels + (i + 1) * label_present return new_labels diff --git a/utils/shot_batch_sampler.py b/utils/shot_batch_sampler.py index cb2e966..c3d9fcc 100644 --- a/utils/shot_batch_sampler.py +++ b/utils/shot_batch_sampler.py @@ -8,8 +8,14 @@ lab_list_fold = {"fold1": {"train": [2, 6, 7, 8, 9], "val": [1]}, "fold2": {"train": [1, 6, 7, 8, 9], "val": [2]}, - "fold4": {"train": [1, 2, 8, 9], "val": [6, 7]}, - "fold5": {"train": [1, 2, 6, 7], "val": [8, 9]}} + "fold3": {"train": [1, 2, 8, 9], "val": [6, 7]}, + "fold4": {"train": [1, 2, 6, 7], "val": [8, 9]}} + +# For brain +# lab_list_fold = {"fold1": {"train": [1, 2, 3, 5, 6], "val": [4]}, +# "fold2": {"train": [1, 3, 4, 5, 6], "val": [2]}, +# "fold3": {"train": [1, 2, 3, 4, 6], "val": [5]}, +# "fold4": {"train": [1, 2, 3, 4, 5], "val": [6]}} # lab_list_fold = {"fold1": {"train": [2, 4, 6, 8], "val": [1]}, # "fold2": {"train": [1, 4, 6, 8], "val": [2]},