|
24 | 24 | parser = argparse.ArgumentParser()
|
25 | 25 |
|
26 | 26 | parser.add_argument('--model', type=str, default = '', help='model path')
|
| 27 | +parser.add_argument('--num_points', type=int, default=2500, help='input batch size') |
27 | 28 |
|
28 | 29 |
|
29 | 30 | opt = parser.parse_args()
|
30 | 31 | print (opt)
|
31 | 32 |
|
32 |
| -test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0' , train = False, classification = True) |
| 33 | +test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0' , train = False, classification = True, npoints = opt.num_points) |
33 | 34 |
|
34 |
| -testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle = False) |
| 35 | +testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle = True) |
35 | 36 |
|
36 | 37 |
|
37 |
| -classifier = PointNetCls(k = len(test_dataset.classes)) |
| 38 | +classifier = PointNetCls(k = len(test_dataset.classes), num_points = opt.num_points) |
38 | 39 | classifier.cuda()
|
39 | 40 | classifier.load_state_dict(torch.load(opt.model))
|
40 | 41 | classifier.eval()
|
41 | 42 |
|
| 43 | + |
42 | 44 | for i, data in enumerate(testdataloader, 0):
|
43 | 45 | points, target = data
|
44 |
| - points, target = Variable(points), Variable(target[:,0]) |
45 |
| - points = points.transpose(2,1) |
| 46 | + points, target = Variable(points), Variable(target[:, 0]) |
| 47 | + points = points.transpose(2, 1) |
46 | 48 | points, target = points.cuda(), target.cuda()
|
47 | 49 | pred, _ = classifier(points)
|
48 | 50 | loss = F.nll_loss(pred, target)
|
49 |
| - from IPython import embed; embed() |
| 51 | + |
50 | 52 | pred_choice = pred.data.max(1)[1]
|
51 | 53 | correct = pred_choice.eq(target.data).cpu().sum()
|
52 | 54 | print('i:%d loss: %f accuracy: %f' %(i, loss.data[0], correct/float(32)))
|
0 commit comments