|
| 1 | + |
| 2 | +local FeatureLPPooling, parent = |
| 3 | + torch.class('nn.FeatureLPPooling', 'nn.Module') |
| 4 | + |
| 5 | +--[[ |
| 6 | + Possible inputs that we handle: |
| 7 | +
|
| 8 | + #### `batch_mode = false` |
| 9 | + The dimensionality of the input chooses between the following modes: |
| 10 | +
|
| 11 | + ``` |
| 12 | + [feature dim] |
| 13 | + [feature dim][opt dim 1] |
| 14 | + [feature dim][opt dim 1][opt dim 2] |
| 15 | + ``` |
| 16 | +
|
| 17 | + #### `batch_mode = true` |
| 18 | + The dimensionality of the input chooses between the following modes: |
| 19 | + ``` |
| 20 | + [batch dim][feature dim] |
| 21 | + [batch dim][feature dim][opt dim 1] |
| 22 | + [batch dim][feature dim][opt dim 1][opt dim 2] |
| 23 | + ``` |
| 24 | +
|
| 25 | + The output has the same number of dimensions as the input, except the feature |
| 26 | + dimension size is reduced to ((`input` - `width`) / `stride`) + 1 |
| 27 | +]] |
| 28 | +function FeatureLPPooling:__init(width, stride, power, batch_mode) |
| 29 | + parent.__init(self) |
| 30 | + |
| 31 | + if (width < 2 or width > 16) then |
| 32 | + error('width must be within 2 to 16') |
| 33 | + end |
| 34 | + |
| 35 | + if (stride < 1 or stride > 4) then |
| 36 | + error('stride must be within 1 to 4') |
| 37 | + end |
| 38 | + |
| 39 | + self.width = width |
| 40 | + self.stride = stride |
| 41 | + self.power = power |
| 42 | + self.batch_mode = batch_mode |
| 43 | + |
| 44 | + self.output = torch.Tensor() |
| 45 | + self.gradInput = torch.Tensor() |
| 46 | +end |
| 47 | + |
| 48 | +function FeatureLPPooling:updateOutput(input) |
| 49 | + input.THNN.FeatureLPPooling_updateOutput(input:cdata(), |
| 50 | + self.output:cdata(), |
| 51 | + self.power, |
| 52 | + self.width, |
| 53 | + self.stride, |
| 54 | + self.batch_mode) |
| 55 | + return self.output |
| 56 | +end |
| 57 | + |
| 58 | +function FeatureLPPooling:updateGradInput(input, gradOutput) |
| 59 | + input.THNN.FeatureLPPooling_updateGradInput(gradOutput:cdata(), |
| 60 | + input:cdata(), |
| 61 | + self.output:cdata(), |
| 62 | + self.gradInput:cdata(), |
| 63 | + self.power, |
| 64 | + self.width, |
| 65 | + self.stride, |
| 66 | + self.batch_mode) |
| 67 | + return self.gradInput |
| 68 | +end |
| 69 | + |
| 70 | +function FeatureLPPooling:__tostring__() |
| 71 | + return string.format('%s(w%d s%d power %d batch %d', |
| 72 | + torch.type(self), |
| 73 | + self.width, self.stride, self.power, self.batch_mode) |
| 74 | +end |
0 commit comments