-
Notifications
You must be signed in to change notification settings - Fork 0
/
GSCAB
118 lines (107 loc) · 4.16 KB
/
GSCAB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
class CBAM_Module(nn.Module):
def __init__(self, channels=512, reduction=2):
super(CBAM_Module, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1,
padding=0)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1,
padding=0)
self.sigmoid_channel = nn.Sigmoid()
self.conv_after_concat = nn.Conv2d(2, 1, kernel_size=7, stride=1, padding=3)
self.sigmoid_spatial = nn.Sigmoid()
def forward(self, x):
# Channel Attention module
module_input = x
avg = self.avg_pool(x)
mx = self.max_pool(x)
avg = self.fc1(avg)
mx = self.fc1(mx)
avg = self.relu(avg)
mx = self.relu(mx)
avg = self.fc2(avg)
mx = self.fc2(mx)
x = avg + mx
x = self.sigmoid_channel(x)
# Spatial Attention module
x = module_input * x + module_input
module_input = x
avg = torch.mean(x, 1, True)
mx, _ = torch.max(x, 1, True)
x = torch.cat((avg, mx), 1)
x = self.conv_after_concat(x)
x = self.sigmoid_spatial(x)
x = module_input * x + module_input
return x
class SpatialAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(SpatialAttentionBlock, self).__init__()
self.query = nn.Sequential(
nn.Conv2d(in_channels,in_channels//8,kernel_size=(1,3), padding=(0,1)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.key = nn.Sequential(
nn.Conv2d(in_channels, in_channels//8, kernel_size=(3,1), padding=(1,0)),
nn.BatchNorm2d(in_channels//8),
nn.ReLU(inplace=True)
)
self.value = nn.Conv2d(in_channels, in_channels, kernel_size=1)
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
:param x: input( BxCxHxW )
:return: affinity value + x
"""
B, C, H, W = x.size()
# compress x: [B,C,H,W]-->[B,H*W,C], make a matrix transpose
proj_query = self.query(x).view(B, -1, W * H).permute(0, 2, 1)
proj_key = self.key(x).view(B, -1, W * H)
affinity = torch.matmul(proj_query, proj_key)
affinity = self.softmax(affinity)
proj_value = self.value(x).view(B, -1, H * W)
weights = torch.matmul(proj_value, affinity.permute(0, 2, 1))
weights = weights.view(B, C, H, W)
out = self.gamma * weights + x
return out
class ChannelAttentionBlock(nn.Module):
def __init__(self, in_channels):
super(ChannelAttentionBlock, self).__init__()
self.gamma = nn.Parameter(torch.zeros(1))
self.softmax = nn.Softmax(dim=-1)
def forward(self, x):
"""
:param x: input( BxCxHxW )
:return: affinity value + x
"""
B, C, H, W = x.size()
proj_query = x.view(B, C, -1)
proj_key = x.view(B, C, -1).permute(0, 2, 1)
affinity = torch.matmul(proj_query, proj_key)
affinity_new = torch.max(affinity, -1, keepdim=True)[0].expand_as(affinity) - affinity
affinity_new = self.softmax(affinity_new)
proj_value = x.view(B, C, -1)
weights = torch.matmul(affinity_new, proj_value)
weights = weights.view(B, C, H, W)
out = self.gamma * weights + x
return out
class AffinityAttention(nn.Module):
""" Affinity attention module """
def __init__(self, in_channels):
super(AffinityAttention2, self).__init__()
self.sab = SpatialAttentionBlock(in_channels)
self.cab = ChannelAttentionBlock(in_channels)
# self.conv1x1 = nn.Conv2d(in_channels * 2, in_channels, kernel_size=1)
def forward(self, x):
"""
sab: spatial attention block
cab: channel attention block
:param x: input tensor
:return: sab + cab
"""
sab = self.sab(x)
cab = self.cab(sab)
out = sab + cab
return out