Skip to content

Commit

Permalink
update libs
Browse files Browse the repository at this point in the history
  • Loading branch information
Nicholasli1995 committed Feb 23, 2022
1 parent fcc83ea commit a4ee11a
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 14 deletions.
10 changes: 8 additions & 2 deletions libs/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
from libs.visualization.debug import save_debug_images
from libs.common.utils import AverageMeter
from libs.metric.criterions import Evaluator
from libs.logger.logger import get_dirs

import torch.nn.functional as F
import torch
import cv2
import numpy as np
import time
import os
Expand Down Expand Up @@ -253,6 +252,13 @@ def train(train_dataset,
)
# back to training mode
model.train()
# save a snapshot
if epoch in cfgs['training_settings'].get('snapshot_epochs', []):
output_dir, _ = get_dirs(cfgs)
prefix = cfgs['exp_type']
model_state_file = os.path.join(output_dir, prefix + '_{:d}.pth'.format(epoch))
logger.info('=> Snapshot model to {}'.format(model_state_file))
torch.save(model.module.state_dict(), model_state_file)
logger.info('Training finished.')
return {'model':model, 'batch_idx':x_buffer, 'loss':y_buffer}

Expand Down
22 changes: 11 additions & 11 deletions libs/visualization/debug.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,26 +24,27 @@ def draw_circles(ndarr,
width,
height,
padding,
color=[255,0,0]
color=[255,0,0],
add_idx=True
):
k = 0
for y in range(ymaps):
for x in range(xmaps):
if k >= nmaps:
break
joints = batch_joints[k]
# joints_vis = batch_joints_vis[k]

# for joint, joint_vis in zip(joints, joints_vis):
# joint[0] = x * width + padding + joint[0]
# joint[1] = y * height + padding + joint[1]
# if joint_vis.item():
# cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, color, 2)
for joint in joints:
for idx, joint in enumerate(joints):
joint[0] = x * width + padding + joint[0]
joint[1] = y * height + padding + joint[1]
cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, color, 2)
k = k + 1
if add_idx:
cv2.putText(ndarr,
str(idx+1),
(int(joint[0]), int(joint[1])),
cv2.FONT_HERSHEY_SIMPLEX,
1, color, 1
)
k += 1
return ndarr

# functions used for debugging heatmap-based keypoint localization model #
Expand All @@ -57,7 +58,6 @@ def save_batch_image_with_joints(batch_image,
batch_image: [batch_size, channel, height, width]
batch_joints: [batch_size, num_joints, 3],
batch_joints_vis: [batch_size, num_joints, 1],
}
"""
grid = torchvision.utils.make_grid(batch_image[:, :3, :, :], nrow, padding, True)
ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy()
Expand Down
2 changes: 1 addition & 1 deletion tools/train_IGRs.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def train(model, model_settings, GPUs, cfgs, logger, final_output_dir):

final_model_state_file = os.path.join(final_output_dir, 'HC.pth')
logger.info('=> saving final model state to {}'.format(final_model_state_file))
torch.save(model.module.cpu().state_dict(), final_model_state_file)
torch.save(model.module.state_dict(), final_model_state_file)
return

def evaluate(model, model_settings, GPUs, cfgs, logger, final_output_dir, eval_train=False):
Expand Down

0 comments on commit a4ee11a

Please sign in to comment.