Skip to content

Commit

Permalink
Enhance AssignResult and SamplingResult (#1995)
Browse files Browse the repository at this point in the history
* Enhance AssignResult and SamplingResult

Add runtime dependency on ubelt (pending approval)

Fix issue in SamplingResult.__init__

Add rng as attribute of RandomSampler

* fix linters

* remove ubelt

* Fix linters

* fix linters again
  • Loading branch information
Erotemic authored and hellock committed Jan 27, 2020
1 parent 78529ec commit 8457bba
Show file tree
Hide file tree
Showing 7 changed files with 415 additions and 29 deletions.
144 changes: 122 additions & 22 deletions mmdet/core/bbox/assigners/assign_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import torch

from mmdet.utils import util_mixins

class AssignResult(object):

class AssignResult(util_mixins.NiceRepr):
"""
Stores assignments between predicted and truth boxes.
Expand Down Expand Up @@ -44,20 +46,25 @@ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
self.max_overlaps = max_overlaps
self.labels = labels

def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])

# Was this a bug?
# self.max_overlaps = torch.cat(
# [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps])
# IIUC, It seems like the correct code should be:
self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
@property
def num_preds(self):
"""
Return the number of predictions in this assignment
"""
return len(self.gt_inds)

if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'num_gts': self.num_gts,
'num_preds': self.num_preds,
'gt_inds': self.gt_inds,
'max_overlaps': self.max_overlaps,
'labels': self.labels,
}

def __nice__(self):
"""
Expand All @@ -81,12 +88,105 @@ def __nice__(self):
parts.append('labels.shape={!r}'.format(tuple(self.labels.shape)))
return ', '.join(parts)

def __repr__(self):
nice = self.__nice__()
classname = self.__class__.__name__
return '<{}({}) at {}>'.format(classname, nice, hex(id(self)))
@classmethod
def random(cls, **kwargs):
"""
Create random AssignResult for tests or debugging.
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
rng (None | int | numpy.random.RandomState): seed or state
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
>>> self = AssignResult.random()
>>> print(self.info)
"""
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(kwargs.get('rng', None))

num_gts = kwargs.get('num_gts', None)
num_preds = kwargs.get('num_preds', None)
p_ignore = kwargs.get('p_ignore', 0.3)
p_assigned = kwargs.get('p_assigned', 0.7)
p_use_label = kwargs.get('p_use_label', 0.5)
num_classes = kwargs.get('p_use_label', 3)

if num_gts is None:
num_gts = rng.randint(0, 8)
if num_preds is None:
num_preds = rng.randint(0, 16)

def __str__(self):
classname = self.__class__.__name__
nice = self.__nice__()
return '<{}({})>'.format(classname, nice)
if num_gts == 0:
max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
gt_inds = torch.zeros(num_preds, dtype=torch.int64)
if p_use_label is True or p_use_label < rng.rand():
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = None
else:
import numpy as np
# Create an overlap for each predicted box
max_overlaps = torch.from_numpy(rng.rand(num_preds))

# Construct gt_inds for each predicted box
is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
# maximum number of assignments constraints
n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))

assigned_idxs = np.where(is_assigned)[0]
rng.shuffle(assigned_idxs)
assigned_idxs = assigned_idxs[0:n_assigned]
assigned_idxs.sort()

is_assigned[:] = 0
is_assigned[assigned_idxs] = True

is_ignore = torch.from_numpy(
rng.rand(num_preds) < p_ignore) & is_assigned

gt_inds = torch.zeros(num_preds, dtype=torch.int64)

true_idxs = np.arange(num_gts)
rng.shuffle(true_idxs)
true_idxs = torch.from_numpy(true_idxs)
gt_inds[is_assigned] = true_idxs[:n_assigned]

gt_inds = torch.from_numpy(
rng.randint(1, num_gts + 1, size=num_preds))
gt_inds[is_ignore] = -1
gt_inds[~is_assigned] = 0
max_overlaps[~is_assigned] = 0

if p_use_label is True or p_use_label < rng.rand():
if num_classes == 0:
labels = torch.zeros(num_preds, dtype=torch.int64)
else:
labels = torch.from_numpy(
rng.randint(1, num_classes + 1, size=num_preds))
labels[~is_assigned] = 0
else:
labels = None

self = cls(num_gts, gt_inds, max_overlaps, labels)
return self

def add_gt_(self, gt_labels):
self_inds = torch.arange(
1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
self.gt_inds = torch.cat([self_inds, self.gt_inds])

self.max_overlaps = torch.cat(
[self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])

if self.labels is not None:
self.labels = torch.cat([gt_labels, self.labels])
24 changes: 22 additions & 2 deletions mmdet/core/bbox/samplers/base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,30 @@ def sample(self,
Returns:
:obj:`SamplingResult`: Sampling result.
Example:
>>> from mmdet.core.bbox import RandomSampler
>>> from mmdet.core.bbox import AssignResult
>>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
>>> rng = ensure_rng(None)
>>> assign_result = AssignResult.random(rng=rng)
>>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
>>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
>>> gt_labels = None
>>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
>>> add_gt_as_proposals=False)
>>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
"""
if len(bboxes.shape) < 2:
bboxes = bboxes[None, :]

bboxes = bboxes[:, :4]

gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
if self.add_gt_as_proposals and len(gt_bboxes) > 0:
if gt_labels is None:
raise ValueError(
'gt_labels must be given when add_gt_as_proposals is True')
bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
assign_result.add_gt_(gt_labels)
gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
Expand All @@ -74,5 +93,6 @@ def sample(self,
assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
neg_inds = neg_inds.unique()

return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
assign_result, gt_flags)
return sampling_result
7 changes: 4 additions & 3 deletions mmdet/core/bbox/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,12 @@ def __init__(self,
neg_pos_ub=-1,
add_gt_as_proposals=True,
**kwargs):
from mmdet.core.bbox import demodata
super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
add_gt_as_proposals)
self.rng = demodata.ensure_rng(kwargs.get('rng', None))

@staticmethod
def random_choice(gallery, num):
def random_choice(self, gallery, num):
"""Random select some elements from the gallery.
It seems that Pytorch's implementation is slower than numpy so we use
Expand All @@ -26,7 +27,7 @@ def random_choice(gallery, num):
if isinstance(gallery, list):
gallery = np.array(gallery)
cands = np.arange(len(gallery))
np.random.shuffle(cands)
self.rng.shuffle(cands)
rand_inds = cands[:num]
if not isinstance(gallery, np.ndarray):
rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device)
Expand Down
134 changes: 132 additions & 2 deletions mmdet/core/bbox/samplers/sampling_result.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,25 @@
import torch

from mmdet.utils import util_mixins

class SamplingResult(object):

class SamplingResult(util_mixins.NiceRepr):
"""
Example:
>>> # xdoctest: +IGNORE_WANT
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random(rng=10)
>>> print('self = {}'.format(self))
self = <SamplingResult({
'neg_bboxes': torch.Size([12, 4]),
'neg_inds': tensor([ 0, 1, 2, 4, 5, 6, 7, 8, 9, 10, 11, 12]),
'num_gts': 4,
'pos_assigned_gt_inds': tensor([], dtype=torch.int64),
'pos_bboxes': torch.Size([0, 4]),
'pos_inds': tensor([], dtype=torch.int64),
'pos_is_gt': tensor([], dtype=torch.uint8)
})>
"""

def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
gt_flags):
Expand All @@ -13,7 +31,17 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,

self.num_gts = gt_bboxes.shape[0]
self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]

if gt_bboxes.numel() == 0:
# hack for index error case
assert self.pos_assigned_gt_inds.numel() == 0
self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
else:
if len(gt_bboxes.shape) < 2:
gt_bboxes = gt_bboxes.view(-1, 4)

self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]

if assign_result.labels is not None:
self.pos_gt_labels = assign_result.labels[pos_inds]
else:
Expand All @@ -22,3 +50,105 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
@property
def bboxes(self):
return torch.cat([self.pos_bboxes, self.neg_bboxes])

def to(self, device):
"""
Change the device of the data inplace.
Example:
>>> self = SamplingResult.random()
>>> print('self = {}'.format(self.to(None)))
>>> # xdoctest: +REQUIRES(--gpu)
>>> print('self = {}'.format(self.to(0)))
"""
_dict = self.__dict__
for key, value in _dict.items():
if isinstance(value, torch.Tensor):
_dict[key] = value.to(device)
return self

def __nice__(self):
data = self.info.copy()
data['pos_bboxes'] = data.pop('pos_bboxes').shape
data['neg_bboxes'] = data.pop('neg_bboxes').shape
parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())]
body = ' ' + ',\n '.join(parts)
return '{\n' + body + '\n}'

@property
def info(self):
"""
Returns a dictionary of info about the object
"""
return {
'pos_inds': self.pos_inds,
'neg_inds': self.neg_inds,
'pos_bboxes': self.pos_bboxes,
'neg_bboxes': self.neg_bboxes,
'pos_is_gt': self.pos_is_gt,
'num_gts': self.num_gts,
'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
}

@classmethod
def random(cls, rng=None, **kwargs):
"""
Args:
rng (None | int | numpy.random.RandomState): seed or state
Kwargs:
num_preds: number of predicted boxes
num_gts: number of true boxes
p_ignore (float): probability of a predicted box assinged to an
ignored truth
p_assigned (float): probability of a predicted box not being
assigned
p_use_label (float | bool): with labels or not
Returns:
AssignResult :
Example:
>>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
>>> self = SamplingResult.random()
>>> print(self.__dict__)
"""
from mmdet.core.bbox.samplers.random_sampler import RandomSampler
from mmdet.core.bbox.assigners.assign_result import AssignResult
from mmdet.core.bbox import demodata
rng = demodata.ensure_rng(rng)

# make probabalistic?
num = 32
pos_fraction = 0.5
neg_pos_ub = -1

assign_result = AssignResult.random(rng=rng, **kwargs)

# Note we could just compute an assignment
bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)

if rng.rand() > 0.2:
# sometimes algorithms squeeze their data, be robust to that
gt_bboxes = gt_bboxes.squeeze()
bboxes = bboxes.squeeze()

if assign_result.labels is None:
gt_labels = None
else:
gt_labels = None # todo

if gt_labels is None:
add_gt_as_proposals = False
else:
add_gt_as_proposals = True # make probabalistic?

sampler = RandomSampler(
num,
pos_fraction,
neg_pos_ubo=neg_pos_ub,
add_gt_as_proposals=add_gt_as_proposals,
rng=rng)
self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
return self
Loading

0 comments on commit 8457bba

Please sign in to comment.