@@ -484,32 +484,35 @@ function nntest.RReLU()
484
484
for _ ,train in ipairs ({true ,false }) do
485
485
-- test with separate output buffer and inplace
486
486
for _ ,inplace in ipairs ({false ,true }) do
487
- module = nn .RReLU (l , u , inplace )
488
- if train then
489
- module :training ()
490
- else
491
- module :evaluate ()
492
- end
493
- input = torch .rand (nframe , size , kW , kH ) - 0.5
494
- input :storage ()[1 ] = - 1
495
- local original_input = input :clone ()
496
- local output = module :forward (input )
497
- mytester :assert (output :sign ():eq (original_input :sign ()):all (), ' sign flipped forward ' )
498
- local gradOutput = torch .ones (output :size ())
499
- local gradInput = module :backward (input , gradOutput )
500
- mytester :assert (gradInput :gt (0 ):eq (input :ne (0 )):all (), ' gradient ' )
501
- mytester :assert (gradInput :lt (1 ):eq (input :le (0 )):all (), ' backward negative inputs ' )
502
- mytester :assert (gradInput :eq (1 ):eq (input :gt (0 )):all (), ' backward positive inputs ' )
503
- if not train then
504
- local err = gradInput [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2
505
- mytester :assertlt (err , precision , ' error on gradient ' )
506
- end
487
+ -- test with channel-wise
488
+ for _ ,cw in ipairs ({true ,false }) do
489
+ module = nn .RReLU (l , u , inplace , cw )
490
+ if train then
491
+ module :training ()
492
+ else
493
+ module :evaluate ()
494
+ end
495
+ input = torch .rand (nframe , size , kW , kH ) - 0.5
496
+ input :storage ()[1 ] = - 1
497
+ local original_input = input :clone ()
498
+ local output = module :forward (input )
499
+ mytester :assert (output :sign ():eq (original_input :sign ()):all (), ' sign flipped forward ' )
500
+ local gradOutput = torch .ones (output :size ())
501
+ local gradInput = module :backward (input , gradOutput )
502
+ mytester :assert (gradInput :gt (0 ):eq (input :ne (0 )):all (), ' gradient ' )
503
+ mytester :assert (gradInput :lt (1 ):eq (input :le (0 )):all (), ' backward negative inputs ' )
504
+ mytester :assert (gradInput :eq (1 ):eq (input :gt (0 )):all (), ' backward positive inputs ' )
505
+ if not train then
506
+ local err = gradInput [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2
507
+ mytester :assertlt (err , precision , ' error on gradient ' )
508
+ end
507
509
508
- input = - torch .rand (1000 )
509
- module :forward (input ) -- fill internal noise tensor
510
- local g = module :backward (input , torch .ones (1000 ))
511
- local err = math.abs (g [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2 )
512
- mytester :assertlt (err , 0.05 , ' mean deviation of gradient for negative inputs ' )
510
+ input = - torch .rand (1000 )
511
+ module :forward (input ) -- fill internal noise tensor
512
+ local g = module :backward (input , torch .ones (1000 ))
513
+ local err = math.abs (g [input :le (0 )]:mean ()- (module .lower + module .upper )/ 2 )
514
+ mytester :assertlt (err , 0.05 , ' mean deviation of gradient for negative inputs ' )
515
+ end
513
516
end
514
517
end
515
518
end
0 commit comments