-
Notifications
You must be signed in to change notification settings - Fork 412
/
picture_demo.py
65 lines (52 loc) · 1.78 KB
/
picture_demo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os
import re
import sys
sys.path.append('.')
import cv2
import math
import time
import scipy
import argparse
import matplotlib
import numpy as np
import pylab as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from collections import OrderedDict
from scipy.ndimage import generate_binary_structure
from scipy.ndimage import gaussian_filter, maximum_filter
from lib.network.rtpose_vgg import get_model
from lib.network import im_transform
from evaluate.coco_eval import get_outputs, handle_paf_and_heat
from lib.utils.common import Human, BodyPart, CocoPart, CocoColors, CocoPairsRender, draw_humans
from lib.utils.paf_to_pose import paf_to_pose_cpp
from lib.config import cfg, update_config
parser = argparse.ArgumentParser()
parser.add_argument('--cfg', help='experiment configure file name',
default='./experiments/vgg19_368x368_sgd.yaml', type=str)
parser.add_argument('--weight', type=str,
default='pose_model.pth')
parser.add_argument('opts',
help="Modify config options using the command-line",
default=None,
nargs=argparse.REMAINDER)
args = parser.parse_args()
# update config file
update_config(cfg, args)
model = get_model('vgg19')
model.load_state_dict(torch.load(args.weight))
model = torch.nn.DataParallel(model).cuda()
model.float()
model.eval()
test_image = './readme/ski.jpg'
oriImg = cv2.imread(test_image) # B,G,R order
shape_dst = np.min(oriImg.shape[0:2])
# Get results of original image
with torch.no_grad():
paf, heatmap, im_scale = get_outputs(oriImg, model, 'rtpose')
print(im_scale)
humans = paf_to_pose_cpp(heatmap, paf, cfg)
out = draw_humans(oriImg, humans)
cv2.imwrite('result.png',out)