Skip to content

Commit

Permalink
fix save_current_results
Browse files Browse the repository at this point in the history
  • Loading branch information
junyanz committed Jun 18, 2017
1 parent 13164a2 commit 55cb596
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 29 deletions.
3 changes: 2 additions & 1 deletion options.lua
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ function options.parse_options(mode)
-- save opt to checkpoints
paths.mkdir(opt.checkpoints_dir)
paths.mkdir(paths.concat(opt.checkpoints_dir, opt.name))

opt.visual_dir = paths.concat(opt.checkpoints_dir, opt.name, 'visuals')
paths.mkdir(opt.visual_dir)
-- save opt to the disk
fd = io.open(paths.concat(opt.checkpoints_dir, opt.name, 'opt_' .. mode .. '.txt'), 'w')
for i,k in ipairs(keyset) do
Expand Down
5 changes: 4 additions & 1 deletion train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,10 @@ end

function save_current_results(epoch, counter)
local visuals = model:GetCurrentVisuals(opt)
visualizer.save_results(visuals, opt, epoch, counter)
for i,visual in ipairs(visuals) do
output_path = paths.concat(opt.visual_dir, 'train_epoch' .. epoch .. '_iter' .. counter .. '_' .. visual.label .. '.jpg')
visualizer.save_results(visual.img, output_path)
end
end

function print_current_errors(epoch, counter_in_epoch)
Expand Down
44 changes: 17 additions & 27 deletions util/visualizer.lua
Original file line number Diff line number Diff line change
Expand Up @@ -15,47 +15,37 @@ require 'image'

-- function visualizer
function visualizer.disp_image(img_data, win_size, display_id, title)
local tensortype = torch.getdefaulttensortype()
disp.image(util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size)), {win=display_id, title=title})
torch.setdefaulttensortype(tensortype)
images = util.deprocess_batch(util.scaleBatch(img_data:float(),win_size,win_size))
disp.image(images, {win=display_id, title=title})
end

function visualizer.disp_images(imgs, opt)
local tensortype = torch.getdefaulttensortype()
disp_imgs = {}
for i,img in ipairs(imgs) do
disp_img = util.deprocess_batch(util.scaleBatch(img:float(), opt.win_size, opt.win_size))
table.insert(disp_imgs, disp_img[1])
end
disp.images(disp_imgs, {opt.win_size*3, labels=opt.labels, win=opt.display_id})
torch.setdefaulttensortype(tensortype)
end
--
function visualizer.save_results(visuals, opt, epoch, counter)
function visualizer.save_results(img_data, output_path)
local tensortype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
local image_out = nil
local win_size = opt.display_winsize
for i,visual in ipairs(visuals) do
im = torch.squeeze(util.deprocess_batch(util.scaleBatch(visual.img:float(), win_size, win_size)))

if image_out == nil then
image_out = im
else
image_out = torch.cat(image_out, im)
images = torch.squeeze(util.deprocess_batch(util.scaleBatch(img_data:float(), win_size, win_size)))

if images:dim() == 3 then
image_out = images
else
for i = 1,images:size(1) do
img = images[i]
if image_out == nil then
image_out = img
else
image_out = torch.cat(image_out, img)
end
end
end

out_path = paths.concat(opt.checkpoints_dir, opt.name, 'epoch' .. epoch .. '_iter' .. counter .. '_train_res.png')
image.save(out_path, image_out)
image.save(output_path, image_out)
torch.setdefaulttensortype(tensortype)
end

function visualizer.save_images(imgs, save_dir, impaths, s1, s2)
local tensortype = torch.getdefaulttensortype()
torch.setdefaulttensortype('torch.FloatTensor')
print('saving images', save_dir)
batchSize = imgs:size(1)
batchSize = imgs:size(0)
imgs_f = util.deprocess_batch(imgs):float()
paths.mkdir(save_dir)
for i = 1, batchSize do -- imgs_f[i]:size(2), imgs_f[i]:size(3)/opt.aspect_ratio
Expand Down

0 comments on commit 55cb596

Please sign in to comment.