-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdoconv.py
383 lines (318 loc) · 15.2 KB
/
doconv.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
import math
import paddle
import numpy as np
import paddle.nn as nn
import paddle.nn.functional as F
def _calculate_fan_in_and_fan_out(tensor):
dimensions = len(tensor.shape)
if dimensions < 2:
raise ValueError(
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions"
)
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if len(tensor.shape) > 2:
receptive_field_size = paddle.numel(tensor[0][0])
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def _calculate_correct_fan(tensor, mode):
mode = mode.lower()
valid_modes = ['fan_in', 'fan_out']
if mode not in valid_modes:
raise ValueError("Mode {} not supported, please use one of {}".format(
mode, valid_modes))
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
return fan_in if mode == 'fan_in' else fan_out
def calculate_gain(nonlinearity, param=None):
"""Return the recommended gain value for the given nonlinearity function.
The values are as follows:
================= ====================================================
nonlinearity gain
================= ====================================================
Linear / Identity :math:`1`
Conv{1,2,3}D :math:`1`
Sigmoid :math:`1`
Tanh :math:`\frac{5}{3}`
ReLU :math:`\sqrt{2}`
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}`
================= ====================================================
Args:
nonlinearity: the non-linear function (`nn.functional` name)
param: optional parameter for the non-linear function
"""
linear_fns = [
'linear', 'conv1d', 'conv2d', 'conv3d', 'conv_transpose1d',
'conv_transpose2d', 'conv_transpose3d'
]
if nonlinearity in linear_fns or nonlinearity == 'sigmoid':
return 1
elif nonlinearity == 'tanh':
return 5.0 / 3
elif nonlinearity == 'relu':
return math.sqrt(2.0)
elif nonlinearity == 'leaky_relu':
if param is None:
negative_slope = 0.01
elif not isinstance(param, bool) and isinstance(
param, int) or isinstance(param, float):
# True/False are instances of int, hence check above
negative_slope = param
else:
raise ValueError(
"negative_slope {} not a valid number".format(param))
return math.sqrt(2.0 / (1 + negative_slope**2))
else:
raise ValueError("Unsupported nonlinearity {}".format(nonlinearity))
@paddle.no_grad()
def uniform_(x, a=-1., b=1.):
temp_value = paddle.uniform(min=a, max=b, shape=x.shape)
x.set_value(temp_value)
return x
@paddle.no_grad()
def kaiming_normal_(x, a=0, mode='fan_in', nonlinearity='leaky_relu'):
"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
normal distribution. The resulting tensor will have values sampled from
:math:`\mathcal{N}(0, \text{std}^2)` where
.. math::
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}}
Also known as He initialization.
Args:
x: an n-dimensional `paddle.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(x, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
temp_value = paddle.normal(0, std, shape=x.shape)
x.set_value(temp_value)
return x
@paddle.no_grad()
def kaiming_uniform_(x, a=0, mode='fan_in', nonlinearity='leaky_relu'):
"""Fills the input `Tensor` with values according to the method
described in `Delving deep into rectifiers: Surpassing human-level
performance on ImageNet classification` - He, K. et al. (2015), using a
uniform distribution. The resulting tensor will have values sampled from
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where
.. math::
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}}
Also known as He initialization.
Args:
x: an n-dimensional `paddle.Tensor`
a: the negative slope of the rectifier used after this layer (only
used with ``'leaky_relu'``)
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'``
preserves the magnitude of the variance of the weights in the
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
"""
fan = _calculate_correct_fan(x, mode)
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
bound = math.sqrt(
3.0) * std # Calculate uniform bounds from standard deviation
temp_value = paddle.uniform(x.shape, min=-bound, max=bound)
x.set_value(temp_value)
return x
class simam_module(nn.Layer):
def __init__(self, e_lambda=1e-4):
super(simam_module, self).__init__()
self.activaton = nn.Sigmoid()
self.e_lambda = e_lambda
def forward(self, x):
b, c, h, w = x.shape
n = w * h - 1
x_minus_mu_square = (x - x.mean(dim=[2, 3], keepdim=True)).pow(2)
y = x_minus_mu_square / (4 * (x_minus_mu_square.sum(dim=[2, 3], keepdim=True) / n + self.e_lambda)) + 0.5
return x * self.activaton(y)
class DOConv2d(nn.Layer):
"""
DOConv2d can be used as an alternative for torch.nn.Conv2d.
The interface is similar to that of Conv2d, with one exception:
1. D_mul: the depth multiplier for the over-parameterization.
Note that the groups parameter switchs between DO-Conv (groups=1),
DO-DConv (groups=in_channels), DO-GConv (otherwise).
"""
__constants__ = ['stride', 'padding', 'dilation', 'groups',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size', 'D_mul']
def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
super(DOConv2d, self).__init__()
kernel_size = (kernel_size, kernel_size)
stride = (stride, stride)
padding = (padding, padding)
dilation = (dilation, dilation)
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in valid_padding_modes:
raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
self.simam = simam
# Initailization of D & W
M = self.kernel_size[0]
N = self.kernel_size[1]
self.D_mul = M * N if D_mul is None or M * N <= 1 else D_mul
self.W = paddle.create_parameter(shape=[out_channels, in_channels // groups, self.D_mul], dtype='float32')
kaiming_normal_(self.W, a=math.sqrt(5))
if M * N > 1:
self.D = paddle.create_parameter(shape=[in_channels, M * N, self.D_mul], dtype='float32')
init_zero = np.zeros([in_channels, M * N, self.D_mul], dtype=np.float32)
self.D.data = paddle.to_tensor(init_zero)
eye = paddle.reshape(paddle.eye(M * N, dtype='float32'), (1, M * N, M * N))
# repeats = paddle.to_tensor([in_channels, 1, self.D_mul // (M * N)], dtype='int32')
# D_diag = paddle.repeat_interleave(eye, repeats)
D_diag = eye.tile((in_channels, 1, self.D_mul // (M * N)))
if self.D_mul % (M * N) != 0: # the cases when D_mul > M * N
zeros = paddle.zeros([in_channels, M * N, self.D_mul % (M * N)])
self.D_diag = paddle.create_parameter(paddle.concat([D_diag, zeros], axis=2))
else: # the case when D_mul = M * N
self.D_diag = paddle.create_parameter(D_diag.shape, dtype='float32', default_initializer=nn.initializer.Assign(D_diag))
if simam:
self.simam_block = simam_module()
if bias:
self.bias = paddle.create_parameter(shape=[out_channels], dtype='float32')
fan_in, _ = _calculate_fan_in_and_fan_out(self.W)
bound = 1 / math.sqrt(fan_in)
uniform_(self.bias, -bound, bound)
else:
self.bias = None
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
def __setstate__(self, state):
super(DOConv2d, self).__setstate__(state)
if not hasattr(self, 'padding_mode'):
self.padding_mode = 'zeros'
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
(0, 0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
M = self.kernel_size[0]
N = self.kernel_size[1]
DoW_shape = (self.out_channels, self.in_channels // self.groups, M, N)
if M * N > 1:
# Compute DoW (input_channels, D_mul, M * N)
D = self.D + self.D_diag
W = paddle.reshape(self.W, (self.out_channels // self.groups, self.in_channels, self.D_mul))
# einsum outputs (out_channels // groups, in_channels, M * N), which is reshaped to (out_channels, in_channels // groups, M, N)
DoW = paddle.reshape(paddle.einsum('ims,ois->oim', D, W), DoW_shape)
else:
DoW = paddle.reshape(self.W, DoW_shape)
if self.simam:
DoW_h1, DoW_h2 = paddle.chunk(DoW, 2, axis=2)
DoW = paddle.concat([self.simam_block(DoW_h1), DoW_h2], axis=2)
return self._conv_forward(input, DoW)
class DOConv2d_eval(nn.Layer):
"""
DOConv2d can be used as an alternative for torch.nn.Conv2d.
The interface is similar to that of Conv2d, with one exception:
1. D_mul: the depth multiplier for the over-parameterization.
Note that the groups parameter switchs between DO-Conv (groups=1),
DO-DConv (groups=in_channels), DO-GConv (otherwise).
"""
__constants__ = ['stride', 'padding', 'dilation', 'groups',
'padding_mode', 'output_padding', 'in_channels',
'out_channels', 'kernel_size', 'D_mul']
def __init__(self, in_channels, out_channels, kernel_size=3, D_mul=None, stride=1,
padding=1, dilation=1, groups=1, bias=False, padding_mode='zeros', simam=False):
super(DOConv2d_eval, self).__init__()
kernel_size = (kernel_size, kernel_size)
stride = (stride, stride)
padding = (padding, padding)
dilation = (dilation, dilation)
if in_channels % groups != 0:
raise ValueError('in_channels must be divisible by groups')
if out_channels % groups != 0:
raise ValueError('out_channels must be divisible by groups')
valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}
if padding_mode not in valid_padding_modes:
raise ValueError("padding_mode must be one of {}, but got padding_mode='{}'".format(
valid_padding_modes, padding_mode))
self.in_channels = in_channels
self.out_channels = out_channels
self.kernel_size = kernel_size
self.stride = stride
self.padding = padding
self.dilation = dilation
self.groups = groups
self.padding_mode = padding_mode
self._padding_repeated_twice = tuple(x for x in self.padding for _ in range(2))
self.simam = simam
# Initailization of D & W
M = self.kernel_size[0]
N = self.kernel_size[1]
self.W = paddle.create_parameter(shape=(out_channels, in_channels // groups, M, N))
kaiming_uniform_(self.W, a=math.sqrt(5))
# self.register_parameter('bias', None)
def extra_repr(self):
s = ('{in_channels}, {out_channels}, kernel_size={kernel_size}'
', stride={stride}')
if self.padding != (0,) * len(self.padding):
s += ', padding={padding}'
if self.dilation != (1,) * len(self.dilation):
s += ', dilation={dilation}'
if self.groups != 1:
s += ', groups={groups}'
if self.bias is None:
s += ', bias=False'
if self.padding_mode != 'zeros':
s += ', padding_mode={padding_mode}'
return s.format(**self.__dict__)
def __setstate__(self, state):
super(DOConv2d, self).__setstate__(state)
if not hasattr(self, 'padding_mode'):
self.padding_mode = 'zeros'
def _conv_forward(self, input, weight):
if self.padding_mode != 'zeros':
return F.conv2d(F.pad(input, self._padding_repeated_twice, mode=self.padding_mode),
weight, self.bias, self.stride,
(0, 0), self.dilation, self.groups)
return F.conv2d(input, weight, self.bias, self.stride,
self.padding, self.dilation, self.groups)
def forward(self, input):
return self._conv_forward(input, self.W)
if __name__ == "__main__":
img = paddle.randn(shape=[2, 3, 63, 63])
net = DOConv2d(in_channels=3, out_channels=3, kernel_size=3, padding=1, stride=1, bias=True, groups=1)
out = net(img)
print(img.shape)