@@ -59,24 +59,36 @@ def __init__(self, in_channels, inter_channels=None, dimension=3, mode='embedded
59
59
self .phi = None
60
60
self .concat_project = None
61
61
62
- if mode in ['embedded_gaussian' , 'dot_product' , 'concatenation' ]:
63
- self .theta = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
64
- kernel_size = 1 , stride = 1 , padding = 0 )
65
- self .phi = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
66
- kernel_size = 1 , stride = 1 , padding = 0 )
67
-
68
- if mode == 'embedded_gaussian' :
69
- self .operation_function = self ._embedded_gaussian
70
- elif mode == 'dot_product' :
71
- self .operation_function = self ._dot_product
72
- elif mode == 'concatenation' :
73
- self .operation_function = self ._concatenation
74
- self .concat_project = nn .Sequential (
75
- nn .Conv2d (self .inter_channels * 2 , 1 , 1 , 1 , 0 , bias = False ),
76
- nn .ReLU ()
77
- )
78
- elif mode == 'gaussian' :
79
- self .operation_function = self ._gaussian
62
+ # if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
63
+ self .theta = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
64
+ kernel_size = 1 , stride = 1 , padding = 0 )
65
+
66
+ self .phi = conv_nd (in_channels = self .in_channels , out_channels = self .inter_channels ,
67
+ kernel_size = 1 , stride = 1 , padding = 0 )
68
+ # elif mode == 'concatenation':
69
+ self .concat_project = nn .Sequential (
70
+ nn .Conv2d (self .inter_channels * 2 , 1 , 1 , 1 , 0 , bias = False ),
71
+ nn .ReLU ()
72
+ )
73
+
74
+ # if mode in ['embedded_gaussian', 'dot_product', 'concatenation']:
75
+ # self.theta = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
76
+ # kernel_size=1, stride=1, padding=0)
77
+ # self.phi = conv_nd(in_channels=self.in_channels, out_channels=self.inter_channels,
78
+ # kernel_size=1, stride=1, padding=0)
79
+ #
80
+ # if mode == 'embedded_gaussian':
81
+ # self.operation_function = self._embedded_gaussian
82
+ # elif mode == 'dot_product':
83
+ # self.operation_function = self._dot_product
84
+ # elif mode == 'concatenation':
85
+ # self.operation_function = self._concatenation
86
+ # self.concat_project = nn.Sequential(
87
+ # nn.Conv2d(self.inter_channels * 2, 1, 1, 1, 0, bias=False),
88
+ # nn.ReLU()
89
+ # )
90
+ # elif mode == 'gaussian':
91
+ # self.operation_function = self._gaussian
80
92
81
93
if sub_sample :
82
94
self .g = nn .Sequential (self .g , max_pool (kernel_size = 2 ))
@@ -91,7 +103,15 @@ def forward(self, x):
91
103
:return:
92
104
'''
93
105
94
- output = self .operation_function (x )
106
+ if self .mode == 'embedded_gaussian' :
107
+ output = self ._embedded_gaussian (x )
108
+ elif mode == 'dot_product' :
109
+ output = self ._dot_product (x )
110
+ elif mode == 'concatenation' :
111
+ output = self ._concatenation (x )
112
+ elif mode == 'gaussian' :
113
+ output = self ._gaussian (x )
114
+
95
115
return output
96
116
97
117
def _embedded_gaussian (self , x ):
0 commit comments