Skip to content

Commit 58aeea1

Browse files
jasonkuenjasonkuen
authored andcommitted
Channel-Wise RReLU
1 parent 316a0cf commit 58aeea1

File tree

2 files changed

+52
-43
lines changed

2 files changed

+52
-43
lines changed

lib/THNN/generic/RReLU.c

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ void THNN_(RReLU_updateOutput)(
1717
if (channelwise && train)
1818
{
1919
long bs, ks;
20-
THIndex_t nOutputPlane;
20+
long nOutputPlane;
2121
{
2222
long input_ndim = THTensor_(nDimension)(input);
2323
switch (input_ndim)
@@ -39,7 +39,7 @@ void THNN_(RReLU_updateOutput)(
3939
ks = input->size[2] * input->size[3];
4040
break;
4141
}
42-
nOutputPlane = input->size[(input_ndim + 1) % 2]
42+
nOutputPlane = input->size[(input_ndim + 1) % 2];
4343
}
4444
// get default random generator
4545
if (inplace)
@@ -51,14 +51,17 @@ void THNN_(RReLU_updateOutput)(
5151
real *input_data = THTensor_(data)(input);
5252
real *noise_data = THTensor_(data)(noise);
5353
if (!inplace)
54+
{
55+
THTensor_(resizeAs)(output, input);
5456
output_data = THTensor_(data)(output);
57+
}
5558
THTensor *channel_noise = THTensor_(newWithSize1d)(nOutputPlane);
5659
real *channel_noise_data = THTensor_(data)(channel_noise);
5760

5861
THIndex_t i, j, k;
5962
#pragma omp parallel for private(j)
6063
for (j = 0; j < nOutputPlane; ++j)
61-
channel_noise_data[j] = (real)THRandom_uniform(generator, lower, upper)
64+
channel_noise_data[j] = (real)THRandom_uniform(generator, lower, upper);
6265
#pragma omp parallel for private(j,k)
6366
for (i = 0; i < bs; ++i)
6467
{
@@ -72,17 +75,22 @@ void THNN_(RReLU_updateOutput)(
7275
for (j = 0; j < nOutputPlane; ++j)
7376
{
7477
const real r = channel_noise_data[j];
75-
for (k = 0; k < ks; ++k)
76-
if (inplace)
77-
if n_input_data[k] <= 0
78-
{
79-
n_input_data[k] = r * n_input_data[k];
80-
n_noise_data[k] = r;
81-
}
82-
else
83-
n_noise_data[k] = 1;
78+
for (k = 0; k < ks; ++k)
79+
if (inplace)
80+
if (n_input_data[k] <= 0)
81+
{
82+
n_input_data[k] = r * n_input_data[k];
83+
n_noise_data[k] = r;
84+
}
8485
else
85-
n_output_data[k] = (n_input_data[k] > 0) ? n_input_data[k] : r * n_input_data[k];
86+
n_noise_data[k] = 1;
87+
else
88+
n_output_data[k] = (n_input_data[k] > 0) ? n_input_data[k] : r * n_input_data[k];
89+
n_input_data += ks;
90+
if (inplace)
91+
n_noise_data += ks;
92+
else
93+
n_output_data += ks;
8694
}
8795
}
8896
if (inplace)
@@ -172,7 +180,7 @@ void THNN_(RReLU_updateGradInput)(
172180
if (channelwise && !inplace)
173181
{
174182
long bs, ks;
175-
THIndex_t nOutputPlane;
183+
long nOutputPlane;
176184
{
177185
long input_ndim = THTensor_(nDimension)(input);
178186
switch (input_ndim)
@@ -194,12 +202,12 @@ void THNN_(RReLU_updateGradInput)(
194202
ks = input->size[2] * input->size[3];
195203
break;
196204
}
197-
nOutputPlane = input->size[(input_ndim + 1) % 2]
205+
nOutputPlane = input->size[(input_ndim + 1) % 2];
198206
}
199207

200-
const real *output_data = output_data = THTensor_(data)(output);
201208
const real *input_data = THTensor_(data)(input);
202209
const real *gradOutput_data = THTensor_(data)(gradOutput);
210+
THTensor_(resizeAs)(gradInput, input);
203211
real *gradInput_data = THTensor_(data)(gradInput);
204212
const real *noise_data = THTensor_(data)(noise);
205213

@@ -215,12 +223,10 @@ void THNN_(RReLU_updateGradInput)(
215223
{
216224
const real r = noise_data[j];
217225
for (k = 0; k < ks; ++k)
218-
{
219226
if (n_input_data[k] > 0)
220227
n_gradInput_data[k] = n_gradOutput_data[k];
221228
else
222229
n_gradInput_data[k] = n_gradOutput_data[k] * r;
223-
}
224230
n_input_data += ks;
225231
n_gradInput_data += ks;
226232
n_gradOutput_data += ks;

test.lua

Lines changed: 28 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -484,32 +484,35 @@ function nntest.RReLU()
484484
for _,train in ipairs({true,false}) do
485485
-- test with separate output buffer and inplace
486486
for _,inplace in ipairs({false,true}) do
487-
module = nn.RReLU(l, u, inplace)
488-
if train then
489-
module:training()
490-
else
491-
module:evaluate()
492-
end
493-
input = torch.rand(nframe, size, kW, kH) - 0.5
494-
input:storage()[1] = -1
495-
local original_input = input:clone()
496-
local output = module:forward(input)
497-
mytester:assert(output:sign():eq(original_input:sign()):all(), 'sign flipped forward ')
498-
local gradOutput = torch.ones(output:size())
499-
local gradInput = module:backward(input, gradOutput)
500-
mytester:assert(gradInput:gt(0):eq(input:ne(0)):all(), 'gradient ')
501-
mytester:assert(gradInput:lt(1):eq(input:le(0)):all(), 'backward negative inputs ')
502-
mytester:assert(gradInput:eq(1):eq(input:gt(0)):all(), 'backward positive inputs ')
503-
if not train then
504-
local err = gradInput[input:le(0)]:mean()-(module.lower+module.upper)/2
505-
mytester:assertlt(err, precision, 'error on gradient ')
506-
end
487+
-- test with channel-wise
488+
for _,cw in ipairs({true,false}) do
489+
module = nn.RReLU(l, u, inplace, cw)
490+
if train then
491+
module:training()
492+
else
493+
module:evaluate()
494+
end
495+
input = torch.rand(nframe, size, kW, kH) - 0.5
496+
input:storage()[1] = -1
497+
local original_input = input:clone()
498+
local output = module:forward(input)
499+
mytester:assert(output:sign():eq(original_input:sign()):all(), 'sign flipped forward ')
500+
local gradOutput = torch.ones(output:size())
501+
local gradInput = module:backward(input, gradOutput)
502+
mytester:assert(gradInput:gt(0):eq(input:ne(0)):all(), 'gradient ')
503+
mytester:assert(gradInput:lt(1):eq(input:le(0)):all(), 'backward negative inputs ')
504+
mytester:assert(gradInput:eq(1):eq(input:gt(0)):all(), 'backward positive inputs ')
505+
if not train then
506+
local err = gradInput[input:le(0)]:mean()-(module.lower+module.upper)/2
507+
mytester:assertlt(err, precision, 'error on gradient ')
508+
end
507509

508-
input = -torch.rand(1000)
509-
module:forward(input) -- fill internal noise tensor
510-
local g = module:backward(input, torch.ones(1000))
511-
local err = math.abs(g[input:le(0)]:mean()-(module.lower+module.upper)/2)
512-
mytester:assertlt(err, 0.05, 'mean deviation of gradient for negative inputs ')
510+
input = -torch.rand(1000)
511+
module:forward(input) -- fill internal noise tensor
512+
local g = module:backward(input, torch.ones(1000))
513+
local err = math.abs(g[input:le(0)]:mean()-(module.lower+module.upper)/2)
514+
mytester:assertlt(err, 0.05, 'mean deviation of gradient for negative inputs ')
515+
end
513516
end
514517
end
515518
end

0 commit comments

Comments
 (0)