Skip to content

Commit fd9125c

Browse files
committed
Reshape displacements and offsets array for cleaner access, few lines less code, sligh perf increase
1 parent 8ede8bf commit fd9125c

File tree

2 files changed

+16
-28
lines changed

2 files changed

+16
-28
lines changed

posenet/decode.py

+7-26
Original file line numberDiff line numberDiff line change
@@ -3,47 +3,31 @@
33
from posenet.constants import *
44

55

6-
def get_offset_point(coord, keypoint_id, offsets):
7-
return np.array((
8-
offsets[coord[0], coord[1], keypoint_id],
9-
offsets[coord[0], coord[1], keypoint_id + NUM_KEYPOINTS])).astype(np.int32)
10-
11-
12-
def get_image_coords(heatmap_coord, keypoint_id, output_stride, offsets):
13-
return heatmap_coord * output_stride + get_offset_point(heatmap_coord, keypoint_id, offsets)
14-
15-
166
def traverse_to_targ_keypoint(
177
edge_id, source_keypoint, target_keypoint_id, scores, offsets, output_stride, displacements
188
):
199
height = scores.shape[0]
2010
width = scores.shape[1]
21-
num_edges = displacements.shape[2] // 2
2211

2312
source_keypoint_indices = np.clip(
2413
np.round(source_keypoint / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32)
2514

26-
displacement = np.array((
27-
displacements[source_keypoint_indices[0], source_keypoint_indices[1], edge_id],
28-
displacements[source_keypoint_indices[0], source_keypoint_indices[1], edge_id + num_edges]
29-
))
30-
31-
displaced_point = source_keypoint + displacement
15+
displaced_point = source_keypoint + displacements[
16+
source_keypoint_indices[0], source_keypoint_indices[1], edge_id]
3217

3318
displaced_point_indices = np.clip(
3419
np.round(displaced_point / output_stride), a_min=0, a_max=[height - 1, width - 1]).astype(np.int32)
3520

36-
offset_point = get_offset_point(displaced_point_indices, target_keypoint_id, offsets)
37-
3821
score = scores[displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id]
3922

40-
position = displaced_point_indices * output_stride + offset_point
23+
image_coord = displaced_point_indices * output_stride + offsets[
24+
displaced_point_indices[0], displaced_point_indices[1], target_keypoint_id]
4125

42-
return score, position
26+
return score, image_coord
4327

4428

4529
def decode_pose(
46-
root_score, root_id, root_coord,
30+
root_score, root_id, root_image_coord,
4731
scores,
4832
offsets,
4933
output_stride,
@@ -55,11 +39,8 @@ def decode_pose(
5539

5640
instance_keypoint_scores = np.zeros(num_parts)
5741
instance_keypoint_coords = np.zeros((num_parts, 2))
58-
59-
root_point = get_image_coords(root_coord, root_id, output_stride, offsets)
60-
6142
instance_keypoint_scores[root_id] = root_score
62-
instance_keypoint_coords[root_id] = root_point
43+
instance_keypoint_coords[root_id] = root_image_coord
6344

6445
# FIXME can we vectorize these loops cleanly?
6546
for edge in reversed(range(num_edges)):

posenet/decode_multi.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -126,15 +126,22 @@ def decode_multiple_poses(
126126
scored_parts = build_part_with_score_fast(score_threshold, LOCAL_MAXIMUM_RADIUS, scores)
127127
scored_parts = sorted(scored_parts, key=lambda x: x[0], reverse=True)
128128

129+
# change dimensions from (h, w, x) to (h, w, x//2, 2) to allow return of complete coord array
130+
height = scores.shape[0]
131+
width = scores.shape[1]
132+
offsets = offsets.reshape(height, width, 2, -1).swapaxes(2, 3)
133+
displacements_fwd = displacements_fwd.reshape(height, width, 2, -1).swapaxes(2, 3)
134+
displacements_bwd = displacements_bwd.reshape(height, width, 2, -1).swapaxes(2, 3)
135+
129136
for root_score, root_id, root_coord in scored_parts:
130-
root_image_coords = get_image_coords(root_coord, root_id, output_stride, offsets)
137+
root_image_coords = root_coord * output_stride + offsets[root_coord[0], root_coord[1], root_id]
131138

132139
if within_nms_radius_fast(
133140
pose_keypoint_coords[:pose_count, root_id, :], squared_nms_radius, root_image_coords):
134141
continue
135142

136143
keypoint_scores, keypoint_coords = decode_pose(
137-
root_score, root_id, root_coord,
144+
root_score, root_id, root_image_coords,
138145
scores, offsets, output_stride,
139146
displacements_fwd, displacements_bwd)
140147

0 commit comments

Comments
 (0)