diff --git a/utils/model.lua b/utils/model.lua index 681e23f..bb4ecc0 100644 --- a/utils/model.lua +++ b/utils/model.lua @@ -254,11 +254,13 @@ local function store(model, modelParameters, optimState, epoch, opt, flag) print('Saving model snapshot to: ' .. filename_model) torch.save(filename_optimstate, optimState) torch.save(opt.curr_save_configs, snapshot_configs(filename_model, epoch, opt)) + if opt.clear_buffers then - torch.save(filename_model, {resetDataParallel(model:clearState(), opt.GPU), modelParameters, info}) - else - torch.save(filename_model, {resetDataParallel(model, opt.GPU), modelParameters, info}) + model = model:clearState() end + setDataParallel(model, opt.GPU, 1) -- set nn.DataParallelTable to use only 1 GPU + torch.save(filename_model, {model, modelParameters, info}) + setDataParallel(model, opt.GPU, opt.nGPU) -- make a symlink to the last trained model local filename_symlink = paths.concat(opt.savedir,'model_final.t7')