Skip to content

Commit f2550e3

Browse files
fix bugs
1 parent 831a606 commit f2550e3

File tree

3 files changed

+9
-8
lines changed

3 files changed

+9
-8
lines changed

exp/transfer/train_cihp_from_pascal.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def get_parser():
6969
parser.add_argument('--numworker',default=12,type=int)
7070
parser.add_argument('--freezeBN', choices=dict(true=True, false=False), default=True, action=LookupChoices)
7171
parser.add_argument('--step', default=10, type=int)
72-
parser.add_argument('--classes', default=7, type=int)
72+
parser.add_argument('--classes', default=20, type=int)
7373
parser.add_argument('--testInterval', default=10, type=int)
7474
parser.add_argument('--loadmodel',default='',type=str)
7575
parser.add_argument('--pretrainedModel', default='', type=str)
@@ -176,7 +176,7 @@ def main(opts):
176176

177177
# Network definition
178178
if backbone == 'xception':
179-
net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=20, os=16,
179+
net_ = deeplab_xception_transfer.deeplab_xception_transfer_projection_savemem(n_classes=opts.classes, os=16,
180180
hidden_layers=opts.hidden_layers, source_classes=7, )
181181
elif backbone == 'resnet':
182182
# net_ = deeplab_resnet.DeepLabv3_plus(nInputChannels=3, n_classes=7, os=16, pretrained=True)
@@ -195,7 +195,7 @@ def main(opts):
195195
if not model_path == '':
196196
x = torch.load(model_path)
197197
net_.load_state_dict_new(x)
198-
print('load pretrainedModel.')
198+
print('load pretrainedModel:', model_path)
199199
else:
200200
print('no pretrainedModel.')
201201
if not opts.loadmodel =='':
@@ -320,7 +320,7 @@ def main(opts):
320320
# One testing epoch
321321
if useTest and epoch % nTestInterval == (nTestInterval - 1):
322322
val_cihp(net_,testloader=testloader, testloader_flip=testloader_flip, test_graph=test_graph,
323-
epoch=epoch,writer=writer,criterion=criterion)
323+
epoch=epoch,writer=writer,criterion=criterion, classes=opts.classes)
324324
torch.cuda.empty_cache()
325325

326326

networks/deeplab_xception_transfer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -775,7 +775,7 @@ def forward(self, input,adj1_target=None, adj2_source=None,adj3_transfer=None ):
775775
source_graph = self.source_featuremap_2_graph(x)
776776
source_graph1 = self.source_graph_conv1.forward(source_graph,adj=adj2_source, relu=True)
777777
source_graph2 = self.source_graph_conv2.forward(source_graph1, adj=adj2_source, relu=True)
778-
source_graph3 = self.source_graph_conv2.forward(source_graph2, adj=adj2_source, relu=True)
778+
source_graph3 = self.source_graph_conv3.forward(source_graph2, adj=adj2_source, relu=True)
779779

780780
source_2_target_graph1_v5 = self.transpose_graph.forward(source_graph1, adj=adj3_transfer, relu=True)
781781
source_2_target_graph2_v5 = self.transpose_graph.forward(source_graph2, adj=adj3_transfer, relu=True)

networks/gcn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,7 @@ def norm_trans_adj(self,adj): # maybe can use softmax
266266
if __name__ == '__main__':
267267

268268
graph = torch.randn((7,128))
269-
pred = (torch.rand((7,7))*7).int()
270-
# a = en.forward(graph,pred)
271-
# print(a.size())
269+
en = GraphConvolution(128,12)
270+
# pred = (torch.rand((7,7))*7).int()
271+
a = en.forward(graph)
272+
print(a)

0 commit comments

Comments
 (0)