From 4b301b9fcd45984a9c472c314d60fc01d675e762 Mon Sep 17 00:00:00 2001 From: Jon Crall Date: Mon, 27 Jan 2020 01:54:18 -0500 Subject: [PATCH] Enhance AssignResult and SamplingResult (#1995) * Enhance AssignResult and SamplingResult Add runtime dependency on ubelt (pending approval) Fix issue in SamplingResult.__init__ Add rng as attribute of RandomSampler * fix linters * remove ubelt * Fix linters * fix linters again --- mmdet/core/bbox/assigners/assign_result.py | 144 +++++++++++++++++--- mmdet/core/bbox/samplers/base_sampler.py | 24 +++- mmdet/core/bbox/samplers/random_sampler.py | 7 +- mmdet/core/bbox/samplers/sampling_result.py | 134 +++++++++++++++++- mmdet/utils/util_mixins.py | 105 ++++++++++++++ tests/test_assigner.py | 16 +++ tests/test_sampler.py | 14 ++ 7 files changed, 415 insertions(+), 29 deletions(-) create mode 100644 mmdet/utils/util_mixins.py diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py index 38a24d7e60c..5e81c897820 100644 --- a/mmdet/core/bbox/assigners/assign_result.py +++ b/mmdet/core/bbox/assigners/assign_result.py @@ -1,7 +1,9 @@ import torch +from mmdet.utils import util_mixins -class AssignResult(object): + +class AssignResult(util_mixins.NiceRepr): """ Stores assignments between predicted and truth boxes. @@ -44,20 +46,25 @@ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None): self.max_overlaps = max_overlaps self.labels = labels - def add_gt_(self, gt_labels): - self_inds = torch.arange( - 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) - self.gt_inds = torch.cat([self_inds, self.gt_inds]) - - # Was this a bug? - # self.max_overlaps = torch.cat( - # [self.max_overlaps.new_ones(self.num_gts), self.max_overlaps]) - # IIUC, It seems like the correct code should be: - self.max_overlaps = torch.cat( - [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + @property + def num_preds(self): + """ + Return the number of predictions in this assignment + """ + return len(self.gt_inds) - if self.labels is not None: - self.labels = torch.cat([gt_labels, self.labels]) + @property + def info(self): + """ + Returns a dictionary of info about the object + """ + return { + 'num_gts': self.num_gts, + 'num_preds': self.num_preds, + 'gt_inds': self.gt_inds, + 'max_overlaps': self.max_overlaps, + 'labels': self.labels, + } def __nice__(self): """ @@ -81,12 +88,105 @@ def __nice__(self): parts.append('labels.shape={!r}'.format(tuple(self.labels.shape))) return ', '.join(parts) - def __repr__(self): - nice = self.__nice__() - classname = self.__class__.__name__ - return '<{}({}) at {}>'.format(classname, nice, hex(id(self))) + @classmethod + def random(cls, **kwargs): + """ + Create random AssignResult for tests or debugging. + + Kwargs: + num_preds: number of predicted boxes + num_gts: number of true boxes + p_ignore (float): probability of a predicted box assinged to an + ignored truth + p_assigned (float): probability of a predicted box not being + assigned + p_use_label (float | bool): with labels or not + rng (None | int | numpy.random.RandomState): seed or state + + Returns: + AssignResult : + + Example: + >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA + >>> self = AssignResult.random() + >>> print(self.info) + """ + from mmdet.core.bbox import demodata + rng = demodata.ensure_rng(kwargs.get('rng', None)) + + num_gts = kwargs.get('num_gts', None) + num_preds = kwargs.get('num_preds', None) + p_ignore = kwargs.get('p_ignore', 0.3) + p_assigned = kwargs.get('p_assigned', 0.7) + p_use_label = kwargs.get('p_use_label', 0.5) + num_classes = kwargs.get('p_use_label', 3) + + if num_gts is None: + num_gts = rng.randint(0, 8) + if num_preds is None: + num_preds = rng.randint(0, 16) - def __str__(self): - classname = self.__class__.__name__ - nice = self.__nice__() - return '<{}({})>'.format(classname, nice) + if num_gts == 0: + max_overlaps = torch.zeros(num_preds, dtype=torch.float32) + gt_inds = torch.zeros(num_preds, dtype=torch.int64) + if p_use_label is True or p_use_label < rng.rand(): + labels = torch.zeros(num_preds, dtype=torch.int64) + else: + labels = None + else: + import numpy as np + # Create an overlap for each predicted box + max_overlaps = torch.from_numpy(rng.rand(num_preds)) + + # Construct gt_inds for each predicted box + is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned) + # maximum number of assignments constraints + n_assigned = min(num_preds, min(num_gts, is_assigned.sum())) + + assigned_idxs = np.where(is_assigned)[0] + rng.shuffle(assigned_idxs) + assigned_idxs = assigned_idxs[0:n_assigned] + assigned_idxs.sort() + + is_assigned[:] = 0 + is_assigned[assigned_idxs] = True + + is_ignore = torch.from_numpy( + rng.rand(num_preds) < p_ignore) & is_assigned + + gt_inds = torch.zeros(num_preds, dtype=torch.int64) + + true_idxs = np.arange(num_gts) + rng.shuffle(true_idxs) + true_idxs = torch.from_numpy(true_idxs) + gt_inds[is_assigned] = true_idxs[:n_assigned] + + gt_inds = torch.from_numpy( + rng.randint(1, num_gts + 1, size=num_preds)) + gt_inds[is_ignore] = -1 + gt_inds[~is_assigned] = 0 + max_overlaps[~is_assigned] = 0 + + if p_use_label is True or p_use_label < rng.rand(): + if num_classes == 0: + labels = torch.zeros(num_preds, dtype=torch.int64) + else: + labels = torch.from_numpy( + rng.randint(1, num_classes + 1, size=num_preds)) + labels[~is_assigned] = 0 + else: + labels = None + + self = cls(num_gts, gt_inds, max_overlaps, labels) + return self + + def add_gt_(self, gt_labels): + self_inds = torch.arange( + 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device) + self.gt_inds = torch.cat([self_inds, self.gt_inds]) + + self.max_overlaps = torch.cat( + [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps]) + + if self.labels is not None: + self.labels = torch.cat([gt_labels, self.labels]) diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py index a396a8d8a92..f437195f6b7 100644 --- a/mmdet/core/bbox/samplers/base_sampler.py +++ b/mmdet/core/bbox/samplers/base_sampler.py @@ -47,11 +47,30 @@ def sample(self, Returns: :obj:`SamplingResult`: Sampling result. + + Example: + >>> from mmdet.core.bbox import RandomSampler + >>> from mmdet.core.bbox import AssignResult + >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes + >>> rng = ensure_rng(None) + >>> assign_result = AssignResult.random(rng=rng) + >>> bboxes = random_boxes(assign_result.num_preds, rng=rng) + >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng) + >>> gt_labels = None + >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1, + >>> add_gt_as_proposals=False) + >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels) """ + if len(bboxes.shape) < 2: + bboxes = bboxes[None, :] + bboxes = bboxes[:, :4] gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8) if self.add_gt_as_proposals and len(gt_bboxes) > 0: + if gt_labels is None: + raise ValueError( + 'gt_labels must be given when add_gt_as_proposals is True') bboxes = torch.cat([gt_bboxes, bboxes], dim=0) assign_result.add_gt_(gt_labels) gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8) @@ -74,5 +93,6 @@ def sample(self, assign_result, num_expected_neg, bboxes=bboxes, **kwargs) neg_inds = neg_inds.unique() - return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, - assign_result, gt_flags) + sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes, + assign_result, gt_flags) + return sampling_result diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py index 0d02b2747fd..3db00bab0eb 100644 --- a/mmdet/core/bbox/samplers/random_sampler.py +++ b/mmdet/core/bbox/samplers/random_sampler.py @@ -12,11 +12,12 @@ def __init__(self, neg_pos_ub=-1, add_gt_as_proposals=True, **kwargs): + from mmdet.core.bbox import demodata super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals) + self.rng = demodata.ensure_rng(kwargs.get('rng', None)) - @staticmethod - def random_choice(gallery, num): + def random_choice(self, gallery, num): """Random select some elements from the gallery. It seems that Pytorch's implementation is slower than numpy so we use @@ -26,7 +27,7 @@ def random_choice(gallery, num): if isinstance(gallery, list): gallery = np.array(gallery) cands = np.arange(len(gallery)) - np.random.shuffle(cands) + self.rng.shuffle(cands) rand_inds = cands[:num] if not isinstance(gallery, np.ndarray): rand_inds = torch.from_numpy(rand_inds).long().to(gallery.device) diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py index 696e6509710..dcf25eecd67 100644 --- a/mmdet/core/bbox/samplers/sampling_result.py +++ b/mmdet/core/bbox/samplers/sampling_result.py @@ -1,7 +1,25 @@ import torch +from mmdet.utils import util_mixins -class SamplingResult(object): + +class SamplingResult(util_mixins.NiceRepr): + """ + Example: + >>> # xdoctest: +IGNORE_WANT + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random(rng=10) + >>> print('self = {}'.format(self)) + self = + """ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, gt_flags): @@ -13,7 +31,17 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, self.num_gts = gt_bboxes.shape[0] self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1 - self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] + + if gt_bboxes.numel() == 0: + # hack for index error case + assert self.pos_assigned_gt_inds.numel() == 0 + self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4) + else: + if len(gt_bboxes.shape) < 2: + gt_bboxes = gt_bboxes.view(-1, 4) + + self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :] + if assign_result.labels is not None: self.pos_gt_labels = assign_result.labels[pos_inds] else: @@ -22,3 +50,105 @@ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result, @property def bboxes(self): return torch.cat([self.pos_bboxes, self.neg_bboxes]) + + def to(self, device): + """ + Change the device of the data inplace. + + Example: + >>> self = SamplingResult.random() + >>> print('self = {}'.format(self.to(None))) + >>> # xdoctest: +REQUIRES(--gpu) + >>> print('self = {}'.format(self.to(0))) + """ + _dict = self.__dict__ + for key, value in _dict.items(): + if isinstance(value, torch.Tensor): + _dict[key] = value.to(device) + return self + + def __nice__(self): + data = self.info.copy() + data['pos_bboxes'] = data.pop('pos_bboxes').shape + data['neg_bboxes'] = data.pop('neg_bboxes').shape + parts = ['\'{}\': {!r}'.format(k, v) for k, v in sorted(data.items())] + body = ' ' + ',\n '.join(parts) + return '{\n' + body + '\n}' + + @property + def info(self): + """ + Returns a dictionary of info about the object + """ + return { + 'pos_inds': self.pos_inds, + 'neg_inds': self.neg_inds, + 'pos_bboxes': self.pos_bboxes, + 'neg_bboxes': self.neg_bboxes, + 'pos_is_gt': self.pos_is_gt, + 'num_gts': self.num_gts, + 'pos_assigned_gt_inds': self.pos_assigned_gt_inds, + } + + @classmethod + def random(cls, rng=None, **kwargs): + """ + Args: + rng (None | int | numpy.random.RandomState): seed or state + + Kwargs: + num_preds: number of predicted boxes + num_gts: number of true boxes + p_ignore (float): probability of a predicted box assinged to an + ignored truth + p_assigned (float): probability of a predicted box not being + assigned + p_use_label (float | bool): with labels or not + + Returns: + AssignResult : + + Example: + >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA + >>> self = SamplingResult.random() + >>> print(self.__dict__) + """ + from mmdet.core.bbox.samplers.random_sampler import RandomSampler + from mmdet.core.bbox.assigners.assign_result import AssignResult + from mmdet.core.bbox import demodata + rng = demodata.ensure_rng(rng) + + # make probabalistic? + num = 32 + pos_fraction = 0.5 + neg_pos_ub = -1 + + assign_result = AssignResult.random(rng=rng, **kwargs) + + # Note we could just compute an assignment + bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng) + gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng) + + if rng.rand() > 0.2: + # sometimes algorithms squeeze their data, be robust to that + gt_bboxes = gt_bboxes.squeeze() + bboxes = bboxes.squeeze() + + if assign_result.labels is None: + gt_labels = None + else: + gt_labels = None # todo + + if gt_labels is None: + add_gt_as_proposals = False + else: + add_gt_as_proposals = True # make probabalistic? + + sampler = RandomSampler( + num, + pos_fraction, + neg_pos_ubo=neg_pos_ub, + add_gt_as_proposals=add_gt_as_proposals, + rng=rng) + self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels) + return self diff --git a/mmdet/utils/util_mixins.py b/mmdet/utils/util_mixins.py new file mode 100644 index 00000000000..5585ac65273 --- /dev/null +++ b/mmdet/utils/util_mixins.py @@ -0,0 +1,105 @@ +# -*- coding: utf-8 -*- +""" +This module defines the :class:`NiceRepr` mixin class, which defines a +``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__`` +method, which you must define. This means you only have to overload one +function instead of two. Furthermore, if the object defines a ``__len__`` +method, then the ``__nice__`` method defaults to something sensible, otherwise +it is treated as abstract and raises ``NotImplementedError``. + +To use simply have your object inherit from :class:`NiceRepr` +(multi-inheritance should be ok). + +This code was copied from the ubelt library: https://github.com/Erotemic/ubelt + +Example: + >>> # Objects that define __nice__ have a default __str__ and __repr__ + >>> class Student(NiceRepr): + ... def __init__(self, name): + ... self.name = name + ... def __nice__(self): + ... return self.name + >>> s1 = Student('Alice') + >>> s2 = Student('Bob') + >>> print('s1 = {}'.format(s1)) + >>> print('s2 = {}'.format(s2)) + s1 = + s2 = + +Example: + >>> # Objects that define __len__ have a default __nice__ + >>> class Group(NiceRepr): + ... def __init__(self, data): + ... self.data = data + ... def __len__(self): + ... return len(self.data) + >>> g = Group([1, 2, 3]) + >>> print('g = {}'.format(g)) + g = + +""" +import warnings + + +class NiceRepr(object): + """ + Inherit from this class and define ``__nice__`` to "nicely" print your + objects. + + Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function + Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``. + If the inheriting class has a ``__len__``, method then the default + ``__nice__`` method will return its length. + + Example: + >>> class Foo(NiceRepr): + ... def __nice__(self): + ... return 'info' + >>> foo = Foo() + >>> assert str(foo) == '' + >>> assert repr(foo).startswith('>> class Bar(NiceRepr): + ... pass + >>> bar = Bar() + >>> import pytest + >>> with pytest.warns(None) as record: + >>> assert 'object at' in str(bar) + >>> assert 'object at' in repr(bar) + + Example: + >>> class Baz(NiceRepr): + ... def __len__(self): + ... return 5 + >>> baz = Baz() + >>> assert str(baz) == '' + """ + + def __nice__(self): + if hasattr(self, '__len__'): + # It is a common pattern for objects to use __len__ in __nice__ + # As a convenience we define a default __nice__ for these objects + return str(len(self)) + else: + # In all other cases force the subclass to overload __nice__ + raise NotImplementedError( + 'Define the __nice__ method for {!r}'.format(self.__class__)) + + def __repr__(self): + try: + nice = self.__nice__() + classname = self.__class__.__name__ + return '<{0}({1}) at {2}>'.format(classname, nice, hex(id(self))) + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) + + def __str__(self): + try: + classname = self.__class__.__name__ + nice = self.__nice__() + return '<{0}({1})>'.format(classname, nice) + except NotImplementedError as ex: + warnings.warn(str(ex), category=RuntimeWarning) + return object.__repr__(self) diff --git a/tests/test_assigner.py b/tests/test_assigner.py index 50cf7d530ee..5348eaba3a3 100644 --- a/tests/test_assigner.py +++ b/tests/test_assigner.py @@ -259,3 +259,19 @@ def test_approx_iou_assigner_with_empty_boxes_and_gt(): assign_result = self.assign(approxs, squares, approxs_per_octave, gt_bboxes) assert len(assign_result.gt_inds) == 0 + + +def test_random_assign_result(): + """ + Test random instantiation of assign result to catch corner cases + """ + from mmdet.core.bbox.assigners.assign_result import AssignResult + AssignResult.random() + + AssignResult.random(num_gts=0, num_preds=0) + AssignResult.random(num_gts=0, num_preds=3) + AssignResult.random(num_gts=3, num_preds=3) + AssignResult.random(num_gts=0, num_preds=3) + AssignResult.random(num_gts=7, num_preds=7) + AssignResult.random(num_gts=7, num_preds=64) + AssignResult.random(num_gts=24, num_preds=3) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index c375d6e6f9a..c75360268e6 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -233,3 +233,17 @@ def test_ohem_sampler_empty_pred(): assert len(sample_result.pos_bboxes) == len(sample_result.pos_inds) assert len(sample_result.neg_bboxes) == len(sample_result.neg_inds) + + +def test_random_sample_result(): + from mmdet.core.bbox.samplers.sampling_result import SamplingResult + SamplingResult.random(num_gts=0, num_preds=0) + SamplingResult.random(num_gts=0, num_preds=3) + SamplingResult.random(num_gts=3, num_preds=3) + SamplingResult.random(num_gts=0, num_preds=3) + SamplingResult.random(num_gts=7, num_preds=7) + SamplingResult.random(num_gts=7, num_preds=64) + SamplingResult.random(num_gts=24, num_preds=3) + + for i in range(3): + SamplingResult.random(rng=i)