From 45f2c1fe559a834e1b466dd193011d3876c07b77 Mon Sep 17 00:00:00 2001 From: superfast852 <72710980+superfast852@users.noreply.github.com> Date: Tue, 27 Dec 2022 23:40:57 -0400 Subject: [PATCH] fix indices device bug (#1311) --- utils/loss.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/utils/loss.py b/utils/loss.py index 6eb70a2fa7..5fc73eaf0e 100644 --- a/utils/loss.py +++ b/utils/loss.py @@ -642,7 +642,7 @@ def build_targets(self, p, targets, imgs): #indices, anch = self.find_4_positive(p, targets) #indices, anch = self.find_5_positive(p, targets) #indices, anch = self.find_9_positive(p, targets) - + device = torch.device(targets.device) matching_bs = [[] for pp in p] matching_as = [[] for pp in p] matching_gjs = [[] for pp in p] @@ -682,7 +682,7 @@ def build_targets(self, p, targets, imgs): all_gj.append(gj) all_gi.append(gi) all_anch.append(anch[i][idx]) - from_which_layer.append(torch.ones(size=(len(b),)) * i) + from_which_layer.append((torch.ones(size=(len(b),)) * i).to(device) fg_pred = pi[b, a, gj, gi] p_obj.append(fg_pred[:, 4:5]) @@ -753,7 +753,7 @@ def build_targets(self, p, targets, imgs): _, cost_argmin = torch.min(cost[:, anchor_matching_gt > 1], dim=0) matching_matrix[:, anchor_matching_gt > 1] *= 0.0 matching_matrix[cost_argmin, anchor_matching_gt > 1] = 1.0 - fg_mask_inboxes = matching_matrix.sum(0) > 0.0 + fg_mask_inboxes = (matching_matrix.sum(0) > 0.0).to(device) matched_gt_inds = matching_matrix[:, fg_mask_inboxes].argmax(0) from_which_layer = from_which_layer[fg_mask_inboxes]