Skip to content

Commit f613412

Browse files
authored
Merge pull request #1259 from wickedfoo/feature_lp_pooling
CPU implementation of L_p feature pooling
2 parents 14cedef + 9c5ddcc commit f613412

File tree

6 files changed

+759
-1
lines changed

6 files changed

+759
-1
lines changed

FeatureLPPooling.lua

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
2+
local FeatureLPPooling, parent =
3+
torch.class('nn.FeatureLPPooling', 'nn.Module')
4+
5+
--[[
6+
Possible inputs that we handle:
7+
8+
#### `batch_mode = false`
9+
The dimensionality of the input chooses between the following modes:
10+
11+
```
12+
[feature dim]
13+
[feature dim][opt dim 1]
14+
[feature dim][opt dim 1][opt dim 2]
15+
```
16+
17+
#### `batch_mode = true`
18+
The dimensionality of the input chooses between the following modes:
19+
```
20+
[batch dim][feature dim]
21+
[batch dim][feature dim][opt dim 1]
22+
[batch dim][feature dim][opt dim 1][opt dim 2]
23+
```
24+
25+
The output has the same number of dimensions as the input, except the feature
26+
dimension size is reduced to ((`input` - `width`) / `stride`) + 1
27+
]]
28+
function FeatureLPPooling:__init(width, stride, power, batch_mode)
29+
parent.__init(self)
30+
31+
if (width < 2 or width > 16) then
32+
error('width must be within 2 to 16')
33+
end
34+
35+
if (stride < 1 or stride > 4) then
36+
error('stride must be within 1 to 4')
37+
end
38+
39+
self.width = width
40+
self.stride = stride
41+
self.power = power
42+
self.batch_mode = batch_mode
43+
44+
self.output = torch.Tensor()
45+
self.gradInput = torch.Tensor()
46+
end
47+
48+
function FeatureLPPooling:updateOutput(input)
49+
input.THNN.FeatureLPPooling_updateOutput(input:cdata(),
50+
self.output:cdata(),
51+
self.power,
52+
self.width,
53+
self.stride,
54+
self.batch_mode)
55+
return self.output
56+
end
57+
58+
function FeatureLPPooling:updateGradInput(input, gradOutput)
59+
input.THNN.FeatureLPPooling_updateGradInput(gradOutput:cdata(),
60+
input:cdata(),
61+
self.output:cdata(),
62+
self.gradInput:cdata(),
63+
self.power,
64+
self.width,
65+
self.stride,
66+
self.batch_mode)
67+
return self.gradInput
68+
end
69+
70+
function FeatureLPPooling:__tostring__()
71+
return string.format('%s(w%d s%d power %d batch %d',
72+
torch.type(self),
73+
self.width, self.stride, self.power, self.batch_mode)
74+
end

init.lua

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,8 @@ require('nn.VolumetricAveragePooling')
160160
require('nn.VolumetricBatchNormalization')
161161
require('nn.VolumetricReplicationPadding')
162162

163+
require('nn.FeatureLPPooling')
164+
163165
require('nn.GPU')
164166

165167
require('nn.ParallelTable')

0 commit comments

Comments
 (0)