Skip to content

Commit

Permalink
[Enchance] support infererence with padding (#1607)
Browse files Browse the repository at this point in the history
* [Enchance] support infererence with padding

* limite pad after flip when inference

* add test code
  • Loading branch information
FreyWang authored Jun 15, 2022
1 parent 2dede04 commit 6f43f4d
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 0 deletions.
8 changes: 8 additions & 0 deletions mmseg/datasets/pipelines/test_time_aug.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,14 @@ def __init__(self,
img_ratios=None,
flip=False,
flip_direction='horizontal'):
if flip:
trans_index = {
key['type']: index
for index, key in enumerate(transforms)
}
if 'RandomFlip' in trans_index and 'Pad' in trans_index:
assert trans_index['RandomFlip'] < trans_index['Pad'], \
'Pad must be executed after RandomFlip when flip is True'
self.transforms = Compose(transforms)
if img_ratios is not None:
img_ratios = img_ratios if isinstance(img_ratios,
Expand Down
6 changes: 6 additions & 0 deletions mmseg/models/segmentors/encoder_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def slide_inference(self, img, img_meta, rescale):
count_mat.cpu().detach().numpy()).to(device=img.device)
preds = preds / count_mat
if rescale:
# remove padding area
resize_shape = img_meta[0]['img_shape'][:2]
preds = preds[:, :, :resize_shape[0], :resize_shape[1]]
preds = resize(
preds,
size=img_meta[0]['ori_shape'][:2],
Expand All @@ -206,6 +209,9 @@ def whole_inference(self, img, img_meta, rescale):
if torch.onnx.is_in_onnx_export():
size = img.shape[2:]
else:
# remove padding area
resize_shape = img_meta[0]['img_shape'][:2]
seg_logit = seg_logit[:, :, :resize_shape[0], :resize_shape[1]]
size = img_meta[0]['ori_shape'][:2]
seg_logit = resize(
seg_logit,
Expand Down
38 changes: 38 additions & 0 deletions tests/test_data/test_tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,3 +149,41 @@ def test_multi_scale_flip_aug():
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]

# test assertion if flip is True and Pad executed before RandomFlip
with pytest.raises(AssertionError):
tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[
dict(type='Resize', keep_ratio=False),
dict(type='Pad', size_divisor=32),
dict(type='RandomFlip'),
])
tta_module = build_from_cfg(tta_transform, PIPELINES)

tta_transform = dict(
type='MultiScaleFlipAug',
img_scale=[(256, 256), (512, 512), (1024, 1024)],
img_ratios=None,
flip=True,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Pad', size_divisor=32),
])
tta_module = build_from_cfg(tta_transform, PIPELINES)
tta_results = tta_module(results.copy())
assert tta_results['scale'] == [(256, 256), (256, 256), (512, 512),
(512, 512), (1024, 1024), (1024, 1024)]
assert tta_results['flip'] == [False, True, False, True, False, True]
assert tta_results['img_shape'] == [(144, 256, 3), (144, 256, 3),
(288, 512, 3), (288, 512, 3),
(576, 1024, 3), (576, 1024, 3)]
assert tta_results['pad_shape'] == [(160, 256, 3), (160, 256, 3),
(288, 512, 3), (288, 512, 3),
(576, 1024, 3), (576, 1024, 3)]
for i in range(len(tta_results['img'])):
assert tta_results['img'][i].shape == tta_results['pad_shape'][i]

0 comments on commit 6f43f4d

Please sign in to comment.