Skip to content

Commit

Permalink
fix indices device bug (WongKinYiu#1311)
Browse files Browse the repository at this point in the history
  • Loading branch information
superfast852 authored Dec 28, 2022
1 parent 557e383 commit 45f2c1f
Showing 1 changed file with 3 additions and 3 deletions.
6 changes: 3 additions & 3 deletions utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,7 +642,7 @@ def build_targets(self, p, targets, imgs):
#indices, anch = self.find_4_positive(p, targets)
#indices, anch = self.find_5_positive(p, targets)
#indices, anch = self.find_9_positive(p, targets)

device = torch.device(targets.device)
matching_bs = [[] for pp in p]
matching_as = [[] for pp in p]
matching_gjs = [[] for pp in p]
Expand Down Expand Up @@ -682,7 +682,7 @@ def build_targets(self, p, targets, imgs):
all_gj.append(gj)
all_gi.append(gi)
all_anch.append(anch[i][idx])
from_which_layer.append(torch.ones(size=(len(b),)) * i)
from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device)

fg_pred = pi[b, a, gj, gi]
p_obj.append(fg_pred[:, 4:5])
Expand Down Expand Up @@ -753,7 +753,7 @@ def build_targets(self, p, targets, imgs):
_, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0)
matching_matrix[:, anchor_matching_gt > 1] *= 0.0
matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0
fg_mask_inboxes = matching_matrix.sum(0) > 0.0
fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device)
matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0)

from_which_layer = from_which_layer[fg_mask_inboxes]
Expand Down

0 comments on commit 45f2c1f

Please sign in to comment.