Skip to content

Commit

Permalink
fixed instance discrimination loss
Browse files Browse the repository at this point in the history
  • Loading branch information
conradry committed Feb 4, 2021
1 parent 6e6c9e5 commit a02e412
Showing 1 changed file with 18 additions and 16 deletions.
34 changes: 18 additions & 16 deletions max_deeplab/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,10 +175,10 @@ def forward(self, input_class, input_mask, target_class, target_mask, target_siz
#-----------------------

class InstanceDiscLoss(nn.Module):
def __init__(self, temp=0.3):
def __init__(self, temp=0.3, eps=1e-5):
super(InstanceDiscLoss, self).__init__()
self.temp = temp
self.xentropy = nn.CrossEntropyLoss()
self.eps = eps

def forward(self, mask_features, target_mask, target_sizes):
"""
Expand All @@ -189,46 +189,48 @@ def forward(self, mask_features, target_mask, target_sizes):
device = mask_features.device

#eqn 16
#consider this like other contrastive algorithms (e.g. MoCo)
query = mask_features #just for analogy
key = torch.einsum('bdhw,bkhw->bkd', mask_features, target_mask)
key = F.normalize(t, dim=-1) #(B, K, D)
t = torch.einsum('bdhw,bkhw->bkd', mask_features, target_mask)
t = F.normalize(t, dim=-1) #(B, K, D)

#get batch and mask indices from target_sizes
batch_indices = []
mask_indices = []
for bi, size in enumerate(target_sizes):
mask_indices.append(torch.arange(0, size, dtype=torch.long, device=device))
batch_indices.append(torch.full_like(mask_indices, bi))
mindices = torch.arange(0, size, dtype=torch.long, device=device)
mask_indices.append(mindices)
batch_indices.append(torch.full_like(mindices, bi))

batch_indices = torch.cat(batch_indices, dim=0) #shape: (torch.prod(target_sizes), )
mask_indices = torch.cat(mask_indices, dim=0)

#create logits and apply temperature
logits = torch.einsum('bdhw,bkd->bkhw', query, key)
logits /= self.temp
logits = torch.einsum('bdhw,bkd->bkhw', mask_features, t)

#select target_masks and logits
#TODO: use logsumexp instead of exp and log separately.
m = target_mask[batch_indices, mask_indices] #(torch.prod(target_sizes), H, W)
logits = logits[batch_indices, mask_indices] #(torch.prod(target_sizes), H, W)
logits *= m #masking out zeros in masks
logits /= self.temp
logits = torch.exp(logits[batch_indices, mask_indices]) #(torch.prod(target_sizes), H, W)

#flip so that there are HW examples for torch.prod(target_sizes) classes
logits = rearrange(logits, 'k h w -> (h w) k')
m = rearrange(m, 'k h w -> (h w) k')

#positive class is also zero
labels = torch.zeros(logits.shape[0], dtype=torch.long, device=device)

return self.xentropy(logits, labels)
#eqn 17
numerator = (m * logits).sum(-1)
denominator = logits.sum(-1)
return -torch.log(numerator + self.eps / denominator + self.eps).mean()

class SemanticSegmentationLoss(nn.Module):
def __init__(self, method='cross_entropy'):
super(SemanticSegmentationLoss, self).__init__()
if method != 'cross_entropy':
raise NotImplementedError
else:
#they don't specify the loss function
#could be regular cross entropy or
#dice loss or focal loss etc.
#keep it simple for now
self.xentropy = nn.CrossEntropyLoss()

def forward(self, input_mask, target_mask):
Expand Down

0 comments on commit a02e412

Please sign in to comment.