Skip to content

Commit 06430d5

Browse files
committed
Merge branch 'master' of https://github.com/open-mmlab/mmdetection into dev/optional_albumentations
2 parents 8bd97a2 + e032ebb commit 06430d5

File tree

6 files changed

+280
-9
lines changed

6 files changed

+280
-9
lines changed

configs/free_anchor/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,3 +21,4 @@
2121

2222
**Notes:**
2323
- We use 8 GPUs with 2 images/GPU.
24+
- For more settings and models, please refer to the [official repo](https://github.com/zhangxiaosong18/FreeAnchor).

configs/libra_rcnn/libra_retinanet_r50_fpn_1x.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
in_channels=[256, 512, 1024, 2048],
1616
out_channels=256,
1717
start_level=1,
18-
extra_convs_on_inputs=True,
1918
add_extra_convs=True,
2019
num_outs=5),
2120
dict(
@@ -57,9 +56,6 @@
5756
neg_iou_thr=0.4,
5857
min_pos_iou=0,
5958
ignore_iof_thr=-1),
60-
smoothl1_beta=0.11,
61-
gamma=2.0,
62-
alpha=0.25,
6359
allowed_border=-1,
6460
pos_weight=-1,
6561
debug=False)

mmdet/core/bbox/demodata.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
import numpy as np
2+
import torch
3+
4+
5+
def ensure_rng(rng=None):
6+
"""
7+
Simple version of the ``kwarray.ensure_rng``
8+
9+
Args:
10+
rng (int | numpy.random.RandomState | None):
11+
if None, then defaults to the global rng. Otherwise this can be an
12+
integer or a RandomState class
13+
Returns:
14+
(numpy.random.RandomState) : rng -
15+
a numpy random number generator
16+
17+
References:
18+
https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270
19+
"""
20+
21+
if rng is None:
22+
rng = np.random.mtrand._rand
23+
elif isinstance(rng, int):
24+
rng = np.random.RandomState(rng)
25+
else:
26+
rng = rng
27+
return rng
28+
29+
30+
def random_boxes(num=1, scale=1, rng=None):
31+
"""
32+
Simple version of ``kwimage.Boxes.random``
33+
34+
Returns:
35+
Tensor: shape (n, 4) in x1, y1, x2, y2 format.
36+
37+
References:
38+
https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
39+
40+
Example:
41+
>>> num = 3
42+
>>> scale = 512
43+
>>> rng = 0
44+
>>> boxes = random_boxes(num, scale, rng)
45+
>>> print(boxes)
46+
tensor([[280.9925, 278.9802, 308.6148, 366.1769],
47+
[216.9113, 330.6978, 224.0446, 456.5878],
48+
[405.3632, 196.3221, 493.3953, 270.7942]])
49+
"""
50+
rng = ensure_rng(rng)
51+
52+
tlbr = rng.rand(num, 4).astype(np.float32)
53+
54+
tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
55+
tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
56+
br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
57+
br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
58+
59+
tlbr[:, 0] = tl_x * scale
60+
tlbr[:, 1] = tl_y * scale
61+
tlbr[:, 2] = br_x * scale
62+
tlbr[:, 3] = br_y * scale
63+
64+
boxes = torch.from_numpy(tlbr)
65+
return boxes

mmdet/models/bbox_heads/bbox_head.py

Lines changed: 41 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,8 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
178178
179179
Args:
180180
rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
181-
and bs is the sampled RoIs per image.
181+
and bs is the sampled RoIs per image. The first column is
182+
the image id and the next 4 columns are x1, y1, x2, y2.
182183
labels (Tensor): Shape (n*bs, ).
183184
bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
184185
pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
@@ -187,13 +188,48 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
187188
188189
Returns:
189190
list[Tensor]: Refined bboxes of each image in a mini-batch.
191+
192+
Example:
193+
>>> # xdoctest: +REQUIRES(module:kwarray)
194+
>>> import kwarray
195+
>>> import numpy as np
196+
>>> from mmdet.core.bbox.demodata import random_boxes
197+
>>> self = BBoxHead(reg_class_agnostic=True)
198+
>>> n_roi = 2
199+
>>> n_img = 4
200+
>>> scale = 512
201+
>>> rng = np.random.RandomState(0)
202+
>>> img_metas = [{'img_shape': (scale, scale)}
203+
... for _ in range(n_img)]
204+
>>> # Create rois in the expected format
205+
>>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
206+
>>> img_ids = torch.randint(0, n_img, (n_roi,))
207+
>>> img_ids = img_ids.float()
208+
>>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
209+
>>> # Create other args
210+
>>> labels = torch.randint(0, 2, (n_roi,)).long()
211+
>>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
212+
>>> # For each image, pretend random positive boxes are gts
213+
>>> is_label_pos = (labels.numpy() > 0).astype(np.int)
214+
>>> lbl_per_img = kwarray.group_items(is_label_pos,
215+
... img_ids.numpy())
216+
>>> pos_per_img = [sum(lbl_per_img.get(gid, []))
217+
... for gid in range(n_img)]
218+
>>> pos_is_gts = [
219+
>>> torch.randint(0, 2, (npos,)).byte().sort(
220+
>>> descending=True)[0]
221+
>>> for npos in pos_per_img
222+
>>> ]
223+
>>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
224+
>>> pos_is_gts, img_metas)
225+
>>> print(bboxes_list)
190226
"""
191227
img_ids = rois[:, 0].long().unique(sorted=True)
192-
assert img_ids.numel() == len(img_metas)
228+
assert img_ids.numel() <= len(img_metas)
193229

194230
bboxes_list = []
195231
for i in range(len(img_metas)):
196-
inds = torch.nonzero(rois[:, 0] == i).squeeze()
232+
inds = torch.nonzero(rois[:, 0] == i).squeeze(dim=1)
197233
num_rois = inds.numel()
198234

199235
bboxes_ = rois[inds, 1:]
@@ -204,6 +240,7 @@ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
204240

205241
bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
206242
img_meta_)
243+
207244
# filter gt bboxes
208245
pos_keep = 1 - pos_is_gts_
209246
keep_inds = pos_is_gts_.new_ones(num_rois)
@@ -226,7 +263,7 @@ def regress_by_class(self, rois, label, bbox_pred, img_meta):
226263
Returns:
227264
Tensor: Regressed bboxes, the same shape as input rois.
228265
"""
229-
assert rois.size(1) == 4 or rois.size(1) == 5
266+
assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
230267

231268
if not self.reg_class_agnostic:
232269
label = label * 4

requirements/tests.txt

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,5 +4,8 @@ flake8
44
isort
55
pytest
66
pytest-cov
7-
xdoctest >= 0.10.0
7+
xdoctest>=0.10.0
88
yapf
9+
10+
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
11+
kwarray

tests/test_heads.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,3 +169,172 @@ def _dummy_bbox_sampling(proposal_list, gt_bboxes, gt_labels):
169169
bbox_targets, bbox_weights)
170170
assert losses.get('loss_cls', 0) > 0, 'cls-loss should be non-zero'
171171
assert losses.get('loss_bbox', 0) > 0, 'box-loss should be non-zero'
172+
173+
174+
def test_refine_boxes():
175+
"""
176+
Mirrors the doctest in
177+
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` but checks for
178+
multiple values of n_roi / n_img.
179+
"""
180+
self = BBoxHead(reg_class_agnostic=True)
181+
182+
test_settings = [
183+
184+
# Corner case: less rois than images
185+
{
186+
'n_roi': 2,
187+
'n_img': 4,
188+
'rng': 34285940
189+
},
190+
191+
# Corner case: no images
192+
{
193+
'n_roi': 0,
194+
'n_img': 0,
195+
'rng': 52925222
196+
},
197+
198+
# Corner cases: few images / rois
199+
{
200+
'n_roi': 1,
201+
'n_img': 1,
202+
'rng': 1200281
203+
},
204+
{
205+
'n_roi': 2,
206+
'n_img': 1,
207+
'rng': 1200282
208+
},
209+
{
210+
'n_roi': 2,
211+
'n_img': 2,
212+
'rng': 1200283
213+
},
214+
{
215+
'n_roi': 1,
216+
'n_img': 2,
217+
'rng': 1200284
218+
},
219+
220+
# Corner case: no rois few images
221+
{
222+
'n_roi': 0,
223+
'n_img': 1,
224+
'rng': 23955860
225+
},
226+
{
227+
'n_roi': 0,
228+
'n_img': 2,
229+
'rng': 25830516
230+
},
231+
232+
# Corner case: no rois many images
233+
{
234+
'n_roi': 0,
235+
'n_img': 10,
236+
'rng': 671346
237+
},
238+
{
239+
'n_roi': 0,
240+
'n_img': 20,
241+
'rng': 699807
242+
},
243+
244+
# Corner case: similar num rois and images
245+
{
246+
'n_roi': 20,
247+
'n_img': 20,
248+
'rng': 1200238
249+
},
250+
{
251+
'n_roi': 10,
252+
'n_img': 20,
253+
'rng': 1200238
254+
},
255+
{
256+
'n_roi': 5,
257+
'n_img': 5,
258+
'rng': 1200238
259+
},
260+
261+
# ----------------------------------
262+
# Common case: more rois than images
263+
{
264+
'n_roi': 100,
265+
'n_img': 1,
266+
'rng': 337156
267+
},
268+
{
269+
'n_roi': 150,
270+
'n_img': 2,
271+
'rng': 275898
272+
},
273+
{
274+
'n_roi': 500,
275+
'n_img': 5,
276+
'rng': 4903221
277+
},
278+
]
279+
280+
for demokw in test_settings:
281+
try:
282+
n_roi = demokw['n_roi']
283+
n_img = demokw['n_img']
284+
rng = demokw['rng']
285+
286+
print('Test refine_boxes case: {!r}'.format(demokw))
287+
tup = _demodata_refine_boxes(n_roi, n_img, rng=rng)
288+
rois, labels, bbox_preds, pos_is_gts, img_metas = tup
289+
bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
290+
pos_is_gts, img_metas)
291+
assert len(bboxes_list) == n_img
292+
assert sum(map(len, bboxes_list)) <= n_roi
293+
assert all(b.shape[1] == 4 for b in bboxes_list)
294+
except Exception:
295+
print('Test failed with demokw={!r}'.format(demokw))
296+
raise
297+
298+
299+
def _demodata_refine_boxes(n_roi, n_img, rng=0):
300+
"""
301+
Create random test data for the
302+
``mmdet.models.bbox_heads.bbox_head.BBoxHead.refine_boxes`` method
303+
"""
304+
import numpy as np
305+
from mmdet.core.bbox.demodata import random_boxes
306+
from mmdet.core.bbox.demodata import ensure_rng
307+
try:
308+
import kwarray
309+
except ImportError:
310+
import pytest
311+
pytest.skip('kwarray is required for this test')
312+
scale = 512
313+
rng = ensure_rng(rng)
314+
img_metas = [{'img_shape': (scale, scale)} for _ in range(n_img)]
315+
# Create rois in the expected format
316+
roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
317+
if n_img == 0:
318+
assert n_roi == 0, 'cannot have any rois if there are no images'
319+
img_ids = torch.empty((0, ), dtype=torch.long)
320+
roi_boxes = torch.empty((0, 4), dtype=torch.float32)
321+
else:
322+
img_ids = rng.randint(0, n_img, (n_roi, ))
323+
img_ids = torch.from_numpy(img_ids)
324+
rois = torch.cat([img_ids[:, None].float(), roi_boxes], dim=1)
325+
# Create other args
326+
labels = rng.randint(0, 2, (n_roi, ))
327+
labels = torch.from_numpy(labels).long()
328+
bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
329+
# For each image, pretend random positive boxes are gts
330+
is_label_pos = (labels.numpy() > 0).astype(np.int)
331+
lbl_per_img = kwarray.group_items(is_label_pos, img_ids.numpy())
332+
pos_per_img = [sum(lbl_per_img.get(gid, [])) for gid in range(n_img)]
333+
# randomly generate with numpy then sort with torch
334+
_pos_is_gts = [
335+
rng.randint(0, 2, (npos, )).astype(np.uint8) for npos in pos_per_img
336+
]
337+
pos_is_gts = [
338+
torch.from_numpy(p).sort(descending=True)[0] for p in _pos_is_gts
339+
]
340+
return rois, labels, bbox_preds, pos_is_gts, img_metas

0 commit comments

Comments
 (0)