Skip to content

Commit

Permalink
bug fix for retinanet with 2 classes (fg/bg)
Browse files Browse the repository at this point in the history
  • Loading branch information
hellock committed Feb 9, 2019
1 parent ba73bcc commit d1cf5e5
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
3 changes: 1 addition & 2 deletions mmdet/core/anchor/anchor_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,7 @@ def anchor_target_single(flat_anchors,


def expand_binary_labels(labels, label_weights, label_channels):
bin_labels = labels.new_full(
(labels.size(0), label_channels), 0, dtype=torch.float32)
bin_labels = labels.new_full((labels.size(0), label_channels), 0)
inds = torch.nonzero(labels >= 1).squeeze()
if inds.numel() > 0:
bin_labels[inds, labels[inds] - 1] = 1
Expand Down
4 changes: 2 additions & 2 deletions mmdet/core/loss/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ def weighted_nll_loss(pred, label, weight, avg_factor=None):
return torch.sum(raw * weight)[None] / avg_factor


def weighted_cross_entropy(pred, label, weight, avg_factor=None,
reduce=True):
def weighted_cross_entropy(pred, label, weight, avg_factor=None, reduce=True):
if avg_factor is None:
avg_factor = max(torch.sum(weight > 0).float().item(), 1.)
raw = F.cross_entropy(pred, label, reduction='none')
Expand All @@ -36,6 +35,7 @@ def sigmoid_focal_loss(pred,
alpha=0.25,
reduction='mean'):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
weight = (alpha * target + (1 - alpha) * (1 - target)) * weight
weight = weight * pt.pow(gamma)
Expand Down

0 comments on commit d1cf5e5

Please sign in to comment.