Skip to content

Commit

Permalink
Merge pull request torch#622 from torch/mulfix
Browse files Browse the repository at this point in the history
fix for nn.Mul for cuda defaults
  • Loading branch information
soumith committed Feb 9, 2016
2 parents bf6e425 + f5acc4e commit 4a43346
Showing 1 changed file with 7 additions and 7 deletions.
14 changes: 7 additions & 7 deletions Mul.lua
Original file line number Diff line number Diff line change
Expand Up @@ -2,37 +2,37 @@ local Mul, parent = torch.class('nn.Mul', 'nn.Module')

function Mul:__init()
parent.__init(self)

self.weight = torch.Tensor(1)
self.gradWeight = torch.Tensor(1)

self:reset()
end


function Mul:reset(stdv)
if stdv then
stdv = stdv * math.sqrt(3)
else
stdv = 1./math.sqrt(self.weight:size(1))
end

self.weight[1] = torch.uniform(-stdv, stdv);
self.weight:uniform(-stdv, stdv);
end

function Mul:updateOutput(input)
self.output:resizeAs(input):copy(input);
self.output:mul(self.weight[1]);
return self.output
return self.output
end

function Mul:updateGradInput(input, gradOutput)
function Mul:updateGradInput(input, gradOutput)
self.gradInput:resizeAs(input):zero()
self.gradInput:add(self.weight[1], gradOutput)
return self.gradInput
end

function Mul:accGradParameters(input, gradOutput, scale)
function Mul:accGradParameters(input, gradOutput, scale)
scale = scale or 1
self.gradWeight[1] = self.gradWeight[1] + scale*input:dot(gradOutput);
end

0 comments on commit 4a43346

Please sign in to comment.