Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deprecate ImageToTensor in image_demo #4400

Merged
merged 15 commits into from
Jan 13, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Support DefaultFormatBundle in image_demo
  • Loading branch information
hhaAndroid committed Jan 6, 2021
commit dfd0a92b27062622215c33043a1908e6e9facfa5
7 changes: 5 additions & 2 deletions mmdet/apis/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from mmcv.runner import load_checkpoint

from mmdet.core import get_classes
from mmdet.datasets import replace_ImageToTensor
from mmdet.datasets.pipelines import Compose
from mmdet.models import build_detector

Expand Down Expand Up @@ -104,9 +105,13 @@ def inference_detector(model, img):
# add information into dict
data = dict(img_info=dict(filename=img), img_prefix=None)
# build the data pipeline
cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
test_pipeline = Compose(cfg.data.test.pipeline)
data = test_pipeline(data)
data = collate([data], samples_per_gpu=1)
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data
data['img'] = data['img'][0].data
if next(model.parameters()).is_cuda:
# scatter to specified GPU
data = scatter(data, [device])[0]
Expand All @@ -115,8 +120,6 @@ def inference_detector(model, img):
assert not isinstance(
m, RoIPool
), 'CPU inference with RoIPool is not supported currently.'
# just get the actual data from DataContainer
data['img_metas'] = data['img_metas'][0].data

# forward the model
with torch.no_grad():
Expand Down