Skip to content

Commit

Permalink
init LOCE
Browse files Browse the repository at this point in the history
  • Loading branch information
fcjian committed Aug 18, 2021
1 parent 2930b4e commit c5c523a
Showing 1 changed file with 10 additions and 13 deletions.
23 changes: 10 additions & 13 deletions mmdet/models/roi_heads/loce_roi_head.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,11 +199,11 @@ def _get_feat_for_memory(self, x, gt_bboxes, gt_labels, img_metas):
x[:self.bbox_roi_extractor.num_inputs], rois)
bbox_feats = bbox_feats.detach() # 0

gt_labels, _, reg_targets, _ = self.bbox_head.get_targets_for_memory(proposal_list, neg_proposal_list,
bbox_labels, _, bbox_targets, _ = self.bbox_head.get_targets_for_memory(proposal_list, neg_proposal_list,
gt_bbox_list, gt_label_list,
self.train_cfg)

return bbox_feats, gt_labels, reg_targets
return bbox_feats, bbox_labels, bbox_targets

def _compute_batch_mean_score(self, cls_score, sampling_results, gt_labels, selected_labels):
# batch mean score for current sample (non queue sampling samples)
Expand All @@ -212,15 +212,12 @@ def _compute_batch_mean_score(self, cls_score, sampling_results, gt_labels, sele
batch_gt_labels = []
batch_mean_scores = []
for img_ind, sampling_results_img in enumerate(sampling_results):
try:
for gt_ind, gt_label in enumerate(gt_labels[img_ind]):
if (sampling_results_img.pos_assigned_gt_inds == gt_ind).sum() > 0:
score = scores[self.bbox_sampler.num * img_ind:self.bbox_sampler.num * img_ind + len(sampling_results_img.pos_assigned_gt_inds),
gt_label][sampling_results_img.pos_assigned_gt_inds == gt_ind]
batch_gt_labels.append(gt_label.unsqueeze(0))
batch_mean_scores.append(score.mean().unsqueeze(0))
except:
print(gt_labels)
for gt_ind, gt_label in enumerate(gt_labels[img_ind]):
if (sampling_results_img.pos_assigned_gt_inds == gt_ind).sum() > 0:
score = scores[self.bbox_sampler.num * img_ind:self.bbox_sampler.num * img_ind + len(sampling_results_img.pos_assigned_gt_inds),
gt_label][sampling_results_img.pos_assigned_gt_inds == gt_ind]
batch_gt_labels.append(gt_label.unsqueeze(0))
batch_mean_scores.append(score.mean().unsqueeze(0))

# batch mean score for selected queue samples
selected_length = len(selected_labels)
Expand Down Expand Up @@ -263,8 +260,8 @@ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
rank = -1

# for memory-augmented feature sampling
bbox_feats_for_memory, gt_labels_for_memory, reg_targets_for_memory = self._get_feat_for_memory(x, gt_bboxes, gt_labels, img_metas)
self.mfs.enqueue_dequeue(bbox_feats_for_memory, gt_labels_for_memory, reg_targets_for_memory)
bbox_feats, bbox_labels, bbox_targets = self._get_feat_for_memory(x, gt_bboxes, gt_labels, img_metas)
self.mfs.enqueue_dequeue(bbox_feats, bbox_labels, bbox_targets)
selectd_bbox_feat, selectd_labels, selectd_reg_targets, selectd_cls_weight, selectd_reg_weight = \
self.mfs.probabilistic_sampler(self.mean_score)

Expand Down

0 comments on commit c5c523a

Please sign in to comment.