Skip to content

Commit

Permalink
update libs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholasli1995 committed Mar 4, 2022
1 parent 2dcb28e commit 0a6191e
Show file tree
Hide file tree
Showing 4 changed files with 27 additions and 12 deletions.
2 changes: 1 addition & 1 deletion libs/common/img_proc.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,7 +199,7 @@ def generate_xy_map(bbox, resolution, global_size):
resolution (height, width): target resolution
global_size (height, width): the size of original image
"""
map_height, map_width = resolution
map_width, map_height = resolution
g_height, g_width = global_size
x_start, x_end = 2*bbox[0]/g_width - 1, 2*bbox[2]/g_width - 1
y_start, y_end = 2*bbox[1]/g_height - 1, 2*bbox[3]/g_height - 1
Expand Down
12 changes: 7 additions & 5 deletions libs/model/egonet.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,12 @@ def crop_single_instance(self,
Crop a single instance given an image and bounding box.
"""
bbox = to_npy(bbox)
target_ar = resolution[0] / resolution[1]
width, height = resolution
target_ar = height / width
ret = modify_bbox(bbox, target_ar)
c, s, r = ret['c'], ret['s'], 0.
# xy_dict: parameters for adding xy coordinate maps
trans = get_affine_transform(c, s, r, resolution)
trans = get_affine_transform(c, s, r, (height, width))
instance = cv2.warpAffine(img,
trans,
(int(resolution[0]), int(resolution[1])),
Expand Down Expand Up @@ -114,7 +115,7 @@ def crop_instances(self,
all_instances = []
# each record stores attributes of one instance
all_records = []
target_ar = resolution[0] / resolution[1]
target_ar = resolution[1] / resolution[0]
for idx, path in enumerate(annot_dict['path']):
data_numpy = self.load_cv2(path)
boxes = annot_dict['boxes'][idx]
Expand Down Expand Up @@ -432,8 +433,9 @@ def get_keypoints(self,
instances = instances.cuda()
output = self.HC(instances)
# local part coordinates
width, height = self.resolution
local_coord = output[1].data.cpu().numpy()
local_coord *= self.resolution[0]
local_coord *= np.array(self.resolution).reshape(1, 1, 2)
# transform local part coordinates to screen coordinates
centers = [records[i]['center'] for i in range(len(records))]
scales = [records[i]['scale'] for i in range(len(records))]
Expand All @@ -442,7 +444,7 @@ def get_keypoints(self,
trans_inv = get_affine_transform(centers[instance_idx],
scales[instance_idx],
rots[instance_idx],
self.resolution,
(height, width),
inv=1
)
screen_coord = affine_transform_modified(local_coord[instance_idx],
Expand Down
23 changes: 18 additions & 5 deletions tools/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,14 @@ def filter_conf(record, thres=0.0):
}
return True, filterd_record

def gather_dict(request, references, filter_c=True, larger=True, thres=0.):
def gather_dict(request,
references,
filter_c=True,
larger=True,
thres=0.,
target_ar=1.,
enlarge=1.2
):
"""
Gather a annotation dictionary from the prepared detections as requsted.
"""
Expand All @@ -104,8 +111,8 @@ def gather_dict(request, references, filter_c=True, larger=True, thres=0.):
# enlarge the input bounding box if needed
for instance_id in range(len(bbox)):
bbox[instance_id] = np.array(modify_bbox(bbox[instance_id],
target_ar=1,
enlarge=1.2
target_ar=target_ar,
enlarge=enlarge
)['bbox']
)
ret['boxes'].append(bbox)
Expand Down Expand Up @@ -158,8 +165,14 @@ def inference(testset, model, results, cfgs):
merge(all_records, record)
if cfgs['use_pred_box']:
# use detected bounding box from any 2D/3D detector
thres = cfgs['conf_thres'] if 'conf_thres' in cfgs else 0.
annot_dict = gather_dict(meta, results['pred'], thres=thres)
thres = cfgs.get('conf_thres', 0.)
width, height = cfgs['heatmapModel']['input_size']
enlarge = cfgs['dataset'].get('enlarge_factor', 1.2)
annot_dict = gather_dict(meta, results['pred'],
thres=thres,
target_ar=height/width,
enlarge=enlarge
)
if len(annot_dict['path']) != 0:
record2 = model(annot_dict)
# update drawings
Expand Down
2 changes: 1 addition & 1 deletion tools/train_lifting.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def main():
os.mkdir(save_path)
# save the model and the normalization statistics
torch.save(cascade[0].cpu().state_dict(),
os.path.join(save_path, 'L.th')
os.path.join(save_path, 'L.pth')
)
np.save(os.path.join(save_path, 'LS.npy'), train_dataset.statistics)
logger.info('=> saving final model state to {}'.format(save_path))
Expand Down

0 comments on commit 0a6191e

Please sign in to comment.