Skip to content

Commit

Permalink
Update pad in data preprocessor
Browse files Browse the repository at this point in the history
  • Loading branch information
chhluo authored and ZwwWayne committed Jul 19, 2022
1 parent 5f99381 commit 97b6e89
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 23 deletions.
23 changes: 10 additions & 13 deletions mmdet/models/data_preprocessors/data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from torch import Tensor

from mmdet.core.data_structures import DetDataSample
from mmdet.core.mask import BitmapMasks
from mmdet.registry import MODELS


Expand Down Expand Up @@ -149,14 +150,12 @@ def pad_gt_masks(self,
"""Pad gt_masks to shape of batch_input_shape."""
if 'masks' in batch_data_samples[0].gt_instances:
for data_samples in batch_data_samples:
# BitmapMasks
masks = data_samples.gt_instances.masks
h, w = masks.shape[-2:]
pad_h, pad_w = data_samples.batch_input_shape
data_samples.gt_instances.masks = F.pad(
masks,
pad=(0, pad_w - w, 0, pad_h - h),
mode='constant',
value=self.mask_pad_value)
assert isinstance(masks, BitmapMasks)
data_samples.gt_instances.masks = masks.pad(
data_samples.batch_input_shape,
pad_val=self.mask_pad_value)

def pad_gt_sem_seg(self,
batch_data_samples: Sequence[DetDataSample]) -> None:
Expand Down Expand Up @@ -320,13 +319,11 @@ def forward(

if self.pad_mask:
for data_samples in batch_data_samples:
# BitmapMasks
masks = data_samples.gt_instances.masks
h, w = masks.shape[-2:]
data_samples.gt_instances.masks = F.pad(
masks,
pad=(0, dst_w - w, 0, dst_h - h),
mode='constant',
value=self.mask_pad_value)
assert isinstance(masks, BitmapMasks)
data_samples.gt_instances.masks = masks.pad(
(dst_h, dst_w), pad_val=self.mask_pad_value)

if self.pad_seg:
for data_samples in batch_data_samples:
Expand Down
3 changes: 2 additions & 1 deletion mmdet/testing/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,13 +58,14 @@ def _rand_bboxes(rng, num_boxes, w, h):


def _rand_masks(rng, num_boxes, bboxes, img_w, img_h):
from mmdet.core.mask import BitmapMasks
masks = np.zeros((num_boxes, img_h, img_w))
for i, bbox in enumerate(bboxes):
bbox = bbox.astype(np.int32)
mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) >
0.3).astype(np.int)
masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask
return torch.from_numpy(masks)
return BitmapMasks(masks, height=img_h, width=img_w)


def demo_mm_inputs(batch_size=2,
Expand Down
21 changes: 12 additions & 9 deletions tests/test_models/test_preprocessors/test_data_preprocessor.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ def test_forward(self):
packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint(
0, 256, (1, 9, 24))
mask_pad_sums = [
x['data_sample'].gt_instances.masks.sum() for x in packed_inputs
x['data_sample'].gt_instances.masks.masks.sum()
for x in packed_inputs
]
seg_pad_sums = [
x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs
Expand All @@ -100,11 +101,11 @@ def test_forward(self):
for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip(
batch_data_samples, [(10, 24), (10, 24)], mask_pad_sums,
seg_pad_sums):
self.assertEqual(data_samples.gt_instances.masks.shape[-2:],
self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_instances.masks.sum(),
self.assertEqual(data_samples.gt_instances.masks.masks.sum(),
mask_pad_sum)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(),
seg_pad_sum)
Expand Down Expand Up @@ -159,7 +160,8 @@ def test_batch_fixed_size_pad(self):
packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint(
0, 256, (1, 9, 24))
mask_pad_sums = [
x['data_sample'].gt_instances.masks.sum() for x in packed_inputs
x['data_sample'].gt_instances.masks.masks.sum()
for x in packed_inputs
]
seg_pad_sums = [
x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs
Expand All @@ -170,11 +172,11 @@ def test_batch_fixed_size_pad(self):
for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip(
batch_data_samples, [(32, 32), (32, 32)], mask_pad_sums,
seg_pad_sums):
self.assertEqual(data_samples.gt_instances.masks.shape[-2:],
self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_instances.masks.sum(),
self.assertEqual(data_samples.gt_instances.masks.masks.sum(),
mask_pad_sum)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(),
seg_pad_sum)
Expand Down Expand Up @@ -204,7 +206,8 @@ def test_batch_fixed_size_pad(self):
packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint(
0, 256, (1, 9, 24))
mask_pad_sums = [
x['data_sample'].gt_instances.masks.sum() for x in packed_inputs
x['data_sample'].gt_instances.masks.masks.sum()
for x in packed_inputs
]
seg_pad_sums = [
x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs
Expand All @@ -215,11 +218,11 @@ def test_batch_fixed_size_pad(self):
for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip(
batch_data_samples, [(32, 32), (32, 32)], mask_pad_sums,
seg_pad_sums):
self.assertEqual(data_samples.gt_instances.masks.shape[-2:],
self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:],
expected_shape)
self.assertEqual(data_samples.gt_instances.masks.sum(),
self.assertEqual(data_samples.gt_instances.masks.masks.sum(),
mask_pad_sum)
self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(),
seg_pad_sum)

0 comments on commit 97b6e89

Please sign in to comment.