|
| 1 | +# coding=utf-8 |
| 2 | +import torch |
| 3 | +import torch.nn as nn |
| 4 | +import torch.nn.functional as F |
| 5 | +from torch.autograd import Variable |
| 6 | +from utils.box_utils import match,refine_match, log_sum_exp,decode |
| 7 | +GPU = False |
| 8 | +if torch.cuda.is_available(): |
| 9 | + GPU = True |
| 10 | + torch.set_default_tensor_type('torch.cuda.FloatTensor') |
| 11 | + |
| 12 | + |
| 13 | +class RefineMultiBoxLoss(nn.Module): |
| 14 | + """SSD Weighted Loss Function |
| 15 | + Compute Targets: |
| 16 | + 1) Produce Confidence Target Indices by matching ground truth boxes |
| 17 | + with (default) 'priorboxes' that have jaccard index > threshold parameter |
| 18 | + (default threshold: 0.5). |
| 19 | + 2) Produce localization target by 'encoding' variance into offsets of ground |
| 20 | + truth boxes and their matched 'priorboxes'. |
| 21 | + 3) Hard negative mining to filter the excessive number of negative examples |
| 22 | + that comes with using a large number of default bounding boxes. |
| 23 | + (default negative:positive ratio 3:1) |
| 24 | + Objective Loss: |
| 25 | + L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N |
| 26 | + Where, Lconf is the CrossEntropy Loss and Lloc is the SmoothL1 Loss |
| 27 | + weighted by α which is set to 1 by cross val. |
| 28 | + Args: |
| 29 | + c: class confidences, |
| 30 | + l: predicted boxes, |
| 31 | + g: ground truth boxes |
| 32 | + N: number of matched default boxes |
| 33 | + See: https://arxiv.org/pdf/1512.02325.pdf for more details. |
| 34 | + """ |
| 35 | + |
| 36 | + |
| 37 | + def __init__(self, num_classes,overlap_thresh,prior_for_matching,bkg_label,neg_mining,neg_pos,neg_overlap,encode_target,object_score = 0): |
| 38 | + super(RefineMultiBoxLoss, self).__init__() |
| 39 | + self.num_classes = num_classes |
| 40 | + self.threshold = overlap_thresh |
| 41 | + self.background_label = bkg_label |
| 42 | + self.encode_target = encode_target |
| 43 | + self.use_prior_for_matching = prior_for_matching |
| 44 | + self.do_neg_mining = neg_mining |
| 45 | + self.negpos_ratio = neg_pos |
| 46 | + self.neg_overlap = neg_overlap |
| 47 | + self.object_score = object_score |
| 48 | + self.variance = [0.1,0.2] |
| 49 | + |
| 50 | + def forward(self, odm_data,priors, targets,arm_data = None,filter_object = False): |
| 51 | + """Multibox Loss |
| 52 | + Args: |
| 53 | + predictions (tuple): A tuple containing loc preds, conf preds, |
| 54 | + and prior boxes from SSD net. |
| 55 | + conf shape: torch.size(batch_size,num_priors,num_classes) |
| 56 | + loc shape: torch.size(batch_size,num_priors,4) |
| 57 | + priors shape: torch.size(num_priors,4) |
| 58 | +
|
| 59 | + ground_truth (tensor): Ground truth boxes and labels for a batch, |
| 60 | + shape: [batch_size,num_objs,5] (last idx is the label). |
| 61 | + arm_data (tuple): arm branch containg arm_loc and arm_conf |
| 62 | + filter_object: whether filter out the prediction according to the arm conf score |
| 63 | + """ |
| 64 | + |
| 65 | + loc_data,conf_data = odm_data |
| 66 | + if arm_data: |
| 67 | + arm_loc,arm_conf = arm_data |
| 68 | + priors = priors.data |
| 69 | + num = loc_data.size(0) |
| 70 | + num_priors = (priors.size(0)) |
| 71 | + |
| 72 | + # match priors (default boxes) and ground truth boxes |
| 73 | + loc_t = torch.Tensor(num, num_priors, 4) |
| 74 | + conf_t = torch.LongTensor(num, num_priors) |
| 75 | + for idx in range(num): |
| 76 | + truths = targets[idx][:,:-1].data |
| 77 | + labels = targets[idx][:,-1].data |
| 78 | + #for object detection |
| 79 | + if self.num_classes == 2: |
| 80 | + labels = labels > 0 |
| 81 | + if arm_data: |
| 82 | + refine_match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx,arm_loc[idx].data) |
| 83 | + else: |
| 84 | + match(self.threshold,truths,priors,self.variance,labels,loc_t,conf_t,idx) |
| 85 | + if GPU: |
| 86 | + loc_t = loc_t.cuda() |
| 87 | + conf_t = conf_t.cuda() |
| 88 | + # wrap targets |
| 89 | + loc_t = Variable(loc_t, requires_grad=False) |
| 90 | + conf_t = Variable(conf_t,requires_grad=False) |
| 91 | + if arm_data and filter_object: |
| 92 | + arm_conf_data = arm_conf.data[:,:,1] |
| 93 | + pos = conf_t > 0 |
| 94 | + object_score_index = arm_conf_data <= self.object_score |
| 95 | + pos[object_score_index] = 0 |
| 96 | + |
| 97 | + else: |
| 98 | + pos = conf_t > 0 |
| 99 | + |
| 100 | + # Localization Loss (Smooth L1) |
| 101 | + # Shape: [batch,num_priors,4] |
| 102 | + pos_idx = pos.unsqueeze(pos.dim()).expand_as(loc_data) |
| 103 | + loc_p = loc_data[pos_idx].view(-1,4) |
| 104 | + loc_t = loc_t[pos_idx].view(-1,4) |
| 105 | + loss_l = F.smooth_l1_loss(loc_p, loc_t, size_average=False) |
| 106 | + |
| 107 | + # Compute max conf across batch for hard negative mining |
| 108 | + batch_conf = conf_data.view(-1,self.num_classes) |
| 109 | + loss_c = log_sum_exp(batch_conf) - batch_conf.gather(1, conf_t.view(-1,1)) |
| 110 | + |
| 111 | + # Hard Negative Mining |
| 112 | + loss_c[pos] = 0 # filter out pos boxes for now |
| 113 | + loss_c = loss_c.view(num, -1) |
| 114 | + _,loss_idx = loss_c.sort(1, descending=True) |
| 115 | + _,idx_rank = loss_idx.sort(1) |
| 116 | + num_pos = pos.long().sum(1,keepdim=True) |
| 117 | + num_neg = torch.clamp(self.negpos_ratio*num_pos, max=pos.size(1)-1) |
| 118 | + neg = idx_rank < num_neg.expand_as(idx_rank) |
| 119 | + |
| 120 | + # Confidence Loss Including Positive and Negative Examples |
| 121 | + pos_idx = pos.unsqueeze(2).expand_as(conf_data) |
| 122 | + neg_idx = neg.unsqueeze(2).expand_as(conf_data) |
| 123 | + conf_p = conf_data[(pos_idx+neg_idx).gt(0)].view(-1,self.num_classes) |
| 124 | + targets_weighted = conf_t[(pos+neg).gt(0)] |
| 125 | + loss_c = F.cross_entropy(conf_p, targets_weighted, size_average=False) |
| 126 | + |
| 127 | + # Sum of losses: L(x,c,l,g) = (Lconf(x, c) + αLloc(x,l,g)) / N |
| 128 | + N = num_pos.data.sum() |
| 129 | + loss_l/=N |
| 130 | + loss_c/=N |
| 131 | + return loss_l,loss_c |
0 commit comments