diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py index 1cac7577dde..4be66b42090 100644 --- a/mmdet/apis/inference.py +++ b/mmdet/apis/inference.py @@ -8,6 +8,7 @@ from mmcv.runner import load_checkpoint from mmdet.core import get_classes +from mmdet.datasets import replace_ImageToTensor from mmdet.datasets.pipelines import Compose from mmdet.models import build_detector @@ -103,9 +104,13 @@ def inference_detector(model, img): # add information into dict data = dict(img_info=dict(filename=img), img_prefix=None) # build the data pipeline + cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline) test_pipeline = Compose(cfg.data.test.pipeline) data = test_pipeline(data) data = collate([data], samples_per_gpu=1) + # just get the actual data from DataContainer + data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']] + data['img'] = [img.data[0] for img in data['img']] if next(model.parameters()).is_cuda: # scatter to specified GPU data = scatter(data, [device])[0] @@ -114,8 +119,6 @@ def inference_detector(model, img): assert not isinstance( m, RoIPool ), 'CPU inference with RoIPool is not supported currently.' - # just get the actual data from DataContainer - data['img_metas'] = data['img_metas'][0].data # forward the model with torch.no_grad():