Skip to content

Commit

Permalink
Simplify model saving function
Browse files Browse the repository at this point in the history
  • Loading branch information
farrajota committed May 30, 2017
1 parent 23aa994 commit d5267e1
Showing 1 changed file with 5 additions and 3 deletions.
8 changes: 5 additions & 3 deletions utils/model.lua
Original file line number Diff line number Diff line change
Expand Up @@ -254,11 +254,13 @@ local function store(model, modelParameters, optimState, epoch, opt, flag)
print('Saving model snapshot to: ' .. filename_model)
torch.save(filename_optimstate, optimState)
torch.save(opt.curr_save_configs, snapshot_configs(filename_model, epoch, opt))

if opt.clear_buffers then
torch.save(filename_model, {resetDataParallel(model:clearState(), opt.GPU), modelParameters, info})
else
torch.save(filename_model, {resetDataParallel(model, opt.GPU), modelParameters, info})
model = model:clearState()
end
setDataParallel(model, opt.GPU, 1) -- set nn.DataParallelTable to use only 1 GPU
torch.save(filename_model, {model, modelParameters, info})
setDataParallel(model, opt.GPU, opt.nGPU)

-- make a symlink to the last trained model
local filename_symlink = paths.concat(opt.savedir,'model_final.t7')
Expand Down

0 comments on commit d5267e1

Please sign in to comment.