-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathlayers.py
103 lines (65 loc) · 2.98 KB
/
layers.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import math
import random
import torch
import torch.nn as nn
import torch.nn.functional as F
def conv1x1(in_channels, out_channels, stride=1, groups=1):
return convnxn(in_channels, out_channels, kernel_size=1, stride=stride, groups=groups)
def conv3x3(in_channels, out_channels, stride=1, groups=1):
return convnxn(in_channels, out_channels, kernel_size=3, stride=stride, groups=groups, padding=1)
def convnxn(in_channels, out_channels, kernel_size, stride=1, groups=1, padding=0):
return nn.Conv2d(in_channels, out_channels,
kernel_size=kernel_size, stride=stride, padding=padding, groups=groups, bias=False)
def relu():
return nn.ReLU()
def bn(channels):
return nn.BatchNorm2d(channels)
def dense(in_features, out_features):
return nn.Linear(in_features, out_features)
def blur(in_filters, sfilter=(1, 1), pad_mode="constant"):
if tuple(sfilter) == (1, 1) and pad_mode in ["constant", "zero"]:
layer = nn.AvgPool2d(kernel_size=2, stride=1, padding=1)
else:
layer = Blur(in_filters, sfilter=sfilter, pad_mode=pad_mode)
return layer
class SamePad(nn.Module):
def __init__(self, filter_size, pad_mode="constant", **kwargs):
super(SamePad, self).__init__()
self.pad_size = [
int((filter_size - 1) / 2.0), int(math.ceil((filter_size - 1) / 2.0)),
int((filter_size - 1) / 2.0), int(math.ceil((filter_size - 1) / 2.0)),
]
self.pad_mode = pad_mode
def forward(self, x):
x = F.pad(x, self.pad_size, mode=self.pad_mode)
return x
def extra_repr(self):
return "pad_size=%s, pad_mode=%s" % (self.pad_size, self.pad_mode)
class Blur(nn.Module):
def __init__(self, in_filters, sfilter=(1, 1), pad_mode="replicate", **kwargs):
super(Blur, self).__init__()
filter_size = len(sfilter)
self.pad = SamePad(filter_size, pad_mode=pad_mode)
self.filter_proto = torch.tensor(sfilter, dtype=torch.float, requires_grad=False)
self.filter = torch.einsum("i, j -> i j", self.filter_proto, self.filter_proto)
self.filter = self.filter / torch.sum(self.filter)
self.filter = self.filter.repeat([in_filters, 1, 1, 1])
self.filter = torch.nn.Parameter(self.filter, requires_grad=False)
def forward(self, x):
x = self.pad(x)
x = F.conv2d(x, self.filter, groups=x.size()[1])
return x
def extra_repr(self):
return "pad=%s, filter_proto=%s" % (self.pad, self.filter_proto.tolist())
class Downsample(nn.Module):
def __init__(self, strides=(2, 2), **kwargs):
super(Downsample, self).__init__()
if isinstance(strides, int):
strides = (strides, strides)
self.strides = strides
def forward(self, x):
shape = (-(-x.size()[2] // self.strides[0]), -(-x.size()[3] // self.strides[1]))
x = F.interpolate(x, size=shape, mode='nearest')
return x
def extra_repr(self):
return "strides=%s" % repr(self.strides)