diff --git a/util/visualizer.lua b/util/visualizer.lua index 9d2cb7c..dae04e9 100644 --- a/util/visualizer.lua +++ b/util/visualizer.lua @@ -27,6 +27,7 @@ function visualizer.disp_images(imgs, opt) end -- function visualizer.save_results(visuals, opt, epoch, counter) + local tensortype = torch.getdefaulttensortype() torch.setdefaulttensortype('torch.FloatTensor') local image_out = nil local win_size = opt.display_winsize @@ -42,7 +43,7 @@ function visualizer.save_results(visuals, opt, epoch, counter) out_path = paths.concat(opt.checkpoints_dir, opt.name, 'epoch' .. epoch .. '_iter' .. counter .. '_train_res.png') image.save(out_path, image_out) - torch.setdefaulttensortype('torch.CudaTensor') + torch.setdefaulttensortype(tensortype) end function visualizer.save_images(imgs, save_dir, impaths, s1, s2)