-
Notifications
You must be signed in to change notification settings - Fork 0
/
FFM
137 lines (108 loc) · 4.46 KB
/
FFM
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import math
import torch.nn as nn
import torch.nn.functional as F
#通道注意力-Transformer
class ChannelAttention(nn.Module):
def __init__(self, in_channels):
super(ChannelAttention, self).__init__()
self.global_avg_pool = nn.AdaptiveAvgPool2d(1) # 全局平均池化
self.global_max_pool = nn.AdaptiveMaxPool2d(1) # 全局最大池化
self.fc = nn.Sequential(nn.Conv2d(in_channels, in_channels // 16, 1, bias=False),
nn.ReLU(),
nn.Conv2d(in_channels // 16, in_channels, 1, bias=False)) #mlp
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.fc(self.avg_pool(x))
max_out = self.fc(self.max_pool(x))
out = avg_out + max_out
sigmoid_out = self.sigmoid(out) # 通道注意力系数
out_CA = x*sigmoid_out
return out_CA
#空间注意力-CNN
class SpatialAttention(nn.Module):
def __init__(self, kernel_size):
super(SpatialAttention, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.conv33 = nn.Conv2d(2, 1, kernel_size=3, padding=kernel_size//2, bias=False) # 3x3 卷积
self.conv11 = nn.Conv2d(2, 1, kernel_size=1, padding=kernel_size//2,bias=False) # 1x1 卷积
self.relu = nn.ReLU()
self.sigmoid = nn.Sigmoid()
def forward(self, x):
avg_out = self.avg_pool(x)
max_out = self.max_pool(x)
x1 = torch.cat([avg_out, max_out], dim=1)
x2 = self.conv33(x1)
x3 = self.conv11(x2)
sigmoid_out = self.sigmoid(x3)
out_SA = x*sigmoid_out
return out_SA
#相关性增强
class CorrelationEnhancement(nn.Module):
def __init__(self, in_channels, out_channels):
super(CorrelationEnhancement, self).__init__()
self.reshape = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, T, C):
out = self.reshape(T)
out_CE = out * C
return out_CE
#特征融合块
class FeatureFusionBlock(nn.Module):
def __init__(self, in_channels, out_channels):
super(FeatureFusionBlock, self).__init__()
self.conv11 = nn.Conv2d(2, 1, kernel_size=1, padding=kernel_size//2,bias=False) # 1x1 卷积
self.relu = nn.ReLU()
self.bn = nn.BatchNorm(64) #需要指定输入特征的维度
self.reshape = nn.Conv2d(in_channels, out_channels, kernel_size=1)
def forward(self, out_CE, out_CA, out_SA):
out = torch.cat([out_CE, out_CA, out_SA], dim=1)
out = self.conv11(out)
out = self.relu(out)
out = self.bn(out)
out = self.reshape(out)
return out
class ChannelPool(nn.Module):
def forward(self, x):
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1)
class BiFusion_block(nn.Module):
def __init__(self, ch_1, ch_2, r_2, ch_int, ch_out, drop_rate=0.):
super(BiFusion_block, self).__init__()
# channel attention for F_g, use SE Block
self.fc1 = nn.Conv2d(ch_2, ch_2 // r_2, kernel_size=1)
self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(ch_2 // r_2, ch_2, kernel_size=1)
self.sigmoid = nn.Sigmoid()
# spatial attention for F_l
self.compress = ChannelPool()
self.spatial = Conv(2, 1, 7, bn=True, relu=False, bias=False)
# bi-linear modelling for both
self.W_g = Conv(ch_1, ch_int, 1, bn=True, relu=False)
self.W_x = Conv(ch_2, ch_int, 1, bn=True, relu=False)
self.W = Conv(ch_int, ch_int, 3, bn=True, relu=True)
self.relu = nn.ReLU(inplace=True)
self.residual = Residual(ch_1+ch_2+ch_int, ch_out)
self.dropout = nn.Dropout2d(drop_rate)
self.drop_rate = drop_rate
def forward(self, g, x):
# bilinear pooling
W_g = self.W_g(g)
W_x = self.W_x(x)
bp = self.W(W_g*W_x)
# spatial attention for cnn branch
g_in = g
g = self.compress(g)
g = self.spatial(g)
g = self.sigmoid(g) * g_in
# channel attetion for transformer branch
x_in = x
x = x.mean((2, 3), keepdim=True)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x) * x_in
fuse = self.residual(torch.cat([g, x, bp], 1))
if self.drop_rate > 0:
return self.dropout(fuse)
else:
return fuse