Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
144 changes: 122 additions & 22 deletions mmdet/core/bbox/assigners/assign_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch

from mmdet.utils import util_mixins

class AssignResult(object):

class AssignResult(util_mixins.NiceRepr):
"""
Stores assignments between predicted and truth boxes.

Expand Down Expand Up @@ -44,20 +46,25 @@ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.max_overlaps = max_overlaps
self.labels = labels

def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])

# Was this a bug?
# self.max_overlaps = torch.cat(
# [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
# IIUC, It seems like the correct code should be:
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
@property
def num_preds(self):
"""
Return the number of predictions in this assignment
"""
return len(self.gt_inds)

if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}

def __nice__(self):
"""
Expand All @@ -81,12 +88,105 @@ def __nice__(self):
parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
return ', '.join(parts)

def __repr__(self):
nice = self.__nice__()
classname = self.__class__.__name__
return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))
@classmethod
def random(cls, **kwargs):
"""
Create random AssignResult for tests or debugging.

Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
rng (None | int | numpy.random.RandomState): seed or state

Returns:
AssignResult :

Example:
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
>>> self = AssignResult.random()
>>> print(self.info)
"""
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(kwargs.get('rng', None))

num_gts = kwargs.get('num_gts', None)
num_preds = kwargs.get('num_preds', None)
p_ignore = kwargs.get('p_ignore', 0.3)
p_assigned = kwargs.get('p_assigned', 0.7)
p_use_label = kwargs.get('p_use_label', 0.5)
num_classes = kwargs.get('p_use_label', 3)

if num_gts is None:
num_gts = rng.randint(0, 8)
if num_preds is None:
num_preds = rng.randint(0, 16)

def __str__(self):
classname = self.__class__.__name__
nice = self.__nice__()
return '<{}({})>'.format(classname, nice)
if num_gts == 0:
max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
if p_use_label is True or p_use_label < rng.rand():
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = None
else:
import numpy as np
# Create an overlap for each predicted box
max_overlaps = torch.from_numpy(rng.rand(num_preds))

# Construct gt_inds for each predicted box
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
# maximum number of assignments constraints
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))

assigned_idxs = np.where(is_assigned)[0]
rng.shuffle(assigned_idxs)
assigned_idxs = assigned_idxs[0:n_assigned]
assigned_idxs.sort()

is_assigned[:] = 0
is_assigned[assigned_idxs] = True

is_ignore = torch.from_numpy(
rng.rand(num_preds) < p_ignore) & is_assigned

gt_inds = torch.zeros(num_preds, dtype=torch.int64)

true_idxs = np.arange(num_gts)
rng.shuffle(true_idxs)
true_idxs = torch.from_numpy(true_idxs)
gt_inds[is_assigned] = true_idxs[:n_assigned]

gt_inds = torch.from_numpy(
rng.randint(1, num_gts + 1, size=num_preds))
gt_inds[is_ignore] = -1
gt_inds[~is_assigned] = 0
max_overlaps[~is_assigned] = 0

if p_use_label is True or p_use_label < rng.rand():
if num_classes == 0:
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = torch.from_numpy(
rng.randint(1, num_classes + 1, size=num_preds))
labels[~is_assigned] = 0
else:
labels = None

self = cls(num_gts, gt_inds, max_overlaps, labels)
return self

def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])

self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])

if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
24 changes: 22 additions & 2 deletions mmdet/core/bbox/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,30 @@ def sample(self,

Returns:
:obj:`SamplingResult`: Sampling result.

Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]

bboxes = bboxes[:, :4]

gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError(
'gt_labels must be given when add_gt_as_proposals is True')
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
Expand All @@ -74,5 +93,6 @@ def sample(self,
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()

return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
7 changes: 4 additions & 3 deletions mmdet/core/bbox/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(self,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
from mmdet.core.bbox import demodata
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.rng = demodata.ensure_rng(kwargs.get('rng', None))

@staticmethod
def random_choice(gallery, num):
def random_choice(self, gallery, num):
"""Random select some elements from the gallery.

It seems that Pytorch's implementation is slower than numpy so we use
Expand All @@ -26,7 +27,7 @@ def random_choice(gallery, num):
if isinstance(gallery, list):
gallery = np.array(gallery)
cands = np.arange(len(gallery))
np.random.shuffle(cands)
self.rng.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
Expand Down
134 changes: 132 additions & 2 deletions mmdet/core/bbox/samplers/sampling_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import torch

from mmdet.utils import util_mixins

class SamplingResult(object):

class SamplingResult(util_mixins.NiceRepr):
"""
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print('self = {}'.format(self))
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""

def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
Expand All @@ -13,7 +31,17 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,

self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]

if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)

self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]

if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
Expand All @@ -22,3 +50,105 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
@property
def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes])

def to(self, device):
"""
Change the device of the data inplace.

Example:
>>> self = SamplingResult.random()
>>> print('self = {}'.format(self.to(None)))
>>> # xdoctest: +REQUIRES(--gpu)
>>> print('self = {}'.format(self.to(0)))
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self

def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'

@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}

@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state

Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not

Returns:
AssignResult :

Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(rng)

# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1

assign_result = AssignResult.random(rng=rng, **kwargs)

# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)

if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()

if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo

if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?

sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ubo=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self
Loading