From a4ee11ab26203ec8f816335725f5d2576540140b Mon Sep 17 00:00:00 2001 From: Shichao Li Date: Wed, 23 Feb 2022 16:33:11 +0800 Subject: [PATCH] update libs --- libs/trainer/trainer.py | 10 ++++++++-- libs/visualization/debug.py | 22 +++++++++++----------- tools/train_IGRs.py | 2 +- 3 files changed, 20 insertions(+), 14 deletions(-) diff --git a/libs/trainer/trainer.py b/libs/trainer/trainer.py index 798f333..a5edc33 100644 --- a/libs/trainer/trainer.py +++ b/libs/trainer/trainer.py @@ -14,10 +14,9 @@ from libs.visualization.debug import save_debug_images from libs.common.utils import AverageMeter from libs.metric.criterions import Evaluator +from libs.logger.logger import get_dirs -import torch.nn.functional as F import torch -import cv2 import numpy as np import time import os @@ -253,6 +252,13 @@ def train(train_dataset, ) # back to training mode model.train() + # save a snapshot + if epoch in cfgs['training_settings'].get('snapshot_epochs', []): + output_dir, _ = get_dirs(cfgs) + prefix = cfgs['exp_type'] + model_state_file = os.path.join(output_dir, prefix + '_{:d}.pth'.format(epoch)) + logger.info('=> Snapshot model to {}'.format(model_state_file)) + torch.save(model.module.state_dict(), model_state_file) logger.info('Training finished.') return {'model':model, 'batch_idx':x_buffer, 'loss':y_buffer} diff --git a/libs/visualization/debug.py b/libs/visualization/debug.py index a62ce7b..3f91df4 100644 --- a/libs/visualization/debug.py +++ b/libs/visualization/debug.py @@ -24,7 +24,8 @@ def draw_circles(ndarr, width, height, padding, - color=[255,0,0] + color=[255,0,0], + add_idx=True ): k = 0 for y in range(ymaps): @@ -32,18 +33,18 @@ def draw_circles(ndarr, if k >= nmaps: break joints = batch_joints[k] - # joints_vis = batch_joints_vis[k] - - # for joint, joint_vis in zip(joints, joints_vis): - # joint[0] = x * width + padding + joint[0] - # joint[1] = y * height + padding + joint[1] - # if joint_vis.item(): - # cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, color, 2) - for joint in joints: + for idx, joint in enumerate(joints): joint[0] = x * width + padding + joint[0] joint[1] = y * height + padding + joint[1] cv2.circle(ndarr, (int(joint[0]), int(joint[1])), 2, color, 2) - k = k + 1 + if add_idx: + cv2.putText(ndarr, + str(idx+1), + (int(joint[0]), int(joint[1])), + cv2.FONT_HERSHEY_SIMPLEX, + 1, color, 1 + ) + k += 1 return ndarr # functions used for debugging heatmap-based keypoint localization model # @@ -57,7 +58,6 @@ def save_batch_image_with_joints(batch_image, batch_image: [batch_size, channel, height, width] batch_joints: [batch_size, num_joints, 3], batch_joints_vis: [batch_size, num_joints, 1], - } """ grid = torchvision.utils.make_grid(batch_image[:, :3, :, :], nrow, padding, True) ndarr = grid.mul(255).clamp(0, 255).byte().permute(1, 2, 0).cpu().numpy() diff --git a/tools/train_IGRs.py b/tools/train_IGRs.py index c52d5e6..9706618 100644 --- a/tools/train_IGRs.py +++ b/tools/train_IGRs.py @@ -102,7 +102,7 @@ def train(model, model_settings, GPUs, cfgs, logger, final_output_dir): final_model_state_file = os.path.join(final_output_dir, 'HC.pth') logger.info('=> saving final model state to {}'.format(final_model_state_file)) - torch.save(model.module.cpu().state_dict(), final_model_state_file) + torch.save(model.module.state_dict(), final_model_state_file) return def evaluate(model, model_settings, GPUs, cfgs, logger, final_output_dir, eval_train=False):