Skip to content

Commit

Permalink
Add gradient clipping
Browse files Browse the repository at this point in the history
  • Loading branch information
farrajota committed Jul 7, 2017
1 parent f276b62 commit 0f12cf6
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 8 deletions.
4 changes: 3 additions & 1 deletion Options.lua
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ function Options:parse(opts)
opt.nThreads, opt.verbose, opt.progressbar, opt.printConfusion,
opt.LR, opt.LRdecay, opt.momentum, opt.weightDecay, opt.optMethod,
opt.threshold, opt.trainIters, opt.epochStart, opt.schedule, opt.continue,
opt.clear_buffers, opt.snapshot, opt.optimize, opt.testInter, opt.frcnn_scales,
opt.clear_buffers, opt.snapshot, opt.optimize, opt.testInter, opt.grad_clip, opt.frcnn_scales,
opt.frcnn_max_size, opt.frcnn_imgs_per_batch, opt.frcnn_rois_per_img,
opt.frcnn_fg_fraction, opt.frcnn_bg_fraction, opt.frcnn_fg_thresh,
opt.frcnn_bg_thresh_hi, opt.frcnn_bg_thresh_lo, opt.frcnn_bbox_thresh,
Expand Down Expand Up @@ -91,6 +91,8 @@ function Options:parse(opts)
help='Optimize network memory usage using optnet.'},
{arg='testInter', type='boolean', default=true,
help='If true, does intermediate testing of the model. Else it only tests the network at the end of the train.'},
{arg='grad_clip', type='number', default=0,
help='Gradient clipping (to prevent exploding gradients).'},

-------------------------------------------------------------------------------
-- FRCNN Training options
Expand Down
31 changes: 24 additions & 7 deletions train.lua
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,16 @@ local function train(data_gen, rois, model, modelParameters, opts)
end


-- copy sample to GPU buffer:
local samples = {}
engine.hooks.onSample = function(state)
cutorch.synchronize(); collectgarbage();
utils.table.recursiveCast(samples, state.sample, 'torch.CudaTensor')
state.sample.input = samples[1]
state.sample.target = samples[2]
end


engine.hooks.onForwardCriterion = function(state)
if state.training then
meters.train_conf:add(state.network.output[1],state.sample.target[1])
Expand Down Expand Up @@ -203,13 +213,20 @@ local function train(data_gen, rois, model, modelParameters, opts)
end
end

-- copy sample to GPU buffer:
local samples = {}
engine.hooks.onSample = function(state)
cutorch.synchronize(); collectgarbage();
utils.table.recursiveCast(samples, state.sample, 'torch.CudaTensor')
state.sample.input = samples[1]
state.sample.target = samples[2]

--[[ Gradient clipping to try to prevent the gradient from exploding. ]]--
-- ref: https://github.com/facebookresearch/torch-rnnlib/blob/master/examples/word-language-model/word_lm.lua#L216-L233
local function clipGradients(grads, norm)
local totalnorm = grads:norm()
if totalnorm > norm then
local coeff = norm / math.max(totalnorm, 1e-6)
grads:mul(coeff)
end
end
engine.hooks.onBackward = function(state)
if opt.grad_clip > 0 then
clipGradients(state.gradParams, opt.grad_clip)
end
end


Expand Down

0 comments on commit 0f12cf6

Please sign in to comment.