From 74463759f1f95449d9002fc592075434c6af30e3 Mon Sep 17 00:00:00 2001 From: Jerry Jiarui XU Date: Fri, 10 Apr 2020 19:08:47 +0800 Subject: [PATCH] Add sampler and assigner registry (#2419) * add sampler and assigner registry * rename with bbox prefix * restore __init__ * roll back atss_head * change import level * import from sampler/assigner --- mmdet/core/bbox/__init__.py | 2 +- mmdet/core/bbox/assign_sampling.py | 33 ------------------- .../bbox/assigners/approx_max_iou_assigner.py | 2 ++ mmdet/core/bbox/assigners/atss_assigner.py | 2 ++ mmdet/core/bbox/assigners/max_iou_assigner.py | 2 ++ mmdet/core/bbox/assigners/point_assigner.py | 2 ++ mmdet/core/bbox/builder.py | 27 +++++++++++++++ mmdet/core/bbox/registry.py | 4 +++ mmdet/core/bbox/samplers/combined_sampler.py | 4 ++- .../samplers/instance_balanced_pos_sampler.py | 2 ++ .../bbox/samplers/iou_balanced_neg_sampler.py | 2 ++ mmdet/core/bbox/samplers/ohem_sampler.py | 2 ++ mmdet/core/bbox/samplers/pseudo_sampler.py | 2 ++ mmdet/core/bbox/samplers/random_sampler.py | 2 ++ tests/test_assigner.py | 4 +-- tests/test_sampler.py | 2 +- 16 files changed, 56 insertions(+), 38 deletions(-) delete mode 100644 mmdet/core/bbox/assign_sampling.py create mode 100644 mmdet/core/bbox/builder.py create mode 100644 mmdet/core/bbox/registry.py diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py index a0de91724e8..4716509811f 100644 --- a/mmdet/core/bbox/__init__.py +++ b/mmdet/core/bbox/__init__.py @@ -8,7 +8,7 @@ bbox_mapping, bbox_mapping_back, delta2bbox, distance2bbox, roi2bbox) -from .assign_sampling import ( # isort:skip, avoid recursive imports +from .builder import ( # isort:skip, avoid recursive imports assign_and_sample, build_assigner, build_sampler) __all__ = [ diff --git a/mmdet/core/bbox/assign_sampling.py b/mmdet/core/bbox/assign_sampling.py deleted file mode 100644 index 4267174bbec..00000000000 --- a/mmdet/core/bbox/assign_sampling.py +++ /dev/null @@ -1,33 +0,0 @@ -import mmcv - -from . import assigners, samplers - - -def build_assigner(cfg, **kwargs): - if isinstance(cfg, assigners.BaseAssigner): - return cfg - elif isinstance(cfg, dict): - return mmcv.runner.obj_from_dict(cfg, assigners, default_args=kwargs) - else: - raise TypeError('Invalid type {} for building a sampler'.format( - type(cfg))) - - -def build_sampler(cfg, **kwargs): - if isinstance(cfg, samplers.BaseSampler): - return cfg - elif isinstance(cfg, dict): - return mmcv.runner.obj_from_dict(cfg, samplers, default_args=kwargs) - else: - raise TypeError('Invalid type {} for building a sampler'.format( - type(cfg))) - - -def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): - bbox_assigner = build_assigner(cfg.assigner) - bbox_sampler = build_sampler(cfg.sampler) - assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, - gt_labels) - sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, - gt_labels) - return assign_result, sampling_result diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py index 6bf23a1858f..df847174232 100644 --- a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py @@ -1,9 +1,11 @@ import torch from ..geometry import bbox_overlaps +from ..registry import BBOX_ASSIGNERS from .max_iou_assigner import MaxIoUAssigner +@BBOX_ASSIGNERS.register_module class ApproxMaxIoUAssigner(MaxIoUAssigner): """Assign a corresponding gt bbox or background to each bbox. diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py index e442ac709d8..c5300d896cc 100644 --- a/mmdet/core/bbox/assigners/atss_assigner.py +++ b/mmdet/core/bbox/assigners/atss_assigner.py @@ -1,10 +1,12 @@ import torch from ..geometry import bbox_overlaps +from ..registry import BBOX_ASSIGNERS from .assign_result import AssignResult from .base_assigner import BaseAssigner +@BBOX_ASSIGNERS.register_module class ATSSAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py index b51ad385e22..5505e6ea001 100644 --- a/mmdet/core/bbox/assigners/max_iou_assigner.py +++ b/mmdet/core/bbox/assigners/max_iou_assigner.py @@ -1,10 +1,12 @@ import torch from ..geometry import bbox_overlaps +from ..registry import BBOX_ASSIGNERS from .assign_result import AssignResult from .base_assigner import BaseAssigner +@BBOX_ASSIGNERS.register_module class MaxIoUAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each bbox. diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py index 263b3096c77..7e5f30f5057 100644 --- a/mmdet/core/bbox/assigners/point_assigner.py +++ b/mmdet/core/bbox/assigners/point_assigner.py @@ -1,9 +1,11 @@ import torch +from ..registry import BBOX_ASSIGNERS from .assign_result import AssignResult from .base_assigner import BaseAssigner +@BBOX_ASSIGNERS.register_module class PointAssigner(BaseAssigner): """Assign a corresponding gt bbox or background to each point. diff --git a/mmdet/core/bbox/builder.py b/mmdet/core/bbox/builder.py new file mode 100644 index 00000000000..40d7d951679 --- /dev/null +++ b/mmdet/core/bbox/builder.py @@ -0,0 +1,27 @@ +from mmdet.utils import build_from_cfg +from .assigners import BaseAssigner +from .registry import BBOX_ASSIGNERS, BBOX_SAMPLERS +from .samplers import BaseSampler + + +def build_assigner(cfg, **default_args): + if isinstance(cfg, BaseAssigner): + return cfg + return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args) + + +def build_sampler(cfg, **default_args): + if isinstance(cfg, BaseSampler): + return cfg + return build_from_cfg(cfg, BBOX_SAMPLERS, default_args) + + +# TODO remove this function in anchor_target in the future +def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg): + bbox_assigner = build_assigner(cfg.assigner) + bbox_sampler = build_sampler(cfg.sampler) + assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore, + gt_labels) + sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes, + gt_labels) + return assign_result, sampling_result diff --git a/mmdet/core/bbox/registry.py b/mmdet/core/bbox/registry.py new file mode 100644 index 00000000000..80a37c9c31a --- /dev/null +++ b/mmdet/core/bbox/registry.py @@ -0,0 +1,4 @@ +from mmdet.utils import Registry + +BBOX_ASSIGNERS = Registry('bbox_assigner') +BBOX_SAMPLERS = Registry('bbox_sampler') diff --git a/mmdet/core/bbox/samplers/combined_sampler.py b/mmdet/core/bbox/samplers/combined_sampler.py index 351a097f671..3fe709a6325 100644 --- a/mmdet/core/bbox/samplers/combined_sampler.py +++ b/mmdet/core/bbox/samplers/combined_sampler.py @@ -1,7 +1,9 @@ -from ..assign_sampling import build_sampler +from ..builder import build_sampler +from ..registry import BBOX_SAMPLERS from .base_sampler import BaseSampler +@BBOX_SAMPLERS.register_module class CombinedSampler(BaseSampler): def __init__(self, pos_sampler, neg_sampler, **kwargs): diff --git a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py index bc829a236c8..6b06291e275 100644 --- a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py +++ b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py @@ -1,9 +1,11 @@ import numpy as np import torch +from ..registry import BBOX_SAMPLERS from .random_sampler import RandomSampler +@BBOX_SAMPLERS.register_module class InstanceBalancedPosSampler(RandomSampler): def _sample_pos(self, assign_result, num_expected, **kwargs): diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py index d9239e0708d..b4df283bfb4 100644 --- a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py +++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py @@ -1,9 +1,11 @@ import numpy as np import torch +from ..registry import BBOX_SAMPLERS from .random_sampler import RandomSampler +@BBOX_SAMPLERS.register_module class IoUBalancedNegSampler(RandomSampler): """IoU Balanced Sampling diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py index 3d315c083e7..0ddd4988e79 100644 --- a/mmdet/core/bbox/samplers/ohem_sampler.py +++ b/mmdet/core/bbox/samplers/ohem_sampler.py @@ -1,9 +1,11 @@ import torch +from ..registry import BBOX_SAMPLERS from ..transforms import bbox2roi from .base_sampler import BaseSampler +@BBOX_SAMPLERS.register_module class OHEMSampler(BaseSampler): """ Online Hard Example Mining Sampler described in [1]_. diff --git a/mmdet/core/bbox/samplers/pseudo_sampler.py b/mmdet/core/bbox/samplers/pseudo_sampler.py index b4c2ea09b0f..7f2d6d5e02e 100644 --- a/mmdet/core/bbox/samplers/pseudo_sampler.py +++ b/mmdet/core/bbox/samplers/pseudo_sampler.py @@ -1,9 +1,11 @@ import torch +from ..registry import BBOX_SAMPLERS from .base_sampler import BaseSampler from .sampling_result import SamplingResult +@BBOX_SAMPLERS.register_module class PseudoSampler(BaseSampler): def __init__(self, **kwargs): diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py index 261ca9c62fa..e1fe59ce1b0 100644 --- a/mmdet/core/bbox/samplers/random_sampler.py +++ b/mmdet/core/bbox/samplers/random_sampler.py @@ -1,8 +1,10 @@ import torch +from ..registry import BBOX_SAMPLERS from .base_sampler import BaseSampler +@BBOX_SAMPLERS.register_module class RandomSampler(BaseSampler): def __init__(self, diff --git a/tests/test_assigner.py b/tests/test_assigner.py index a3904b38e10..6eb30f5b574 100644 --- a/tests/test_assigner.py +++ b/tests/test_assigner.py @@ -10,8 +10,8 @@ """ import torch -from mmdet.core import MaxIoUAssigner -from mmdet.core.bbox.assigners import ApproxMaxIoUAssigner, PointAssigner +from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, MaxIoUAssigner, + PointAssigner) def test_max_iou_assigner(): diff --git a/tests/test_sampler.py b/tests/test_sampler.py index 5afa16a8409..8ada5ed9c3e 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -1,6 +1,6 @@ import torch -from mmdet.core import MaxIoUAssigner +from mmdet.core.bbox.assigners import MaxIoUAssigner from mmdet.core.bbox.samplers import OHEMSampler, RandomSampler