Skip to content

Commit

Permalink
Working on COCO dataset, 2
Browse files Browse the repository at this point in the history
  • Loading branch information
osmr committed Feb 13, 2020
1 parent b7be479 commit 1fffda8
Show file tree
Hide file tree
Showing 5 changed files with 328 additions and 95 deletions.
5 changes: 3 additions & 2 deletions gluon/datasets/coco_hpe_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -445,7 +445,7 @@ def __init__(self):
self.num_training_samples = None
self.in_channels = 3
self.num_classes = CocoHpeDataset.classes
self.input_image_size = (256, 256)
self.input_image_size = (256, 192)
self.train_metric_capts = None
self.train_metric_names = None
self.train_metric_extra_kwargs = None
Expand All @@ -455,7 +455,8 @@ def __init__(self):
self.test_metric_names = ["CocoHpeOksApMetric"]
self.test_metric_extra_kwargs = [
{"name": "OksAp",
"coco": None}]
"coco": None,
"in_vis_thresh": 0.0}]
self.saver_acc_ind = 0
self.do_transform = True
self.val_transform = CocoHpeValTransform
Expand Down
134 changes: 132 additions & 2 deletions gluon/gluoncv2/models/others/oth_simple_pose_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,133 @@
'oth_simple_pose_resnet152_v1d', 'oth_resnet50_v1d', 'oth_resnet101_v1d',
'oth_resnet152_v1d']

import cv2
import numpy as np
import mxnet as mx
from mxnet.context import cpu
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet import initializer
import gluoncv as gcv
from gluoncv.data.transforms.pose import get_final_preds


def get_max_pred(batch_heatmaps):
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
width = batch_heatmaps.shape[3]
heatmaps_reshaped = batch_heatmaps.reshape((batch_size, num_joints, -1))
idx = mx.nd.argmax(heatmaps_reshaped, 2)
maxvals = mx.nd.max(heatmaps_reshaped, 2)

maxvals = maxvals.reshape((batch_size, num_joints, 1))
idx = idx.reshape((batch_size, num_joints, 1))

preds = mx.nd.tile(idx, (1, 1, 2)).astype(np.float32)

preds[:, :, 0] = (preds[:, :, 0]) % width
preds[:, :, 1] = mx.nd.floor((preds[:, :, 1]) / width)

pred_mask = mx.nd.tile(mx.nd.greater(maxvals, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)

preds *= pred_mask
return preds, maxvals


def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]


def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)


def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)

src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs

return src_result


def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale])

scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]

rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)

src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir

src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

return trans


def transform_preds(coords, center, scale, output_size):
target_coords = mx.nd.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2].asnumpy(), trans)
return target_coords


def _get_final_preds(batch_heatmaps, center, scale):
center_ = center.asnumpy()
scale_ = scale.asnumpy()

coords, maxvals = get_max_pred(batch_heatmaps)

heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]

# post-processing
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(mx.nd.floor(coords[n][p][0] + 0.5).asscalar())
py = int(mx.nd.floor(coords[n][p][1] + 0.5).asscalar())
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
diff = mx.nd.concat(hm[py][px+1] - hm[py][px-1],
hm[py+1][px] - hm[py-1][px],
dim=0)
coords[n][p] += mx.nd.sign(diff) * .25

preds = mx.nd.zeros_like(coords)

# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(coords[i], center_[i], scale_[i],
[heatmap_width, heatmap_height])

return preds, maxvals


class SimplePoseResNet(HybridBlock):
Expand Down Expand Up @@ -119,11 +239,21 @@ def hybrid_forward(self, F, x, center=None, scale=None):
x = self.final_layer(x)

if center is not None:
y, maxvals = get_final_preds(x.as_in_context(mx.cpu()), center.asnumpy(), scale.asnumpy())
batch_heatmaps = x.as_in_context(mx.cpu())
center_ = center.as_in_context(mx.cpu())
scale_ = scale.as_in_context(mx.cpu())
y, maxvals = _get_final_preds(
batch_heatmaps=batch_heatmaps,
center=center_,
scale=scale_)
return y, maxvals
else:
return x

@staticmethod
def calc_pose(batch_heatmaps, center, scale):
return _get_final_preds(batch_heatmaps, center, scale)


def get_simple_pose_resnet(base_name,
pretrained=False,
Expand Down
155 changes: 149 additions & 6 deletions gluon/gluoncv2/models/simplepose_coco.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,129 @@
'simplepose_resneta152b_coco']

import os
import numpy as np
import mxnet as mx
import cv2
from mxnet import cpu
from mxnet.gluon import nn, HybridBlock
from .common import get_activation_layer, BatchNormExtra, conv1x1
from .resnet import resnet18, resnet50b, resnet101b, resnet152b
from .resneta import resneta50b, resneta101b, resneta152b


def calc_keypoints_with_max_scores(heatmap):
width = heatmap.shape[3]

heatmap_vector = heatmap.reshape((0, 0, -3))

indices = heatmap_vector.argmax(axis=2, keepdims=True)
scores = heatmap_vector.max(axis=2, keepdims=True)

keypoints = indices.tile((1, 1, 2))

keypoints[:, :, 0] = keypoints[:, :, 0] % width
keypoints[:, :, 1] = mx.nd.floor((keypoints[:, :, 1]) / width)

pred_mask = mx.nd.tile(mx.nd.greater(scores, 0.0), (1, 1, 2))
pred_mask = pred_mask.astype(np.float32)

keypoints *= pred_mask
return keypoints, scores


def affine_transform(pt, t):
new_pt = np.array([pt[0], pt[1], 1.]).T
new_pt = np.dot(t, new_pt)
return new_pt[:2]


def get_3rd_point(a, b):
direct = a - b
return b + np.array([-direct[1], direct[0]], dtype=np.float32)


def get_dir(src_point, rot_rad):
sn, cs = np.sin(rot_rad), np.cos(rot_rad)

src_result = [0, 0]
src_result[0] = src_point[0] * cs - src_point[1] * sn
src_result[1] = src_point[0] * sn + src_point[1] * cs

return src_result


def get_affine_transform(center,
scale,
rot,
output_size,
shift=np.array([0, 0], dtype=np.float32),
inv=0):
if not isinstance(scale, np.ndarray) and not isinstance(scale, list):
scale = np.array([scale, scale])

scale_tmp = scale
src_w = scale_tmp[0]
dst_w = output_size[0]
dst_h = output_size[1]

rot_rad = np.pi * rot / 180
src_dir = get_dir([0, src_w * -0.5], rot_rad)
dst_dir = np.array([0, dst_w * -0.5], np.float32)

src = np.zeros((3, 2), dtype=np.float32)
dst = np.zeros((3, 2), dtype=np.float32)
src[0, :] = center + scale_tmp * shift
src[1, :] = center + src_dir + scale_tmp * shift
dst[0, :] = [dst_w * 0.5, dst_h * 0.5]
dst[1, :] = np.array([dst_w * 0.5, dst_h * 0.5]) + dst_dir

src[2:, :] = get_3rd_point(src[0, :], src[1, :])
dst[2:, :] = get_3rd_point(dst[0, :], dst[1, :])

if inv:
trans = cv2.getAffineTransform(np.float32(dst), np.float32(src))
else:
trans = cv2.getAffineTransform(np.float32(src), np.float32(dst))

return trans


def transform_preds(coords, center, scale, output_size):
target_coords = mx.nd.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
target_coords[p, 0:2] = affine_transform(coords[p, 0:2].asnumpy(), trans)
return target_coords


def _calc_pose(heatmap, center, scale):
center_ = center.asnumpy()
scale_ = scale.asnumpy()

keypoints, scores = calc_keypoints_with_max_scores(heatmap)

heatmap_height = heatmap.shape[2]
heatmap_width = heatmap.shape[3]

# post-processing
for n in range(keypoints.shape[0]):
for p in range(keypoints.shape[1]):
hm = heatmap[n][p]
px = int(mx.nd.floor(keypoints[n][p][0] + 0.5).asscalar())
py = int(mx.nd.floor(keypoints[n][p][1] + 0.5).asscalar())
if 1 < px < heatmap_width - 1 and 1 < py < heatmap_height - 1:
diff = mx.nd.concat(hm[py][px + 1] - hm[py][px - 1], hm[py + 1][px] - hm[py - 1][px], dim=0)
keypoints[n][p] += mx.nd.sign(diff) * .25

preds = mx.nd.zeros_like(keypoints)

# Transform back
for i in range(keypoints.shape[0]):
preds[i] = transform_preds(keypoints[i], center_[i], scale_[i], [heatmap_width, heatmap_height])

return preds, scores


class DeconvBlock(HybridBlock):
"""
Deconvolution block with batch normalization and activation.
Expand Down Expand Up @@ -114,6 +230,8 @@ class SimplePose(HybridBlock):
bn_use_global_stats : bool, default False
Whether global moving statistics is used instead of local batch-norm for BatchNorm layers.
Useful for fine-tuning.
return_heatmap_only : bool, default False
Whether to return only heatmap.
bn_cudnn_off : bool, default True
Whether to disable CUDNN batch normalization operator.
in_channels : int, default 3
Expand All @@ -129,6 +247,7 @@ def __init__(self,
channels,
bn_use_global_stats=False,
bn_cudnn_off=True,
return_heatmap_only=False,
in_channels=3,
in_size=(256, 192),
keypoints=17,
Expand All @@ -137,6 +256,8 @@ def __init__(self,
assert (in_channels == 3)
self.in_size = in_size
self.keypoints = keypoints
self.return_heatmap_only = return_heatmap_only
self.out_size = (in_size[0] // 4, in_size[1] // 4)

with self.name_scope():
self.backbone = backbone
Expand All @@ -162,8 +283,24 @@ def __init__(self,
def hybrid_forward(self, F, x):
x = self.backbone(x)
x = self.decoder(x)
x = self.final_block(x)
return x
heatmap = self.final_block(x)
if self.return_heatmap_only:
return heatmap

return heatmap
# heatmap_vector = heatmap.reshape((0, 0, -3))
# indices = heatmap_vector.argmax(axis=2, keepdims=True)
# scores = heatmap_vector.max(axis=2, keepdims=True)
# keys_x = indices % self.out_size[1]
# keys_y = (indices / self.out_size[1]).floor()
# keypoints = F.concat(keys_x, keys_y, dim=2)
# keypoints = F.broadcast_mul(keypoints, scores.clip(0.0, 1.0e5))
#
# return heatmap, keypoints, scores

@staticmethod
def calc_pose(heatmap, center, scale):
return _calc_pose(heatmap, center, scale)


def get_simplepose(backbone,
Expand Down Expand Up @@ -421,7 +558,7 @@ def _test():
if not pretrained:
net.initialize(ctx=ctx)

# net.hybridize()
net.hybridize()
net_params = net.collect_params()
weight_count = 0
for param in net_params.values():
Expand All @@ -437,11 +574,17 @@ def _test():
assert (model != simplepose_resneta101b_coco or weight_count == 53011057)
assert (model != simplepose_resneta152b_coco or weight_count == 68654705)

x = mx.nd.zeros((1, 3, 256, 192), ctx=ctx)
y = net(x)
assert ((y.shape[0] == x.shape[0]) and (y.shape[1] == keypoints) and (y.shape[2] == x.shape[2] // 4) and
batch = 14
x = mx.nd.zeros((batch, 3, 256, 192), ctx=ctx)
y, _, _ = net(x)
assert ((y.shape[0] == batch) and (y.shape[1] == keypoints) and (y.shape[2] == x.shape[2] // 4) and
(y.shape[3] == x.shape[3] // 4))

center = mx.nd.zeros((batch, 2), ctx=ctx)
scale = mx.nd.ones((batch, 2), ctx=ctx)
z = net.calc_pose(y, center, scale)
assert (z.shape[0] == batch)


if __name__ == "__main__":
_test()
Loading

0 comments on commit 1fffda8

Please sign in to comment.