forked from braindotai/Watermark-Removal-Pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
modules.py
63 lines (48 loc) · 2.12 KB
/
modules.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import torch
from torch import nn
import numpy as np
class DepthwiseSeperableConv2d(nn.Module):
def __init__(self, input_channels, output_channels, **kwargs):
super(DepthwiseSeperableConv2d, self).__init__()
self.depthwise = nn.Conv2d(input_channels, input_channels, groups = input_channels, **kwargs)
self.pointwise = nn.Conv2d(input_channels, output_channels, kernel_size = 1)
def forward(self, x):
x = self.depthwise(x)
x = self.pointwise(x)
return x
class Conv2dBlock(nn.Module):
def __init__(self, in_channels, out_channels, kernel_size, stride = 1, bias = False):
super(Conv2dBlock, self).__init__()
self.model = nn.Sequential(
nn.ReflectionPad2d(int((kernel_size - 1) / 2)),
DepthwiseSeperableConv2d(in_channels, out_channels, kernel_size = kernel_size, stride = stride, padding = 0, bias = bias),
nn.BatchNorm2d(out_channels),
nn.LeakyReLU(0.2)
)
def forward(self, x):
return self.model(x)
class Concat(nn.Module):
def __init__(self, dim, *args):
super(Concat, self).__init__()
self.dim = dim
for idx, module in enumerate(args):
self.add_module(str(idx), module)
def forward(self, input):
inputs = []
for module in self._modules.values():
inputs.append(module(input))
inputs_shapes2 = [x.shape[2] for x in inputs]
inputs_shapes3 = [x.shape[3] for x in inputs]
if np.all(np.array(inputs_shapes2) == min(inputs_shapes2)) and np.all(np.array(inputs_shapes3) == min(inputs_shapes3)):
inputs_ = inputs
else:
target_shape2 = min(inputs_shapes2)
target_shape3 = min(inputs_shapes3)
inputs_ = []
for inp in inputs:
diff2 = (inp.size(2) - target_shape2) // 2
diff3 = (inp.size(3) - target_shape3) // 2
inputs_.append(inp[:, :, diff2: diff2 + target_shape2, diff3:diff3 + target_shape3])
return torch.cat(inputs_, dim=self.dim)
def __len__(self):
return len(self._modules)