-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlinear_block.py
34 lines (30 loc) · 1.16 KB
/
linear_block.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import chainer
from chainer import functions
from chainer import links
class LinearBlock(chainer.Chain):
def __init__(self, in_size, out_size=None, nobias=False,
initialW=None, initial_bias=None, use_bn=True,
activation=functions.relu, dropout_ratio=-1, residual=False):
super(LinearBlock, self).__init__()
with self.init_scope():
self.linear = links.Linear(
in_size, out_size=out_size, nobias=nobias,
initialW=initialW, initial_bias=initial_bias)
if use_bn:
self.bn = links.BatchNormalization(out_size)
self.activation = activation
self.use_bn = use_bn
self.dropout_ratio = dropout_ratio
self.residual = residual
def __call__(self, x):
if self.use_bn:
h = self.bn(self.linear(x))
else:
h = self.linear(x)
if self.activation is not None:
h = self.activation(h)
if self.residual:
raise NotImplementedError('not implemented yet')
if self.dropout_ratio >= 0:
h = functions.dropout(h, ratio=self.dropout_ratio)
return h