Skip to content

Commit 75dd61c

Browse files
authored
Merge pull request #15 from elliottzheng/master
fix:In pytorch 0.4 accuracy always zero
2 parents 3700c44 + 27d3413 commit 75dd61c

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

train_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
optimizer.step()
8181
pred_choice = pred.data.max(1)[1]
8282
correct = pred_choice.eq(target.data).cpu().sum()
83-
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.data[0], correct/float(opt.batchSize)))
83+
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(),correct.item() / float(opt.batchSize)))
8484

8585
if i % 10 == 0:
8686
j, data = next(enumerate(testdataloader, 0))
@@ -92,6 +92,6 @@
9292
loss = F.nll_loss(pred, target)
9393
pred_choice = pred.data.max(1)[1]
9494
correct = pred_choice.eq(target.data).cpu().sum()
95-
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.data[0], correct/float(opt.batchSize)))
95+
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))
9696

9797
torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

train_segmentation.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@
8080
optimizer.step()
8181
pred_choice = pred.data.max(1)[1]
8282
correct = pred_choice.eq(target.data).cpu().sum()
83-
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.data[0], correct/float(opt.batchSize * 2500)))
83+
print('[%d: %d/%d] train loss: %f accuracy: %f' %(epoch, i, num_batch, loss.item(), correct.item()/float(opt.batchSize * 2500)))
8484

8585
if i % 10 == 0:
8686
j, data = next(enumerate(testdataloader, 0))
@@ -95,6 +95,6 @@
9595
loss = F.nll_loss(pred, target)
9696
pred_choice = pred.data.max(1)[1]
9797
correct = pred_choice.eq(target.data).cpu().sum()
98-
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.data[0], correct/float(opt.batchSize * 2500)))
98+
print('[%d: %d/%d] %s loss: %f accuracy: %f' %(epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))
9999

100100
torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch))

0 commit comments

Comments
 (0)