Skip to content

Commit

Permalink
code cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholasli1995 committed Jul 4, 2021
1 parent aeb6a60 commit 604280d
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 91 deletions.
2 changes: 1 addition & 1 deletion libs/arguments/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,4 +44,4 @@ def parse_args():
configs['config_path'] = args.cfg
configs['visualize'] = args.visualize
configs['batch_to_show'] = args.batch_to_show
return configs
return configs
205 changes: 120 additions & 85 deletions libs/common/img_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
import numpy as np
import torch
import torch.nn.functional as F
import math
import os

SIZE = 200.0

def transform_preds(coords, center, scale, output_size):
"""
Transform local coordinates within a patch to screen coordinates.
"""
target_coords = np.zeros(coords.shape)
trans = get_affine_transform(center, scale, 0, output_size, inv=1)
for p in range(coords.shape[0]):
Expand All @@ -28,7 +30,10 @@ def get_affine_transform(center,
shift=np.array([0, 0], dtype=np.float32),
inv=0
):
# TODO: speed up this function
"""
Estimate an affine transformation given crop parameters (center, scale and
rotation) and output resolution.
"""
if isinstance(scale, list):
scale = np.array(scale)
if isinstance(center, list):
Expand Down Expand Up @@ -65,6 +70,9 @@ def affine_transform(pt, t):
return new_pt[:2]

def affine_transform_modified(pts, t):
"""
Apply affine transformation with homogeneous coordinates.
"""
# pts of shape [n, 2]
new_pts = np.hstack([pts, np.ones((len(pts), 1))]).T
new_pts = t @ new_pts
Expand All @@ -84,6 +92,9 @@ def get_dir(src_point, rot_rad):
return src_result

def crop(img, center, scale, output_size, rot=0):
"""
A cropping function implemented as warping.
"""
trans = get_affine_transform(center, scale, rot, output_size)

dst_img = cv2.warpAffine(img,
Expand All @@ -95,9 +106,9 @@ def crop(img, center, scale, output_size, rot=0):
return dst_img

def simple_crop(input_image, center, crop_size):
'''
Simple cropping without warping
'''
"""
A simple cropping function without warping.
"""
assert len(input_image.shape) == 3, 'Unsupported image format.'
channel = input_image.shape[2]
# crop a rectangular region around the center in the image
Expand Down Expand Up @@ -125,13 +136,19 @@ def simple_crop(input_image, center, crop_size):
return cropped

def np_random():
# return a randomly number sampled uniformly from [-1, 1]
"""
Return a random number sampled uniformly from [-1, 1]
"""
return np.random.rand()*2 - 1

def jitter_bbox_with_kpts(old_bbox, joints, parameters):
# bbox: [x1, y1, x2, y2]
# joints: [N, 3]
# randomly shifting and resizeing a bounding box and mask out occluded joints
"""
Randomly shifting and resizeing a bounding box and mask out occluded joints.
Used as data augmentation to improve robustness to detector noise.
bbox: [x1, y1, x2, y2]
joints: [N, 3]
"""
new_joints = joints.copy()
width, height = old_bbox[2] - old_bbox[0], old_bbox[3] - old_bbox[1]
old_center = [0.5*(old_bbox[0] + old_bbox[2]),
Expand All @@ -156,7 +173,9 @@ def jitter_bbox_with_kpts(old_bbox, joints, parameters):
return new_bbox, new_joints

def jitter_bbox_with_kpts_no_occlu(old_bbox, joints, parameters):
# similar to the function above, but does not produce occluded joints
"""
Similar to the function above, but does not produce occluded joints
"""
width, height = old_bbox[2] - old_bbox[0], old_bbox[3] - old_bbox[1]
old_center = [0.5*(old_bbox[0] + old_bbox[2]),
0.5*(old_bbox[1] + old_bbox[3])]
Expand All @@ -173,10 +192,14 @@ def jitter_bbox_with_kpts_no_occlu(old_bbox, joints, parameters):
return new_bbox, joints

def generate_xy_map(bbox, resolution, global_size):
# generate the normalized coordinates as 2D maps
# bbox: [x1, y1, x2, y2] the local region
# resolution (height, width): target resolution
# global_size (height, width): the size of original image
"""
Generate the normalized coordinates as 2D maps which encodes location
information.
bbox: [x1, y1, x2, y2] the local region
resolution (height, width): target resolution
global_size (height, width): the size of original image
"""
map_height, map_width = resolution
g_height, g_width = global_size
x_start, x_end = 2*bbox[0]/g_width - 1, 2*bbox[2]/g_width - 1
Expand All @@ -189,6 +212,9 @@ def generate_xy_map(bbox, resolution, global_size):
return np.concatenate([x_map, y_map], axis=2)

def crop_single_instance(data_numpy, bbox, joints, parameters, pth_trans=None):
"""
Crop an instance from an image given the bounding box and part coordinates.
"""
reso = parameters['input_size']
transformed_joints = joints.copy()
if parameters['jitter_bbox']:
Expand Down Expand Up @@ -237,8 +263,9 @@ def get_tensor_from_img(path,
max_cnt=None
):
"""
read image and apply data augmentation to obtain a tensor.
Read image and apply data augmentation to obtain a tensor.
Keypoints are also transformed if given.
path: image path
c: cropping center
s: cropping scale
Expand All @@ -260,7 +287,6 @@ def get_tensor_from_img(path,
raise ValueError('Fail to read {}'.format(path))
if rgb:
data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
# TODO: clean here to the list(zip(*annot_list)) format
all_inputs = []
all_target = []
all_centers = []
Expand Down Expand Up @@ -320,11 +346,14 @@ def get_tensor_from_img(path,
return inputs, targets, target_weights, meta

def generate_target(joints, joints_vis, parameters):
'''
:param joints: [num_joints, 3]
:param joints_vis: [num_joints]
:return: target, target_weight(1: visible, 0: invisible)
'''
"""
Generate heatmap targets by drawing Gaussian dots.
joints: [num_joints, 3]
joints_vis: [num_joints]
return: target, target_weight (1: visible, 0: invisible)
"""
num_joints = parameters['num_joints']
target_type = parameters['target_type']
input_size = parameters['input_size']
Expand Down Expand Up @@ -381,7 +410,9 @@ def generate_target(joints, joints_vis, parameters):
return target, target_weight

def resize_bbox(left, top, right, bottom, target_ar=1.):
# resize bounding box to pre-defined aspect ratio
"""
Resize a bounding box to pre-defined aspect ratio.
"""
width = right - left
height = bottom - top
aspect_ratio = height/width
Expand All @@ -405,56 +436,74 @@ def resize_bbox(left, top, right, bottom, target_ar=1.):
}

def enlarge_bbox(left, top, right, bottom, enlarge):
"""
Enlarge a bounding box.
"""
width = right - left
height = bottom - top
new_width = width*enlarge[0]
new_height = height*enlarge[1]
center_x = (left + right)/2
center_y = (top + bottom)/2
new_left = center_x - 0.5*new_width
new_right = center_x + 0.5*new_width
new_top = center_y - 0.5*new_height
new_bottom = center_y + 0.5*new_height
new_width = width * enlarge[0]
new_height = height * enlarge[1]
center_x = (left + right) / 2
center_y = (top + bottom) / 2
new_left = center_x - 0.5 * new_width
new_right = center_x + 0.5 * new_width
new_top = center_y - 0.5 * new_height
new_bottom = center_y + 0.5 * new_height
return [new_left, new_top, new_right, new_bottom]

def modify_bbox(bbox, target_ar, enlarge=1.1):
"""
Enlarge a bounding box so that occluded parts may be included.
Modify a bounding box by enlarging/resizing.
"""
lbbox = enlarge_bbox(bbox[0], bbox[1], bbox[2], bbox[3], [enlarge, enlarge])
ret = resize_bbox(lbbox[0], lbbox[1], lbbox[2], lbbox[3], target_ar=target_ar)
return ret

def resize_crop(crop_size, target_ar=None):
"""
Resize a crop size to a pre-defined aspect ratio.
"""
if target_ar is None:
return crop_size
width = crop_size[0]
height = crop_size[1]
aspect_ratio = height/width
aspect_ratio = height / width
if aspect_ratio > target_ar:
new_width = height*(1/target_ar)
new_width = height * (1 / target_ar)
new_height = height
else:
new_height = width*target_ar
new_width = width
return [new_width, new_height]

def bbox2cs(bbox):
"""
Convert bounding box annotation to center and scale.
"""
return [(bbox[0] + bbox[2]/2), (bbox[1] + bbox[3]/2)], \
[(bbox[2] - bbox[0]/SIZE), (bbox[3] - bbox[1]/SIZE)]

def cs2bbox(center, size):
"""
Convert center/scale to a bounding box annotation.
"""
x1 = center[0] - size[0]
y1 = center[1] - size[1]
x2 = center[0] + size[0]
y2 = center[1] + size[1]
return [x1, y1, x2, y2]

def kpts2cs(keypoints, enlarge=1.1, method='boundary', target_ar=None, use_visibility=True):
'''
convert instance keypoint locations to cropping center and size
keypoints of shape (n_joints, 2 or 3)
'''
def kpts2cs(keypoints,
enlarge=1.1,
method='boundary',
target_ar=None,
use_visibility=True
):
"""
Convert instance screen coordinates to cropping center and size
keypoints of shape [n_joints, 2/3]
"""
assert keypoints.shape[1] in [2, 3], 'Unsupported input.'
if keypoints.shape[1] == 2:
visible_keypoints = keypoints
Expand Down Expand Up @@ -492,6 +541,9 @@ def kpts2cs(keypoints, enlarge=1.1, method='boundary', target_ar=None, use_visib
return center, crop_size, new_keypoints, vis_rate

def draw_bboxes(img_path, bboxes_dict, save_path=None):
"""
Draw bounding boxes with OpenCV.
"""
data_numpy = cv2.imread(img_path, 1 | 128)
for name, (color, bboxes) in bboxes_dict.items():
for bbox in bboxes:
Expand All @@ -503,15 +555,23 @@ def draw_bboxes(img_path, bboxes_dict, save_path=None):
return data_numpy

def imread_rgb(img_path):
"""
Read image with OpenCV.
"""
data_numpy = cv2.imread(img_path, 1 | 128)
data_numpy = cv2.cvtColor(data_numpy, cv2.COLOR_BGR2RGB)
return data_numpy

def save_cropped_patches(img_path, keypoints, save_dir="./", threshold=0.25,
enlarge=1.4, target_ar=None):
'''
crop instances from a image given key-points and save them.
'''
def save_cropped_patches(img_path,
keypoints,
save_dir="./",
threshold=0.25,
enlarge=1.4,
target_ar=None
):
"""
Crop instances from a image given part screen coordinates and save them.
"""
# data_numpy = cv2.imread(
# img_path, cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION
# )
Expand Down Expand Up @@ -547,10 +607,11 @@ def save_cropped_patches(img_path, keypoints, save_dir="./", threshold=0.25,
return new_paths, np.concatenate(all_new_keypoints, axis=0), all_bbox

def get_max_preds(batch_heatmaps):
'''
get predictions from score maps
heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
'''
"""
Get predictions from heatmaps with hard arg-max.
batch_heatmaps: numpy.ndarray([batch_size, num_joints, height, width])
"""
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
Expand All @@ -576,41 +637,10 @@ def get_max_preds(batch_heatmaps):
preds *= pred_mask
return preds, maxvals


def get_final_preds(config, batch_heatmaps, center, scale):
coords, maxvals = get_max_preds(batch_heatmaps)

heatmap_height = batch_heatmaps.shape[2]
heatmap_width = batch_heatmaps.shape[3]

# post-processing
if config.TEST.POST_PROCESS:
for n in range(coords.shape[0]):
for p in range(coords.shape[1]):
hm = batch_heatmaps[n][p]
px = int(math.floor(coords[n][p][0] + 0.5))
py = int(math.floor(coords[n][p][1] + 0.5))
if 1 < px < heatmap_width-1 and 1 < py < heatmap_height-1:
diff = np.array(
[
hm[py][px+1] - hm[py][px-1],
hm[py+1][px]-hm[py-1][px]
]
)
coords[n][p] += np.sign(diff) * .25

preds = coords.copy()

# Transform back
for i in range(coords.shape[0]):
preds[i] = transform_preds(
coords[i], center[i], scale[i], [heatmap_width, heatmap_height]
)

return preds, maxvals

# soft-argmax instead of hard-argmax considering quantization errors
def soft_arg_max_np(batch_heatmaps):
"""
Soft-argmax instead of hard-argmax considering quantization errors.
"""
assert isinstance(batch_heatmaps, np.ndarray), \
'batch_heatmaps should be numpy.ndarray'
assert batch_heatmaps.ndim == 4, 'batch_images should be 4-ndim'
Expand Down Expand Up @@ -647,7 +677,9 @@ def soft_arg_max_np(batch_heatmaps):
return preds, maxvals

def soft_arg_max(batch_heatmaps):
# a pytorch version of soft-argmax
"""
A pytorch version of soft-argmax
"""
assert len(batch_heatmaps.shape) == 4, 'batch_images should be 4-ndim'
batch_size = batch_heatmaps.shape[0]
num_joints = batch_heatmaps.shape[1]
Expand Down Expand Up @@ -676,9 +708,12 @@ def soft_arg_max(batch_heatmaps):
return preds, maxvals

def appro_cr(coordinates):
# approximate the square of cross-ratio along four ordered 2D points using
# inner-product
# coordinates: PyTorch tensor of shape [4, 2]
"""
Approximate the square of cross-ratio along four ordered 2D points using
inner-product
coordinates: PyTorch tensor of shape [4, 2]
"""
AC = coordinates[2] - coordinates[0]
BD = coordinates[3] - coordinates[1]
BC = coordinates[2] - coordinates[1]
Expand Down
Loading

0 comments on commit 604280d

Please sign in to comment.