Skip to content

Commit

Permalink
two of three auxiliary losses
Browse files Browse the repository at this point in the history
  • Loading branch information
conradry committed Feb 3, 2021
1 parent 8108e99 commit 6e6c9e5
Showing 1 changed file with 70 additions and 1 deletion.
71 changes: 70 additions & 1 deletion max_deeplab/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def forward(self, input_class_prob, input_mask, target_class, target_mask, targe
input mask: (B, N, H, W) #probabilities [0, 1]
target_class: (B, K) #long indices
target_mask: (B, K, H, W) #bool
target_sizes: (B,) #number of masks that are padding (i.e. no class)
"""
device = input_class_prob.device
B, N = input_class_prob.size()[:2]
Expand Down Expand Up @@ -153,7 +154,7 @@ def forward(self, input_class, input_mask, target_class, target_mask, target_siz

#eqn 10
#NOTE: some people find negative losses irritating,
#-dice could be swapped for 2 - dice without harm
#-dice could be swapped for 1-dice without harm
l_pos = (class_weight * (-dice) + dice_weight * cross_entropy).mean()

if self.no_class_index == -1:
Expand All @@ -168,3 +169,71 @@ def forward(self, input_class, input_mask, target_class, target_mask, target_siz

#eqn 12
return self.alpha * l_pos * (1 - self.alpha) * l_neg

#-----------------------
### Auxiliary Losses ###
#-----------------------

class InstanceDiscLoss(nn.Module):
def __init__(self, temp=0.3):
super(InstanceDiscLoss, self).__init__()
self.temp = temp
self.xentropy = nn.CrossEntropyLoss()

def forward(self, mask_features, target_mask, target_sizes):
"""
mask_features: (B, D, H, W) #g
target_mask: (B, K, H, W) #m
"""

device = mask_features.device

#eqn 16
#consider this like other contrastive algorithms (e.g. MoCo)
query = mask_features #just for analogy
key = torch.einsum('bdhw,bkhw->bkd', mask_features, target_mask)
key = F.normalize(t, dim=-1) #(B, K, D)

#get batch and mask indices from target_sizes
batch_indices = []
mask_indices = []
for bi, size in enumerate(target_sizes):
mask_indices.append(torch.arange(0, size, dtype=torch.long, device=device))
batch_indices.append(torch.full_like(mask_indices, bi))

batch_indices = torch.cat(batch_indices, dim=0) #shape: (torch.prod(target_sizes), )
mask_indices = torch.cat(mask_indices, dim=0)

#create logits and apply temperature
logits = torch.einsum('bdhw,bkd->bkhw', query, key)
logits /= self.temp

#select target_masks and logits
m = target_mask[batch_indices, mask_indices] #(torch.prod(target_sizes), H, W)
logits = logits[batch_indices, mask_indices] #(torch.prod(target_sizes), H, W)
logits *= m #masking out zeros in masks

#flip so that there are HW examples for torch.prod(target_sizes) classes
logits = rearrange(logits, 'k h w -> (h w) k')

#positive class is also zero
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=device)

return self.xentropy(logits, labels)

class SemanticSegmentationLoss(nn.Module):
def __init__(self, method='cross_entropy'):
if method != 'cross_entropy':
raise NotImplementedError
else:
#they don't specify the loss function
#could be regular cross entropy or
#dice loss or focal loss etc.
self.xentropy = nn.CrossEntropyLoss()

def forward(self, input_mask, target_mask):
"""
input_mask: (B, NUM_CLASSES, H, W) #logits
target_mask: (B, H, W) #long indices
"""
return self.xentropy(input_mask, target_mask)

0 comments on commit 6e6c9e5

Please sign in to comment.