forked from open-mmlab/mmsegmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
…1527) * refactor init_model and unit test * add topdown inference and minor modification to data pipelines * add topdown image demo * change bbox format from xywh to xyxy * resolve comments
- Loading branch information
Showing
129 changed files
with
1,761 additions
and
1,195 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
codec = dict( | ||
type='MSRAHeatmap', input_size=(192, 256), heatmap_size=(48, 64), sigma=2) | ||
|
||
# model settings | ||
model = dict( | ||
type='TopdownPoseEstimator', | ||
data_preprocessor=dict( | ||
type='PoseDataPreprocessor', | ||
mean=[123.675, 116.28, 103.53], | ||
std=[58.395, 57.12, 57.375], | ||
bgr_to_rgb=True), | ||
backbone=dict( | ||
type='HRNet', | ||
in_channels=3, | ||
extra=dict( | ||
stage1=dict( | ||
num_modules=1, | ||
num_branches=1, | ||
block='BOTTLENECK', | ||
num_blocks=(4, ), | ||
num_channels=(64, )), | ||
stage2=dict( | ||
num_modules=1, | ||
num_branches=2, | ||
block='BASIC', | ||
num_blocks=(4, 4), | ||
num_channels=(32, 64)), | ||
stage3=dict( | ||
num_modules=4, | ||
num_branches=3, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4), | ||
num_channels=(32, 64, 128)), | ||
stage4=dict( | ||
num_modules=3, | ||
num_branches=4, | ||
block='BASIC', | ||
num_blocks=(4, 4, 4, 4), | ||
num_channels=(32, 64, 128, 256))), | ||
), | ||
head=dict( | ||
type='HeatmapHead', | ||
in_channels=32, | ||
out_channels=17, | ||
deconv_out_channels=None, | ||
loss=dict(type='KeypointMSELoss', use_target_weight=True), | ||
decoder=codec)) | ||
|
||
# dataset settings | ||
dataset_type = 'CocoDataset' | ||
data_mode = 'topdown' | ||
data_root = 'data/coco/' | ||
|
||
file_client_args = dict(backend='disk') | ||
|
||
# pipelines | ||
train_pipeline = [ | ||
dict(type='LoadImage', file_client_args=file_client_args), | ||
dict(type='GetBboxCenterScale'), | ||
dict(type='RandomBboxTransform'), | ||
dict(type='RandomFlip', direction='horizontal'), | ||
dict(type='RandomHalfBody'), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='TopdownGenerateHeatmap', encoder=codec), | ||
dict(type='PackPoseInputs') | ||
] | ||
test_pipeline = [ | ||
dict(type='LoadImage', file_client_args=file_client_args), | ||
dict(type='GetBboxCenterScale'), | ||
dict(type='TopdownAffine', input_size=codec['input_size']), | ||
dict(type='PackPoseInputs') | ||
] | ||
|
||
train_dataloader = dict( | ||
batch_size=64, | ||
num_workers=2, | ||
persistent_workers=True, | ||
sampler=dict(type='DefaultSampler', shuffle=True), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/person_keypoints_train2017.json', | ||
data_prefix=dict(img='train2017/'), | ||
pipeline=train_pipeline, | ||
)) | ||
val_dataloader = dict( | ||
batch_size=32, | ||
num_workers=2, | ||
persistent_workers=True, | ||
drop_last=False, | ||
sampler=dict(type='DefaultSampler', shuffle=False), | ||
dataset=dict( | ||
type=dataset_type, | ||
data_root=data_root, | ||
ann_file='annotations/person_keypoints_val2017.json', | ||
data_prefix=dict(img='val2017/'), | ||
test_mode=True, | ||
pipeline=test_pipeline, | ||
)) | ||
test_dataloader = val_dataloader | ||
|
||
val_evaluator = dict( | ||
type='CocoMetric', | ||
ann_file=data_root + 'annotations/person_keypoints_val2017.json') | ||
test_evaluator = val_evaluator | ||
|
||
vis_backends = [dict(type='LocalVisBackend')] | ||
visualizer = dict( | ||
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from argparse import ArgumentParser | ||
|
||
from mmcv.image import imread | ||
|
||
from mmpose.apis import inference_topdown, init_model | ||
from mmpose.core.data_structures import PoseDataSample | ||
from mmpose.registry import VISUALIZERS | ||
from mmpose.utils import register_all_modules | ||
|
||
|
||
def parse_args(): | ||
parser = ArgumentParser() | ||
parser.add_argument('img', help='Image file') | ||
parser.add_argument('config', help='Config file') | ||
parser.add_argument('checkpoint', help='Checkpoint file') | ||
parser.add_argument('--out-file', default=None, help='Path to output file') | ||
parser.add_argument( | ||
'--device', default='cuda:0', help='Device used for inference') | ||
parser.add_argument( | ||
'--draw_heatmap', | ||
action='store_true', | ||
help='Visualize the predicted heatmap') | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
# register all modules in mmpose into the registries | ||
register_all_modules() | ||
|
||
# build the model from a config file and a checkpoint file | ||
if args.draw_heatmap: | ||
cfg_options = dict(model=dict(test_cfg=dict(output_heatmap=True))) | ||
else: | ||
cfg_options = None | ||
|
||
model = init_model( | ||
args.config, | ||
args.checkpoint, | ||
device=args.device, | ||
cfg_options=cfg_options) | ||
|
||
# init visualizer | ||
visualizer = VISUALIZERS.build(model.cfg.visualizer) | ||
visualizer.set_dataset_meta(model.dataset_meta) | ||
|
||
# inference a single image | ||
results = inference_topdown(model, args.img) | ||
results = PoseDataSample.merge(results) | ||
|
||
# show the results | ||
img = imread(args.img, channel_order='rgb') | ||
visualizer.add_datasample( | ||
'result', | ||
img, | ||
data_sample=results, | ||
draw_gt=False, | ||
draw_bbox=True, | ||
draw_heatmap=args.draw_heatmap, | ||
show=True, | ||
out_file=args.out_file) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,20 +1,4 @@ | ||
# Copyright (c) OpenMMLab. All rights reserved. | ||
from .inference import (collect_multi_frames, inference_bottom_up_pose_model, | ||
inference_top_down_pose_model, init_pose_model, | ||
process_mmdet_results, vis_pose_result) | ||
from .inference_3d import (extract_pose_sequence, inference_interhand_3d_model, | ||
inference_mesh_model, inference_pose_lifter_model, | ||
vis_3d_mesh_result, vis_3d_pose_result) | ||
from .inference_tracking import get_track_id, vis_pose_tracking_result | ||
from .test import multi_gpu_test, single_gpu_test | ||
from .train import init_random_seed, train_model | ||
from .inference import inference_topdown, init_model | ||
|
||
__all__ = [ | ||
'train_model', 'init_pose_model', 'inference_top_down_pose_model', | ||
'inference_bottom_up_pose_model', 'multi_gpu_test', 'single_gpu_test', | ||
'vis_pose_result', 'get_track_id', 'vis_pose_tracking_result', | ||
'inference_pose_lifter_model', 'vis_3d_pose_result', | ||
'inference_interhand_3d_model', 'extract_pose_sequence', | ||
'inference_mesh_model', 'vis_3d_mesh_result', 'process_mmdet_results', | ||
'init_random_seed', 'collect_multi_frames' | ||
] | ||
__all__ = ['init_model', 'inference_topdown'] |
Oops, something went wrong.