Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 40 additions & 0 deletions L2Normalize.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@

--[[
This layer expects an [n x d] Tensor and normalizes each
row to have unit L2 norm.
]]--
local L2Normalize, parent = torch.class('nn.L2Normalize', 'nn.Module')
function L2Normalize:__init()
parent.__init(self)
end
function L2Normalize:updateOutput(input)
assert(input:dim() == 2, 'only mini-batch supported (2D tensor), got '
.. input:dim() .. 'D tensor instead')
self.output:resizeAs(input)
self.buffer = self.buffer or input.new()
self.normSquared = self.normSquared or input.new()
self.normSquared:sum(self.buffer:cmul(input, input), 2)
self.buffer:sqrt(self.normSquared)
self.output:copy(input):cdiv(self.buffer:expandAs(input))
return self.output
end

function L2Normalize:updateGradInput(input, gradOutput)
assert(input:dim() == 2, 'only mini-batch supported')
assert(gradOutput:dim() == 2, 'only mini-batch supported')
local n = input:size(1) -- batch size
local d = input:size(2) -- dimensionality of vectors
-- compute diagonal term
self.eye = self.eye or torch.eye(d):typeAs(input):repeatTensor(n,1):view(n,d,d)
self.diag = self.diag or self.eye.new()
self.diag:cmul(self.eye, self.normSquared:view(n,1,1):expand(n,d,d))
-- compute cross term
local b1 = input:view(n,d,1)
local b2 = input:view(n,1,d)
self.diag:add(-torch.bmm(b1,b2))
-- compute the local gradient of the L2 transformation
self.diag:cdiv(torch.pow(self.buffer,3):view(n,1,1):expand(n,d,d))
-- chain the gradient
self.gradInput:resize(n,d,1):bmm(self.diag, gradOutput:view(n,d,1)):resize(n,d)
return self.gradInput
end
1 change: 1 addition & 0 deletions init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ include('WeightedEuclidean.lua')
include('PairwiseDistance.lua')
include('CosineDistance.lua')
include('DotProduct.lua')
include('L2Normalize.lua')

include('Exp.lua')
include('Log.lua')
Expand Down
23 changes: 23 additions & 0 deletions test.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3554,6 +3554,29 @@ function nntest.Padding()
mytester:assertTensorEq(gradInput, input, 0.00001, "Padding backward error")
end

function nntest.L2Normalize()
local ini = math.random(6,8)
local inj = math.random(3,5)
local input = torch.randn(ini, inj)

local module = nn.L2Normalize()

-- test correctness of output
local output = module:forward(input)
local norms = torch.norm(output, 2, 2)
local desired_norms = torch.ones(ini)
mytester:assertTensorEq(norms, desired_norms, 0.000001, 'L2Normalize forward err')

-- test the Jacobian
local err = jac.testJacobian(module,input)
mytester:assertlt(err, precision, 'error on state ')

-- test IO correctness
local ferr, berr = jac.testIO(module,input)
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
end

mytester:add(nntest)

if not nn then
Expand Down