Skip to content

Commit

Permalink
Update test_Fusion.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Linfeng-Tang authored Jun 1, 2023
1 parent 96d1e70 commit ecdc3dd
Showing 1 changed file with 1 addition and 14 deletions.
15 changes: 1 addition & 14 deletions test_Fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,17 +31,10 @@ def main():

def multi_task_tester(test_loader, multi_task_model, device, opts):
multi_task_model.eval()
is_rgb = False ## 用来标记重建的可见光图像是彩色图像还是灰度图像。
test_bar= tqdm(test_loader)
seg_metric = SegmentationMetric(opts.class_nb, device=device)
lb_ignore = [255]
## define save dir
Fusion_save_dir = os.path.join('./Fusion_results', opts.dataname, opts.name)
# Fusion_save_dir = os.path.join('./Biseg/PSFusion_5180/', 'train')
# Fusion_save_dir = os.path.join('/data/timer/Segmentation/SegNext/datasets/MSRS/PSFusion')
Fusion_save_dir = os.path.join(opts.result_dir, opts.dataname)
os.makedirs(Fusion_save_dir, exist_ok=True)
# Seg_save_dir = os.path.join(save_root, 'Segmentation', opts.dataname)
# os.makedirs(Seg_save_dir, exist_ok=True)
with torch.no_grad(): # operations inside don't track history
for it, (img_ir, img_vi, img_names, widths, heights) in enumerate(test_bar):
img_ir = img_ir.to(device)
Expand All @@ -50,17 +43,11 @@ def multi_task_tester(test_loader, multi_task_model, device, opts):
vi_Y = vi_Y.to(device)
vi_Cb = vi_Cb.to(device)
vi_Cr = vi_Cr.to(device)
if it == 0:
flops, params = profile(multi_task_model,inputs=(img_vi, img_ir))
print('flops: {:.2f} G | params: {:.2f} M'.format(flops / (1024* 1024 * 1024), params / (1024* 1024)))
Seg_pred, _, _, fused_img, _, _ = multi_task_model(img_vi, img_ir)
# seg_result = torch.argmax(Seg_pred, dim=1, keepdim=True)
fused_img = YCbCr2RGB(fused_img, vi_Cb, vi_Cr)
for i in range(len(img_names)):
img_name = img_names[i]
# seg_save_name = os.path.join(Seg_save_dir, img_name)
fusion_save_name = os.path.join(Fusion_save_dir, img_name)
# seg_visualize(seg_result[i, ::].unsqueeze(0).squeeze(dim=1), seg_save_name)
save_img_single(fused_img[i, ::], fusion_save_name, widths[i], heights[i])
test_bar.set_description('Image: {} '.format(img_name))

Expand Down

0 comments on commit ecdc3dd

Please sign in to comment.