Skip to content

Commit 89a1970

Browse files
committed
Use custom range in HardTanh and mask it as Clamp
allows HardTanh to use a custom interval for linear region when min/max values are specified in declaration. Thus its implementation can be shared with new Clamp function.
1 parent 31d7d2b commit 89a1970

File tree

7 files changed

+82
-14
lines changed

7 files changed

+82
-14
lines changed

Clamp.lua

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
local Clamp, Parent = torch.class('nn.Clamp', 'nn.HardTanh')
2+
3+
function Clamp:__init(min_value, max_value)
4+
Parent.__init(self, min_value, max_value)
5+
end

HardTanh.lua

+10-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
1-
local HardTanh = torch.class('nn.HardTanh', 'nn.Module')
1+
local HardTanh, parent = torch.class('nn.HardTanh', 'nn.Module')
2+
3+
function HardTanh:__init(min_value, max_value)
4+
parent.__init(self)
5+
self.min_val = min_value or -1
6+
self.max_val = max_value or 1
7+
assert(self.max_val>self.min_val, 'max_value must be larger than min_value')
8+
end
29

310
function HardTanh:updateOutput(input)
11+
self.min_val = self.min_val or -1
12+
self.max_val = self.max_val or 1
413
return input.nn.HardTanh_updateOutput(self, input)
514
end
615

doc/simple.md

+27
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ Simple Modules are used for various tasks like adapting Tensor methods and provi
2929
* [Power](#nn.Power) : an element-wise [pow](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchpowres-x) operation ;
3030
* [Square](#nn.Square) : an element-wise square operation ;
3131
* [Sqrt](#nn.Sqrt) : an element-wise [sqrt](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchsqrtres-x) operation ;
32+
* [Clamp](#nn.Clamp) : an element-wise [clamp](https://github.com/torch/torch7/blob/master/doc/maths.md#res-torchclampres-tensor1-min_value-max_value) operation ;
3233
* [Normalize](#nn.Normalize) : normalizes the input to have unit `L_p` norm ;
3334
* [MM](#nn.MM) : matrix-matrix multiplication (also supports batches of matrices) ;
3435
* Miscellaneous Modules :
@@ -905,6 +906,32 @@ gnuplot.grid(true)
905906

906907
![](image/power.png)
907908

909+
<a name="nn.Clamp"></a>
910+
## Clamp ##
911+
912+
```lua
913+
module = nn.Clamp(min_value, max_value)
914+
```
915+
916+
Clamps all elements into the range `[min_value, max_value]`.
917+
Output is identical to input in the range, otherwise elements less than `min_value` (or greater than `max_value`) are saturated to `min_value` (or `max_value`).
918+
919+
```lua
920+
A = torch.randn(2, 5)
921+
m = nn.Clamp(-0.1, 0.5)
922+
B = m:forward(A)
923+
924+
print(A) -- input
925+
-1.1321 0.0227 -0.4672 0.6519 -0.5380
926+
0.9061 -1.0858 0.3697 -0.8120 -1.6759
927+
[torch.DoubleTensor of size 3x5]
928+
929+
print(B) -- output
930+
-0.1000 0.0227 -0.1000 0.5000 -0.1000
931+
0.5000 -0.1000 0.3697 -0.1000 -0.1000
932+
[torch.DoubleTensor of size 3x5]
933+
```
934+
908935
<a name="nn.Normalize"></a>
909936
## Normalize ##
910937

doc/transfer.md

+4
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,10 @@ thus outputting a Tensor of the same dimension.
1414
* `f(x)` = `-1, if x <` `-1,`
1515
* `f(x)` = `x,` `otherwise.`
1616

17+
The range of the linear region `[-1 1]` can be adjusted by specifying arguments in declaration, for example `nn.HardTanh(min_value, max_value)`.
18+
Otherwise, `[min_value max_value]` is set to `[-1 1]` by default.
19+
20+
1721
```lua
1822
ii=torch.linspace(-2,2)
1923
m=nn.HardTanh()

generic/HardTanh.c

+17-13
Original file line numberDiff line numberDiff line change
@@ -5,19 +5,21 @@
55
static int nn_(HardTanh_updateOutput)(lua_State *L)
66
{
77
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
8+
real min_val = luaT_getfieldchecknumber(L, 1, "min_val");
9+
real max_val = luaT_getfieldchecknumber(L, 1, "max_val");
810
THTensor *output = luaT_getfieldcheckudata(L, 1, "output", torch_Tensor);
911

1012
THTensor_(resizeAs)(output, input);
1113

1214
if (input->nDimension == 1 || !THTensor_(isContiguous)(input) || !THTensor_(isContiguous)(output))
1315
{
1416
TH_TENSOR_APPLY2(real, output, real, input, \
15-
if(*input_data < -1) \
16-
*output_data = -1; \
17-
else if(*input_data <= 1) \
17+
if(*input_data < min_val) \
18+
*output_data = min_val; \
19+
else if(*input_data <= max_val) \
1820
*output_data = *input_data; \
1921
else \
20-
*output_data = 1;);
22+
*output_data = max_val;);
2123
}
2224
else
2325
{
@@ -28,12 +30,12 @@ static int nn_(HardTanh_updateOutput)(lua_State *L)
2830
#pragma omp parallel for private(i)
2931
for (i = 0; i < THTensor_(nElement)(input); i++)
3032
{
31-
if(ptr_input[i] < -1)
32-
ptr_output[i] = -1;
33-
else if (ptr_input[i] <= 1)
34-
ptr_output[i] = ptr_input[i];
33+
if(ptr_input[i] < min_val)
34+
ptr_output[i] = min_val;
35+
else if (ptr_input[i] <= max_val)
36+
ptr_output[i] = ptr_input[i];
3537
else
36-
ptr_output[i] = 1;
38+
ptr_output[i] = max_val;
3739
}
3840
}
3941
return 1;
@@ -42,6 +44,8 @@ static int nn_(HardTanh_updateOutput)(lua_State *L)
4244
static int nn_(HardTanh_updateGradInput)(lua_State *L)
4345
{
4446
THTensor *input = luaT_checkudata(L, 2, torch_Tensor);
47+
real min_val = luaT_getfieldchecknumber(L, 1, "min_val");
48+
real max_val = luaT_getfieldchecknumber(L, 1, "max_val");
4549
THTensor *gradOutput = luaT_checkudata(L, 3, torch_Tensor);
4650
THTensor *gradInput = luaT_getfieldcheckudata(L, 1, "gradInput", torch_Tensor);
4751

@@ -53,7 +57,7 @@ static int nn_(HardTanh_updateGradInput)(lua_State *L)
5357
!THTensor_(isContiguous)(gradInput))
5458
{
5559
TH_TENSOR_APPLY3(real, gradInput, real, gradOutput, real, input, \
56-
if(*input_data < -1 || *input_data > 1) \
60+
if(*input_data < min_val || *input_data > max_val) \
5761
*gradInput_data = 0; \
5862
else \
5963
*gradInput_data = *gradOutput_data;);
@@ -68,10 +72,10 @@ static int nn_(HardTanh_updateGradInput)(lua_State *L)
6872
#pragma omp parallel for private(i)
6973
for (i = 0; i < THTensor_(nElement)(input); i++)
7074
{
71-
if(ptr_input[i] < -1 || ptr_input[i] > 1)
72-
ptr_gradInput[i] = 0;
75+
if(ptr_input[i] < min_val || ptr_input[i] > max_val)
76+
ptr_gradInput[i] = 0;
7377
else
74-
ptr_gradInput[i] = ptr_gradOutput[i];
78+
ptr_gradInput[i] = ptr_gradOutput[i];
7579
}
7680
}
7781
return 1;

init.lua

+1
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ include('Normalize.lua')
5252
include('Exp.lua')
5353
include('Log.lua')
5454
include('HardTanh.lua')
55+
include('Clamp.lua')
5556
include('LogSigmoid.lua')
5657
include('LogSoftMax.lua')
5758
include('Sigmoid.lua')

test.lua

+18
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,24 @@ function nntest.HardTanh()
266266
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
267267
end
268268

269+
function nntest.Clamp()
270+
local ini = math.random(3,5)
271+
local inj = math.random(3,5)
272+
local ink = math.random(3,5)
273+
local max_value = math.abs(math.random())
274+
local min_value = -math.abs(math.random())
275+
local input = torch.Tensor(ink, inj, ini):zero()
276+
277+
local module = nn.Clamp(min_value, max_value)
278+
279+
local err = jac.testJacobian(module, input)
280+
mytester:assertlt(err, precision , 'error on state ')
281+
282+
local ferr, berr = jac.testIO(module, input)
283+
mytester:asserteq(ferr, 0, torch.typename(module) .. ' - i/o forward err ')
284+
mytester:asserteq(berr, 0, torch.typename(module) .. ' - i/o backward err ')
285+
end
286+
269287
function nntest.Abs()
270288
local ini = math.random(3,5)
271289
local inj = math.random(3,5)

0 commit comments

Comments
 (0)