Skip to content

Commit dc9cbe0

Browse files
committed
Adds Normalize
1 parent 33196ac commit dc9cbe0

File tree

4 files changed

+149
-1
lines changed

4 files changed

+149
-1
lines changed

Normalize.lua

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,87 @@
1+
local Normalize, parent = torch.class('nn.Normalize', 'nn.Module')
2+
3+
function Normalize:__init(p,eps)
4+
parent.__init(self)
5+
assert(p,'p-norm not provided')
6+
assert(p > 0, p..'-norm not supported')
7+
self.p = p
8+
self.eps = eps or 1e-10
9+
end
10+
11+
function Normalize:updateOutput(input)
12+
assert(input:dim() <= 2, 'only 1d layer supported')
13+
local is_batch = true
14+
if input:dim() == 1 then
15+
input = input:view(1,-1)
16+
is_batch = false
17+
end
18+
19+
self.output:resizeAs(input)
20+
21+
self.norm = self.norm or input.new()
22+
self.normp = self.normp or input.new()
23+
self.buffer = self.buffer or input.new()
24+
25+
if self.p % 2 ~= 0 then
26+
self.buffer:abs(input):pow(self.p)
27+
else
28+
self.buffer:pow(input,self.p)
29+
end
30+
self.normp:sum(self.buffer,2):add(self.eps)
31+
self.norm:pow(self.normp,1/self.p)
32+
self.output:cdiv(input,self.norm:view(-1,1):expandAs(self.output))
33+
34+
if not is_batch then
35+
self.output = self.output[1]
36+
end
37+
return self.output
38+
end
39+
40+
function Normalize:updateGradInput(input, gradOutput)
41+
assert(input:dim() <= 2, 'only 1d layer supported')
42+
assert(gradOutput:dim() <= 2, 'only 1d layer supported')
43+
44+
local is_batch = true
45+
if input:dim() == 1 then
46+
input = input:view(1,-1)
47+
is_batch = false
48+
end
49+
50+
local n = input:size(1) -- batch size
51+
local d = input:size(2) -- dimensionality of vectors
52+
-- compute diagonal term
53+
self.eye = self.eye or torch.eye(d):typeAs(input):view(1,d,d)
54+
local eyeExpand = self.eye:expand(n,d,d)
55+
self.diag = self.diag or self.eye.new()
56+
self.diag:cmul(eyeExpand, self.normp:view(n,1,1):expand(n,d,d))
57+
-- compute cross term
58+
self.buffer:abs(input):pow(self.p-2):cmul(input)
59+
local b1 = self.buffer:view(n,d,1)
60+
local b2 = input:view(n,1,d)
61+
62+
self.diag:baddbmm(-1,b1,b2)
63+
-- compute the local gradient of the Lp transformation
64+
self.buffer:cmul(self.normp,self.norm)
65+
self.diag:cdiv(self.buffer:view(n,1,1):expand(n,d,d))
66+
-- chain the gradient
67+
self.gradInput:resize(n,d,1)
68+
self.gradInput:bmm(self.diag, gradOutput:view(n,d,1))
69+
self.gradInput = self.gradInput:view(n,d)
70+
71+
if not is_batch then
72+
self.gradInput = self.gradInput[1]
73+
end
74+
75+
return self.gradInput
76+
end
77+
78+
function Normalize:__tostring__()
79+
local s
80+
-- different prints if the norm is integer
81+
if self.p % 1 == 0 then
82+
s = '%s(%d)'
83+
else
84+
s = '%s(%f)'
85+
end
86+
return string.format(s,torch.type(self),self.p)
87+
end

doc/simple.md

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
<a name="nn.simplelayers.dok"></a>
22
# Simple layers #
33
Simple Modules are used for various tasks like adapting Tensor methods and providing affine transformations :
4-
54
* Parameterized Modules :
65
* [Linear](#nn.Linear) : a linear transformation ;
76
* [SparseLinear](#nn.SparseLinear) : a linear transformation with sparse inputs ;
@@ -28,6 +27,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
2827
* [Power](#nn.Power) : an element-wise [pow](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchpowres-x) operation ;
2928
* [Square](#nn.Square) : an element-wise square operation ;
3029
* [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ;
30+
* [Normalize](#nn.Normalize) : normalizes the input to have unit `L_p` norm ;
3131
* [MM](#nn.MM) : matrix-matrix multiplication (also supports batches of matrices) ;
3232
* Miscellaneous Modules :
3333
* [BatchNormalization](#nn.BatchNormalization) - mean/std normalization over the mini-batch inputs (with an optional affine transform) ;
@@ -886,6 +886,23 @@ gnuplot.grid(true)
886886

887887
![](image/power.png)
888888

889+
<a name="nn.Normalize"></a>
890+
## Normalize ##
891+
892+
```lua
893+
module = nn.Normalize(p, [eps])
894+
```
895+
Normalizes the input Tensor to have unit `L_p` norm. The smoothing parameter `eps` prevents division by zero when the input contains all zero elements (default = `1e-10`).
896+
897+
Input can be 1D or 2D (in which case it's considered as in batch mode)
898+
899+
```lua
900+
A = torch.randn(3, 5)
901+
m = nn.Normalize(2)
902+
B = m:forward(A) -- B is also 3 x 5
903+
-- take the L2 norm over the second axis:
904+
print(torch.norm(B, 2, 2)) -- norms is [1, 1, 1]
905+
```
889906

890907
<a name="nn.MM"></a>
891908
## MM ##

init.lua

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ include('WeightedEuclidean.lua')
4646
include('PairwiseDistance.lua')
4747
include('CosineDistance.lua')
4848
include('DotProduct.lua')
49+
include('Normalize.lua')
4950

5051
include('Exp.lua')
5152
include('Log.lua')

test.lua

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -412,6 +412,49 @@ function nntest.Power()
412412
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
413413
end
414414

415+
function nntest.Normalize()
416+
-- compare forward against torch implementation
417+
-- and check gradient
418+
for _,p in pairs({1,2,1.5}) do
419+
local ini = math.random(3,10)
420+
local input = torch.randn(ini)
421+
local module = nn.Normalize(p)
422+
local out = module:forward(input)
423+
local expected = torch.div(input,input:norm(p))
424+
mytester:assertTensorEq(out, expected, 1e-7,
425+
torch.typename(module) ..' (' .. p ..') - forward err ')
426+
427+
local err = jac.testJacobian(module, input, -2, 2)
428+
mytester:assertlt(err, precision, 'error norm '..p..' on state ')
429+
end
430+
431+
-- batch mode
432+
for _,p in pairs({1,2,torch.uniform()*math.random(1,10)}) do
433+
local ini = math.random(3,5)
434+
local inj = math.random(3,5)
435+
local ink = math.random(3,5)
436+
local input = torch.Tensor(inj, ini):zero()
437+
438+
local module = nn.Normalize(p)
439+
440+
local err = jac.testJacobian(module, input, -2, 2)
441+
mytester:assertlt(err, precision, 'error norm '..p..' on state ')
442+
end
443+
444+
-- test IO correctness
445+
local ini = math.random(3,5)
446+
local inj = math.random(3,5)
447+
local ink = math.random(3,5)
448+
local input = torch.Tensor(inj, ini):zero()
449+
450+
local module = nn.Normalize(2)
451+
452+
local ferr, berr = jac.testIO(module,input, 0.1, 2)
453+
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
454+
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
455+
456+
end
457+
415458
function nntest.Square()
416459
local in1 = torch.rand(5,7)
417460
local module = nn.Square()

0 commit comments

Comments
 (0)