@@ -27,10 +27,10 @@ def reset_parameters(self):
2727 # self.bias.data.uniform_(-stdv,stdv)
2828
2929 def forward (self , input ,adj = None ,relu = False ):
30- support = torch .matmul (input ,self .weight )
30+ support = torch .matmul (input , self .weight )
3131 # print(support.size(),adj.size())
3232 if adj is not None :
33- output = torch .matmul (adj ,support )
33+ output = torch .matmul (adj , support )
3434 else :
3535 output = support
3636 # print(output.size())
@@ -97,7 +97,7 @@ def forward(self, input, source_pre_fea):
9797 fea_node = torch .matmul (input1 ,self .pre_fea ) # n x hw x n_classes
9898 weight_node = torch .matmul (input1 ,self .weight ) # n x hw x hidden_layer
9999 # softmax fea_node
100- fea_node = F .softmax (fea_node ,dim = - 1 )
100+ fea_node = F .softmax (fea_node ,dim = 1 )
101101 # print(fea_node.size(),weight_node.size())
102102 graph_node = F .relu (torch .matmul (fea_node .transpose (1 ,2 ),weight_node ))
103103 return graph_node # n x n_class x hidden_layer
@@ -145,6 +145,9 @@ def forward(self, input, res_feature):
145145 new_node = torch .matmul (new_fea , self .node_fea ) # batch x hw x nodes x 1
146146 new_weight = torch .matmul (input , self .weight ) # batch x node x channel
147147 new_node = new_node .view (batch , hi * wi , nodes )
148+ # 0721
149+ new_node = F .softmax (new_node , dim = - 1 )
150+ #
148151 feature_out = torch .matmul (new_node ,new_weight )
149152 # print(feature_out.size())
150153 feature_out = feature_out .transpose (2 ,3 ).contiguous ().view (res_feature .size ())
@@ -194,6 +197,9 @@ def forward(self, input, res_feature):
194197 # new_node = torch.matmul(new_fea, self.node_fea) # batch x hw x nodes x 1
195198 new_weight = torch .matmul (input , self .weight ) # batch x node x channel
196199 new_node = new_node .view (batch , hi * wi , nodes )
200+ # 0721
201+ new_node = F .softmax (new_node , dim = - 1 )
202+ #
197203 feature_out = torch .matmul (new_node ,new_weight )
198204 # print(feature_out.size())
199205 feature_out = feature_out .transpose (2 ,3 ).contiguous ().view (res_feature .size ())
@@ -217,7 +223,7 @@ def __init__(self,in_features,out_features,begin_nodes=7,end_nodes=2,bias=False,
217223 self .bias = Parameter (torch .FloatTensor (out_features ))
218224 else :
219225 self .register_parameter ('bias' ,None )
220- self .reset_parameters ()
226+ # self.reset_parameters()
221227
222228 def reset_parameters (self ):
223229 # stdv = 1./math.sqrt(self.weight(1))
@@ -266,7 +272,8 @@ def norm_trans_adj(self,adj): # maybe can use softmax
266272if __name__ == '__main__' :
267273
268274 graph = torch .randn ((7 ,128 ))
269- en = GraphConvolution (128 ,12 )
270- # pred = (torch.rand((7,7))*7).int()
275+ en = GraphConvolution (128 ,128 )
271276 a = en .forward (graph )
272- print (a )
277+ print (a )
278+ # a = en.forward(graph,pred)
279+ # print(a.size())
0 commit comments