Skip to content

Commit

Permalink
Add skip testing a network after training
Browse files Browse the repository at this point in the history
  • Loading branch information
farrajota committed May 24, 2017
1 parent aa65ad3 commit 1942a7e
Showing 1 changed file with 65 additions and 12 deletions.
77 changes: 65 additions & 12 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,65 @@ local function train(data_gen, rois, model, modelParameters, opts)

-- set number of iterations
local nItersTrain = opt.trainIters
local nItersTest = dataLoadTable.test.nfiles/opt.frcnn_imgs_per_batch
local nItersTest
if dataLoadTable.test then
nItersTest = dataLoadTable.test.nfiles/opt.frcnn_imgs_per_batch
end

-- classes
local classes = utils.table.concatTables({'background'}, dataLoadTable.train.classLabel)

-- convert modules to a specified tensor type
local function cast(x) return x:type(opt.dataType) end

--[[
local function getIterator_test(mode)
return tnt.ParallelDatasetIterator{
nthread = opt.nThreads,
init = function(threadid)
require 'torch'
require 'torchnet'
opt = lopt
paths.dofile('/home/mf/Toolkits/Codigo/git/fastrcnn/init.lua')
torch.manualSeed(threadid+opt.manualSeed)
end,
closure = function()
-- data loader/generator
local data_loader = data_gen()
local batchprovider = fastrcnn.BatchROISampler(data_loader[mode], rois[mode], modelParameters, opt, mode)
-- number of iterations per epoch
local nIters = data_loader[mode].nfiles
-- setup dataset iterator
local list_dataset = tnt.ListDataset{
list = torch.range(1, nIters):long(),
load = function(idx)
return batchprovider:getSample(idx)
end
}
return list_dataset
end,
}
end
for _, mode in pairs({'train', 'test'}) do
print('Starting mode: ', mode)
local iter = getIterator_test(mode)
local idx = 1
local data_loader = data_gen()
local nfiles = data_loader[mode].nfiles
for saple in iter() do
xlua.progress(idx, nfiles)
idx = idx + 1
end
end
os.exit()
--]]



--------------------------------------------------------------------------------
-- Setup data generator
Expand Down Expand Up @@ -248,8 +299,8 @@ local function train(data_gen, rois, model, modelParameters, opts)
meters:reset()

-- store model
--modelStorageFn(state.network.modules[1], modelParameters, state.config, state.epoch, state.maxepoch, opt)
modelStorageFn(modelOut.modules[1], modelParameters, state.config, state.epoch, state.maxepoch, opt)
modelStorageFn(state.network.modules[1], modelParameters, state.config, state.epoch, state.maxepoch, opt)
--modelStorageFn(modelOut.modules[1], modelParameters, state.config, state.epoch, state.maxepoch, opt)
state.t = 0
end
end
Expand Down Expand Up @@ -303,16 +354,18 @@ local function train(data_gen, rois, model, modelParameters, opts)
-- Test the model
--------------------------------------------------------------------------------

print('\n')
print('**********************************************')
print('*** Test the network ')
print('**********************************************')
if dataLoadTable.test then
print('\n')
print('**********************************************')
print('*** Test the network ')
print('**********************************************')

engine:test{
network = modelOut,
iterator = getIterator('test'),
criterion = criterion
}
engine:test{
network = modelOut,
iterator = getIterator('test'),
criterion = criterion
}
end


--------------------------------------------------------------------------------
Expand Down

0 comments on commit 1942a7e

Please sign in to comment.