-
Notifications
You must be signed in to change notification settings - Fork 0
/
S
59 lines (44 loc) · 1.64 KB
/
S
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import torch
import torch.nn as nn
from torchvision.models import resnet34 as resnet
#from torchvision.models import resnet50
from torch.nn import CrossEntropyLoss, Dropout, Softmax, Linear, Conv2d, LayerNorm
import torch.nn.functional as F
import numpy as np
from SwinTransformer import *
from FFM import *
from EEM import *
form EIM import *
#def Downsample():
# return nn.MaxPool2d(kernel_size=2, stride=2)
class CTE(nn.Module):
def __init__(self, num_classes=1, drop_rate=0.2, normal_init=True, pretrained=False):
super(CTE, self).__init__()
self.cnn = self.resnet = resnet34()
if pretrained:
self.resnet.load_state_dict(torch.load('pretrained/resnet34-333f7ec4.pth'))
self.transformer = SwinTransformer()
self.fusion1 = FFM(128, 256)
self.fusion2 = FFM(256, 512)
self.fusion3 = FFM(512, 1024)
self.edge1 = EEM()
self.edge2 = EIM()
#self.predict = nn.Conv2d(64, self.num_class, 1)
self.sigmoid = sigmoid()
def forward(self, x):
# dual-encoder
g1, g2, g3 = self.transformer(x) # 128, 256, 512
f1, f2, f3, = self.cnn(x, feature=True) # 128, 256, 512
# featurefusion
m1 = self.fusion1(g1, s1, g1, s1) # 128 ==> 256
m2 = self.fusion2(g1, s1, g2, s2) # 256 ==> 512
m3 = self.fusion3(g2, s2, g3, s3) # 512 ==> 1024
#EdgeEnhancement
e1 = self.edge1(f1, f2)
#Edgeinjection
e2 = f3 + e1
out = self.edge2(e2, m3)
#decoder
out = F.interpolate(out, scale_factor=4, mode='bilinear')
out = self.sigmoid(out)
return out