diff --git a/mmdet/models/data_preprocessors/data_preprocessor.py b/mmdet/models/data_preprocessors/data_preprocessor.py index e848a0c61dc..21986fd4989 100644 --- a/mmdet/models/data_preprocessors/data_preprocessor.py +++ b/mmdet/models/data_preprocessors/data_preprocessor.py @@ -13,6 +13,7 @@ from torch import Tensor from mmdet.core.data_structures import DetDataSample +from mmdet.core.mask import BitmapMasks from mmdet.registry import MODELS @@ -149,14 +150,12 @@ def pad_gt_masks(self, """Pad gt_masks to shape of batch_input_shape.""" if 'masks' in batch_data_samples[0].gt_instances: for data_samples in batch_data_samples: + # BitmapMasks masks = data_samples.gt_instances.masks - h, w = masks.shape[-2:] - pad_h, pad_w = data_samples.batch_input_shape - data_samples.gt_instances.masks = F.pad( - masks, - pad=(0, pad_w - w, 0, pad_h - h), - mode='constant', - value=self.mask_pad_value) + assert isinstance(masks, BitmapMasks) + data_samples.gt_instances.masks = masks.pad( + data_samples.batch_input_shape, + pad_val=self.mask_pad_value) def pad_gt_sem_seg(self, batch_data_samples: Sequence[DetDataSample]) -> None: @@ -320,13 +319,11 @@ def forward( if self.pad_mask: for data_samples in batch_data_samples: + # BitmapMasks masks = data_samples.gt_instances.masks - h, w = masks.shape[-2:] - data_samples.gt_instances.masks = F.pad( - masks, - pad=(0, dst_w - w, 0, dst_h - h), - mode='constant', - value=self.mask_pad_value) + assert isinstance(masks, BitmapMasks) + data_samples.gt_instances.masks = masks.pad( + (dst_h, dst_w), pad_val=self.mask_pad_value) if self.pad_seg: for data_samples in batch_data_samples: diff --git a/mmdet/testing/_utils.py b/mmdet/testing/_utils.py index e8de58c4def..aedd762968e 100644 --- a/mmdet/testing/_utils.py +++ b/mmdet/testing/_utils.py @@ -58,13 +58,14 @@ def _rand_bboxes(rng, num_boxes, w, h): def _rand_masks(rng, num_boxes, bboxes, img_w, img_h): + from mmdet.core.mask import BitmapMasks masks = np.zeros((num_boxes, img_h, img_w)) for i, bbox in enumerate(bboxes): bbox = bbox.astype(np.int32) mask = (rng.rand(1, bbox[3] - bbox[1], bbox[2] - bbox[0]) > 0.3).astype(np.int) masks[i:i + 1, bbox[1]:bbox[3], bbox[0]:bbox[2]] = mask - return torch.from_numpy(masks) + return BitmapMasks(masks, height=img_h, width=img_w) def demo_mm_inputs(batch_size=2, diff --git a/tests/test_models/test_preprocessors/test_data_preprocessor.py b/tests/test_models/test_preprocessors/test_data_preprocessor.py index 1010dd39725..9a54284e9bd 100644 --- a/tests/test_models/test_preprocessors/test_data_preprocessor.py +++ b/tests/test_models/test_preprocessors/test_data_preprocessor.py @@ -91,7 +91,8 @@ def test_forward(self): packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint( 0, 256, (1, 9, 24)) mask_pad_sums = [ - x['data_sample'].gt_instances.masks.sum() for x in packed_inputs + x['data_sample'].gt_instances.masks.masks.sum() + for x in packed_inputs ] seg_pad_sums = [ x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs @@ -100,11 +101,11 @@ def test_forward(self): for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip( batch_data_samples, [(10, 24), (10, 24)], mask_pad_sums, seg_pad_sums): - self.assertEqual(data_samples.gt_instances.masks.shape[-2:], + self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:], expected_shape) self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:], expected_shape) - self.assertEqual(data_samples.gt_instances.masks.sum(), + self.assertEqual(data_samples.gt_instances.masks.masks.sum(), mask_pad_sum) self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(), seg_pad_sum) @@ -159,7 +160,8 @@ def test_batch_fixed_size_pad(self): packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint( 0, 256, (1, 9, 24)) mask_pad_sums = [ - x['data_sample'].gt_instances.masks.sum() for x in packed_inputs + x['data_sample'].gt_instances.masks.masks.sum() + for x in packed_inputs ] seg_pad_sums = [ x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs @@ -170,11 +172,11 @@ def test_batch_fixed_size_pad(self): for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip( batch_data_samples, [(32, 32), (32, 32)], mask_pad_sums, seg_pad_sums): - self.assertEqual(data_samples.gt_instances.masks.shape[-2:], + self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:], expected_shape) self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:], expected_shape) - self.assertEqual(data_samples.gt_instances.masks.sum(), + self.assertEqual(data_samples.gt_instances.masks.masks.sum(), mask_pad_sum) self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(), seg_pad_sum) @@ -204,7 +206,8 @@ def test_batch_fixed_size_pad(self): packed_inputs[1]['data_sample'].gt_sem_seg.sem_seg = torch.randint( 0, 256, (1, 9, 24)) mask_pad_sums = [ - x['data_sample'].gt_instances.masks.sum() for x in packed_inputs + x['data_sample'].gt_instances.masks.masks.sum() + for x in packed_inputs ] seg_pad_sums = [ x['data_sample'].gt_sem_seg.sem_seg.sum() for x in packed_inputs @@ -215,11 +218,11 @@ def test_batch_fixed_size_pad(self): for data_samples, expected_shape, mask_pad_sum, seg_pad_sum in zip( batch_data_samples, [(32, 32), (32, 32)], mask_pad_sums, seg_pad_sums): - self.assertEqual(data_samples.gt_instances.masks.shape[-2:], + self.assertEqual(data_samples.gt_instances.masks.masks.shape[-2:], expected_shape) self.assertEqual(data_samples.gt_sem_seg.sem_seg.shape[-2:], expected_shape) - self.assertEqual(data_samples.gt_instances.masks.sum(), + self.assertEqual(data_samples.gt_instances.masks.masks.sum(), mask_pad_sum) self.assertEqual(data_samples.gt_sem_seg.sem_seg.sum(), seg_pad_sum)