-
Notifications
You must be signed in to change notification settings - Fork 68
/
Copy pathres_unet_plus.py
88 lines (63 loc) · 2.7 KB
/
res_unet_plus.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import torch.nn as nn
import torch
from core.modules import (
ResidualConv,
ASPP,
AttentionBlock,
Upsample_,
Squeeze_Excite_Block,
)
class ResUnetPlusPlus(nn.Module):
def __init__(self, channel, filters=[32, 64, 128, 256, 512]):
super(ResUnetPlusPlus, self).__init__()
self.input_layer = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
nn.BatchNorm2d(filters[0]),
nn.ReLU(),
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
)
self.input_skip = nn.Sequential(
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
)
self.squeeze_excite1 = Squeeze_Excite_Block(filters[0])
self.residual_conv1 = ResidualConv(filters[0], filters[1], 2, 1)
self.squeeze_excite2 = Squeeze_Excite_Block(filters[1])
self.residual_conv2 = ResidualConv(filters[1], filters[2], 2, 1)
self.squeeze_excite3 = Squeeze_Excite_Block(filters[2])
self.residual_conv3 = ResidualConv(filters[2], filters[3], 2, 1)
self.aspp_bridge = ASPP(filters[3], filters[4])
self.attn1 = AttentionBlock(filters[2], filters[4], filters[4])
self.upsample1 = Upsample_(2)
self.up_residual_conv1 = ResidualConv(filters[4] + filters[2], filters[3], 1, 1)
self.attn2 = AttentionBlock(filters[1], filters[3], filters[3])
self.upsample2 = Upsample_(2)
self.up_residual_conv2 = ResidualConv(filters[3] + filters[1], filters[2], 1, 1)
self.attn3 = AttentionBlock(filters[0], filters[2], filters[2])
self.upsample3 = Upsample_(2)
self.up_residual_conv3 = ResidualConv(filters[2] + filters[0], filters[1], 1, 1)
self.aspp_out = ASPP(filters[1], filters[0])
self.output_layer = nn.Sequential(nn.Conv2d(filters[0], 1, 1), nn.Sigmoid())
def forward(self, x):
x1 = self.input_layer(x) + self.input_skip(x)
x2 = self.squeeze_excite1(x1)
x2 = self.residual_conv1(x2)
x3 = self.squeeze_excite2(x2)
x3 = self.residual_conv2(x3)
x4 = self.squeeze_excite3(x3)
x4 = self.residual_conv3(x4)
x5 = self.aspp_bridge(x4)
x6 = self.attn1(x3, x5)
x6 = self.upsample1(x6)
x6 = torch.cat([x6, x3], dim=1)
x6 = self.up_residual_conv1(x6)
x7 = self.attn2(x2, x6)
x7 = self.upsample2(x7)
x7 = torch.cat([x7, x2], dim=1)
x7 = self.up_residual_conv2(x7)
x8 = self.attn3(x1, x7)
x8 = self.upsample3(x8)
x8 = torch.cat([x8, x1], dim=1)
x8 = self.up_residual_conv3(x8)
x9 = self.aspp_out(x8)
out = self.output_layer(x9)
return out