Skip to content

Commit

Permalink
nn.clearState
Browse files Browse the repository at this point in the history
  • Loading branch information
szagoruyko committed Feb 9, 2016
1 parent 4a43346 commit b4ebdf2
Show file tree
Hide file tree
Showing 50 changed files with 380 additions and 67 deletions.
11 changes: 11 additions & 0 deletions BatchNormalization.lua
Original file line number Diff line number Diff line change
Expand Up @@ -154,3 +154,14 @@ function BN:accGradParameters(input, gradOutput, scale)
self.gradBias:add(scale, self.buffer)
end
end

function BN:clearState()
nn.utils.clear(self, {
'buffer',
'buffer2',
'centered',
'std',
'normalized',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions Bilinear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,8 @@ function Bilinear:__tostring__()
(self.bias == nil and ' without bias' or '')
)
end

function Bilinear:clearState()
if self.buff then self.buff:set() end
return parent.clearState(self)
end
21 changes: 14 additions & 7 deletions CMul.lua
Original file line number Diff line number Diff line change
Expand Up @@ -116,13 +116,20 @@ end

function CMul:type(type, tensorCache)
if type then
self._input = nil
self._output = nil
self._weight = nil
self._gradWeight = nil
self._expand = nil
self._repeat = nil
self._sum = nil
self:clearState()
end
return parent.type(self, type, tensorCache)
end

function CMul:clearState()
nn.utils.clear(self, {
'_input',
'_output',
'_weight',
'_gradWeight',
'_expand',
'_repeat',
'_sum',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions CMulTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,8 @@ function CMulTable:updateGradInput(input, gradOutput)

return self.gradInput
end

function CMulTable:clearState()
if self.tout then self.tout:set() end
return parent.clearState(self)
end
23 changes: 23 additions & 0 deletions Container.lua
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,26 @@ function Container:parameters()
end
return w,gw
end

function Container:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
if self.modules then
for i,module in pairs(self.modules) do
module:clearState()
end
end
return self
end
12 changes: 12 additions & 0 deletions Cosine.lua
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,15 @@ function Cosine:type(type, tensorCache)
end
return parent.type(self, type, tensorCache)
end

function Cosine:clearState()
nn.utils.clear(self, {
'_input',
'_weight',
'_gradOutput',
'_sum',
'_inputNorm',
'_weightNorm',
})
return parent.clearState(self)
end
17 changes: 17 additions & 0 deletions CosineDistance.lua
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,11 @@ function CosineDistance:updateGradInput(input, gradOutput)
not_batch = true
end

if #self.gradInput ~= 2 then
self.gradInput[1] = self.gradInput[1] or v1.new()
self.gradInput[2] = self.gradInput[2] or v1.new()
end

local gw1 = self.gradInput[1]
local gw2 = self.gradInput[2]
gw1:resizeAs(v1):copy(v2)
Expand All @@ -97,3 +102,15 @@ function CosineDistance:updateGradInput(input, gradOutput)

return self.gradInput
end

function CosineDistance:clearState()
nn.utils.clear(self, {
'buffer',
'w1',
'w22',
'w',
'w32',
'ones',
})
return parent.clearState(self)
end
10 changes: 10 additions & 0 deletions DotProduct.lua
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ function DotProduct:updateGradInput(input, gradOutput)
local v2 = input[2]
local not_batch = false

if #self.gradInput ~= 2 then
self.gradInput[1] = self.gradInput[1] or input[1].new()
self.gradInput[2] = self.gradInput[2] or input[2].new()
end

if v1:dim() == 1 then
v1 = v1:view(1,-1)
v2 = v2:view(1,-1)
Expand All @@ -49,3 +54,8 @@ function DotProduct:updateGradInput(input, gradOutput)

return self.gradInput
end

function DotProduct:clearState()
if self.buffer then self.buffer:set() end
return parent.clearState(self)
end
8 changes: 8 additions & 0 deletions Dropout.lua
Original file line number Diff line number Diff line change
Expand Up @@ -64,3 +64,11 @@ end
function Dropout:__tostring__()
return string.format('%s(%f)', torch.type(self), self.p)
end


function Dropout:clearState()
if self.noise then
self.noise:set()
end
return Parent.clearState(self)
end
29 changes: 18 additions & 11 deletions Euclidean.lua
Original file line number Diff line number Diff line change
Expand Up @@ -174,17 +174,24 @@ end
function Euclidean:type(type, tensorCache)
if type then
-- prevent premature memory allocations
self._input = nil
self._output = nil
self._gradOutput = nil
self._weight = nil
self._div = nil
self._sum = nil
self._expand = nil
self._expand2 = nil
self._expand3 = nil
self._repeat = nil
self._repeat2 = nil
self:clearState()
end
return parent.type(self, type, tensorCache)
end

function Euclidean:clearState()
nn.utils.clear(self, {
'_input',
'_output',
'_gradOutput',
'_weight',
'_div',
'_sum',
'_expand',
'_expand2',
'_expand3',
'_repeat',
'_repeat2',
})
return parent.clearState(self)
end
6 changes: 5 additions & 1 deletion FlattenTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ end
function FlattenTable:type(type, tensorCache)
-- This function just stores references so we don't need to do any type
-- conversions. Just force the tables to be empty.
self.output = {}
self:clearState()
end

function FlattenTable:clearState()
self.input_map = {}
return parent.clearState(self)
end
2 changes: 1 addition & 1 deletion GradientReversal.lua
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
local GradientReversal = torch.class('nn.GradientReversal', 'nn.Module')

function GradientReversal:updateOutput(input)
self.output = input
self.output:set(input)
return self.output
end

Expand Down
18 changes: 18 additions & 0 deletions Identity.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,21 @@ function Identity:updateGradInput(input, gradOutput)
self.gradInput = gradOutput
return self.gradInput
end

function Identity:clearState()
-- don't call set because it might reset referenced tensors
local function clear(f)
if self[f] then
if torch.isTensor(self[f]) then
self[f] = self[f].new()
elseif type(self[f]) == 'table' then
self[f] = {}
else
self[f] = nil
end
end
end
clear('output')
clear('gradInput')
return self
end
1 change: 1 addition & 0 deletions Jacobian.lua
Original file line number Diff line number Diff line change
Expand Up @@ -305,6 +305,7 @@ function nn.Jacobian.testIO(module,input, minval, maxval)
-- write module
local filename = os.tmpname()
local f = torch.DiskFile(filename, 'w'):binary()
module:clearState()
f:writeObject(module)
f:close()
-- read module
Expand Down
5 changes: 5 additions & 0 deletions L1Cost.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,8 @@ function L1Cost:updateGradInput(input)
)
return self.gradInput
end

function L1Cost:clearState()
if self.output_tensor then self.output_tensor:set() end
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions L1Penalty.lua
Original file line number Diff line number Diff line change
Expand Up @@ -41,3 +41,7 @@ function L1Penalty:updateGradInput(input, gradOutput)
return self.gradInput
end

function L1Penalty:clearState()
if self.loss then self.loss:set() end
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions Linear.lua
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,10 @@ end
-- we do not need to accumulate parameters when sharing
Linear.sharedAccUpdateGradParameters = Linear.accUpdateGradParameters

function Linear:clearState()
if self.addBuffer then self.addBuffer:set() end
return parent.clearState(self)
end

function Linear:__tostring__()
return torch.type(self) ..
Expand Down
6 changes: 6 additions & 0 deletions LogSigmoid.lua
Original file line number Diff line number Diff line change
Expand Up @@ -23,3 +23,9 @@ function LogSigmoid:updateGradInput(input, gradOutput)
)
return self.gradInput
end

function LogSigmoid:clearState()
if self.buffer then self.buffer:set() end
return parent.clearState(self)
end

4 changes: 4 additions & 0 deletions LookupTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -100,5 +100,9 @@ function LookupTable:type(type, tensorCache)
return self
end

function LookupTable:clearState()
return self
end

-- we do not need to accumulate parameters when sharing
LookupTable.sharedAccUpdateGradParameters = LookupTable.accUpdateGradParameters
5 changes: 5 additions & 0 deletions Max.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ function Max:type(type, tensorCache)
end
return self
end

function Max:clearState()
nn.utils.clear(self, '_indices', '_output')
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions Min.lua
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,8 @@ function Min:type(type, tensorCache)
end
return self
end

function Min:clearState()
nn.utils.clear(self, '_indices', '_output')
return parent.clearState(self)
end
12 changes: 12 additions & 0 deletions MixtureTable.lua
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,15 @@ function MixtureTable:type(type, tensorCache)
self._expertView2 = nil
return parent.type(self, type, tensorCache)
end

function MixtureTable:clearState()
nn.utils.clear(self, {
'_gaterView',
'_expert',
'_expertView',
'_sum',
'_expert2',
'_expertView2',
})
return parent.clearState(self)
end
4 changes: 4 additions & 0 deletions Module.lua
Original file line number Diff line number Diff line change
Expand Up @@ -364,3 +364,7 @@ function Module:listModules()
end
return modules
end

function Module:clearState()
return nn.utils.clear(self, 'output', 'gradInput')
end
13 changes: 13 additions & 0 deletions Normalize.lua
Original file line number Diff line number Diff line change
Expand Up @@ -140,3 +140,16 @@ function Normalize:type(type, tensorCache)
end
return self
end

function Normalize:clearState()
nn.utils.clear(self, {
'_output',
'_indices',
'_gradInput',
'buffer',
'norm',
'normp',
'cross',
})
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions PReLU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,8 @@ function PReLU:accGradParameters(input, gradOutput, scale)
)
return self.gradWeight
end

function PReLU:clearState()
nn.utils.clear(self, 'gradWeightBuf', 'gradWeightBuf2')
return parent.clearState(self)
end
6 changes: 6 additions & 0 deletions PairwiseDistance.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function PairwiseDistance:__init(p)
end

function PairwiseDistance:updateOutput(input)
self.output:resize(1)
if input[1]:dim() == 1 then
self.output:resize(1)
self.output[1]=input[1]:dist(input[2],self.norm)
Expand Down Expand Up @@ -83,3 +84,8 @@ function PairwiseDistance:updateGradInput(input, gradOutput)
self.gradInput[2]:zero():add(-1, self.gradInput[1])
return self.gradInput
end

function PairwiseDistance:clearState()
nn.utils.clear(self, 'diff', 'outExpand', 'grad', 'ones')
return parent.clearState(self)
end
5 changes: 5 additions & 0 deletions RReLU.lua
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,8 @@ end
function RReLU:__tostring__()
return string.format('%s (l:%f, u:%f)', torch.type(self), self.lower, self.upper)
end

function RReLU:clearState()
if self.noise then self.noise:set() end
return parent.clearState(self)
end
Loading

0 comments on commit b4ebdf2

Please sign in to comment.