Skip to content

Commit

Permalink
Add sampler and assigner registry (open-mmlab#2419)
Browse files Browse the repository at this point in the history
* add sampler and assigner registry

* rename with bbox prefix

* restore __init__

* roll back atss_head

* change import level

* import from sampler/assigner
  • Loading branch information
xvjiarui authored Apr 10, 2020
1 parent 3b5983e commit 7446375
Show file tree
Hide file tree
Showing 16 changed files with 56 additions and 38 deletions.
2 changes: 1 addition & 1 deletion mmdet/core/bbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
bbox_mapping, bbox_mapping_back, delta2bbox,
distance2bbox, roi2bbox)

from .assign_sampling import ( # isort:skip, avoid recursive imports
from .builder import ( # isort:skip, avoid recursive imports
assign_and_sample, build_assigner, build_sampler)

__all__ = [
Expand Down
33 changes: 0 additions & 33 deletions mmdet/core/bbox/assign_sampling.py

This file was deleted.

2 changes: 2 additions & 0 deletions mmdet/core/bbox/assigners/approx_max_iou_assigner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from ..geometry import bbox_overlaps
from ..registry import BBOX_ASSIGNERS
from .max_iou_assigner import MaxIoUAssigner


@BBOX_ASSIGNERS.register_module
class ApproxMaxIoUAssigner(MaxIoUAssigner):
"""Assign a corresponding gt bbox or background to each bbox.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/assigners/atss_assigner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch

from ..geometry import bbox_overlaps
from ..registry import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


@BBOX_ASSIGNERS.register_module
class ATSSAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each bbox.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/assigners/max_iou_assigner.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import torch

from ..geometry import bbox_overlaps
from ..registry import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


@BBOX_ASSIGNERS.register_module
class MaxIoUAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each bbox.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/assigners/point_assigner.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from ..registry import BBOX_ASSIGNERS
from .assign_result import AssignResult
from .base_assigner import BaseAssigner


@BBOX_ASSIGNERS.register_module
class PointAssigner(BaseAssigner):
"""Assign a corresponding gt bbox or background to each point.
Expand Down
27 changes: 27 additions & 0 deletions mmdet/core/bbox/builder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
from mmdet.utils import build_from_cfg
from .assigners import BaseAssigner
from .registry import BBOX_ASSIGNERS, BBOX_SAMPLERS
from .samplers import BaseSampler


def build_assigner(cfg, **default_args):
if isinstance(cfg, BaseAssigner):
return cfg
return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)


def build_sampler(cfg, **default_args):
if isinstance(cfg, BaseSampler):
return cfg
return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)


# TODO remove this function in anchor_target in the future
def assign_and_sample(bboxes, gt_bboxes, gt_bboxes_ignore, gt_labels, cfg):
bbox_assigner = build_assigner(cfg.assigner)
bbox_sampler = build_sampler(cfg.sampler)
assign_result = bbox_assigner.assign(bboxes, gt_bboxes, gt_bboxes_ignore,
gt_labels)
sampling_result = bbox_sampler.sample(assign_result, bboxes, gt_bboxes,
gt_labels)
return assign_result, sampling_result
4 changes: 4 additions & 0 deletions mmdet/core/bbox/registry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from mmdet.utils import Registry

BBOX_ASSIGNERS = Registry('bbox_assigner')
BBOX_SAMPLERS = Registry('bbox_sampler')
4 changes: 3 additions & 1 deletion mmdet/core/bbox/samplers/combined_sampler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from ..assign_sampling import build_sampler
from ..builder import build_sampler
from ..registry import BBOX_SAMPLERS
from .base_sampler import BaseSampler


@BBOX_SAMPLERS.register_module
class CombinedSampler(BaseSampler):

def __init__(self, pos_sampler, neg_sampler, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import torch

from ..registry import BBOX_SAMPLERS
from .random_sampler import RandomSampler


@BBOX_SAMPLERS.register_module
class InstanceBalancedPosSampler(RandomSampler):

def _sample_pos(self, assign_result, num_expected, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import numpy as np
import torch

from ..registry import BBOX_SAMPLERS
from .random_sampler import RandomSampler


@BBOX_SAMPLERS.register_module
class IoUBalancedNegSampler(RandomSampler):
"""IoU Balanced Sampling
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/samplers/ohem_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from ..registry import BBOX_SAMPLERS
from ..transforms import bbox2roi
from .base_sampler import BaseSampler


@BBOX_SAMPLERS.register_module
class OHEMSampler(BaseSampler):
"""
Online Hard Example Mining Sampler described in [1]_.
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/samplers/pseudo_sampler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import torch

from ..registry import BBOX_SAMPLERS
from .base_sampler import BaseSampler
from .sampling_result import SamplingResult


@BBOX_SAMPLERS.register_module
class PseudoSampler(BaseSampler):

def __init__(self, **kwargs):
Expand Down
2 changes: 2 additions & 0 deletions mmdet/core/bbox/samplers/random_sampler.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import torch

from ..registry import BBOX_SAMPLERS
from .base_sampler import BaseSampler


@BBOX_SAMPLERS.register_module
class RandomSampler(BaseSampler):

def __init__(self,
Expand Down
4 changes: 2 additions & 2 deletions tests/test_assigner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
"""
import torch

from mmdet.core import MaxIoUAssigner
from mmdet.core.bbox.assigners import ApproxMaxIoUAssigner, PointAssigner
from mmdet.core.bbox.assigners import (ApproxMaxIoUAssigner, MaxIoUAssigner,
PointAssigner)


def test_max_iou_assigner():
Expand Down
2 changes: 1 addition & 1 deletion tests/test_sampler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import torch

from mmdet.core import MaxIoUAssigner
from mmdet.core.bbox.assigners import MaxIoUAssigner
from mmdet.core.bbox.samplers import OHEMSampler, RandomSampler


Expand Down

0 comments on commit 7446375

Please sign in to comment.