Skip to content

Commit cefe8b7

Browse files
committed
Fix bug: non-ignored area may be empty in OHEM loss
1 parent f35fb55 commit cefe8b7

File tree

1 file changed

+1
-0
lines changed

1 file changed

+1
-0
lines changed

lib/loss/loss_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def forward(self, predict, target, **kwargs):
138138
tmp_target[tmp_target == self.ignore_label] = 0
139139
prob = prob_out.gather(1, tmp_target.unsqueeze(1))
140140
mask = target.contiguous().view(-1,) != self.ignore_label
141+
mask[0] = 1 # Avoid `mask` being empty
141142
sort_prob, sort_indices = prob.contiguous().view(-1,)[mask].contiguous().sort()
142143
min_threshold = sort_prob[min(self.min_kept, sort_prob.numel() - 1)]
143144
threshold = max(min_threshold, self.thresh)

0 commit comments

Comments
 (0)