Skip to content

Commit

Permalink
add some benchmarking code
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Mar 2, 2019
1 parent 0ce9a4b commit 523f07f
Show file tree
Hide file tree
Showing 7 changed files with 92 additions and 32 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ shapenetcore_partanno_segmentation_benchmark_v0/
.idea*
cls/
seg/
*.egg-info/
Empty file added pointnet/__init__.py
Empty file.
16 changes: 15 additions & 1 deletion pointnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,13 @@ def __init__(self,
npoints=2500,
classification=False,
class_choice=None,
train=True):
train=True,
data_augmentation=True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.data_augmentation = data_augmentation

self.classification = classification

Expand Down Expand Up @@ -73,10 +75,22 @@ def __getitem__(self, index):
choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]

point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
point_set = point_set / dist #scale

if self.data_augmentation:
theta = np.random.uniform(0,np.pi*2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))

if self.classification:
return point_set, cls
else:
Expand Down
5 changes: 3 additions & 2 deletions pointnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,16 +84,17 @@ def __init__(self, k = 2):
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
self.dropout = nn.Dropout(p=0.3)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()

def forward(self, x):
x, trans = self.feat(x)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
x = self.fc3(x)
return F.log_softmax(x, dim=0), trans
return F.log_softmax(x, dim=1), trans

class PointNetDenseCls(nn.Module):
def __init__(self, k = 2):
Expand Down
17 changes: 7 additions & 10 deletions utils/show_seg.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,34 +17,31 @@

parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--idx', type=int, default=0, help='model index')


parser.add_argument('--dataset', type=str, default='', help='dataset path')
parser.add_argument('--class_choice', type=str, default='', help='class choice')

opt = parser.parse_args()
print(opt)

d = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
class_choice=['Airplane'],
root=opt.dataset,
class_choice=[opt.class_choice],
train=False)

idx = opt.idx

print("model %d/%d" % (idx, len(d)))

point, seg = d[idx]
print(point.size(), seg.size())

point_np = point.numpy()



cmap = plt.cm.get_cmap("hsv", 10)
cmap = np.array([cmap(i) for i in range(10)])[:, :3]
gt = cmap[seg.numpy() - 1, :]

classifier = PointNetDenseCls(k=4)
classifier.load_state_dict(torch.load(opt.model))
state_dict = torch.load(opt.model)
classifier = PointNetDenseCls(k= state_dict['conv4.weight'].size()[0] )
classifier.load_state_dict(state_dict)
classifier.eval()

point = point.transpose(1, 0).contiguous()
Expand Down
35 changes: 26 additions & 9 deletions utils/train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from pointnet.dataset import PartDataset
from pointnet.model import PointNetCls
import torch.nn.functional as F

from tqdm import tqdm


parser = argparse.ArgumentParser()
Expand All @@ -21,9 +20,10 @@
parser.add_argument(
'--workers', type=int, help='number of data loading workers', default=4)
parser.add_argument(
'--nepoch', type=int, default=25, help='number of epochs to train for')
'--nepoch', type=int, default=250, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='cls', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")

opt = parser.parse_args()
print(opt)
Expand All @@ -36,7 +36,7 @@
torch.manual_seed(opt.manualSeed)

dataset = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
root=opt.dataset,
classification=True,
npoints=opt.num_points)
dataloader = torch.utils.data.DataLoader(
Expand All @@ -46,7 +46,7 @@
num_workers=int(opt.workers))

test_dataset = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
root=opt.dataset,
classification=True,
train=False,
npoints=opt.num_points)
Expand All @@ -65,22 +65,23 @@
except OSError:
pass


classifier = PointNetCls(k=num_classes)

if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))


optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
classifier.cuda()

num_batch = len(dataset) / opt.batchSize

for epoch in range(opt.nepoch):
scheduler.step()
for i, data in enumerate(dataloader, 0):
points, target = data
points, target = Variable(points), Variable(target[:, 0])
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
Expand All @@ -96,7 +97,7 @@
if i % 10 == 0:
j, data = next(enumerate(testdataloader, 0))
points, target = data
points, target = Variable(points), Variable(target[:, 0])
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
Expand All @@ -107,3 +108,19 @@
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize)))

torch.save(classifier.state_dict(), '%s/cls_model_%d.pth' % (opt.outf, epoch))

total_correct = 0
total_testset = 0
for i,data in tqdm(enumerate(testdataloader, 0)):
points, target = data
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _ = classifier(points)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
total_correct += correct.item()
total_testset += points.size()[0]

print("final accuracy {}".format(total_correct / float(total_testset)))
50 changes: 40 additions & 10 deletions utils/train_segmentation.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,11 @@
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from torch.autograd import Variable
from pointnet.dataset import PartDataset
from pointnet.model import PointNetDenseCls
import torch.nn.functional as F
from tqdm import tqdm
import numpy as np


parser = argparse.ArgumentParser()
Expand All @@ -21,6 +22,8 @@
'--nepoch', type=int, default=25, help='number of epochs to train for')
parser.add_argument('--outf', type=str, default='seg', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice")


opt = parser.parse_args()
Expand All @@ -32,19 +35,19 @@
torch.manual_seed(opt.manualSeed)

dataset = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
root=opt.dataset,
classification=False,
class_choice=['Chair'])
class_choice=[opt.class_choice])
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))

test_dataset = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
root=opt.dataset,
classification=False,
class_choice=['Chair'],
class_choice=[opt.class_choice],
train=False)
testdataloader = torch.utils.data.DataLoader(
test_dataset,
Expand All @@ -67,15 +70,16 @@
if opt.model != '':
classifier.load_state_dict(torch.load(opt.model))

optimizer = optim.SGD(classifier.parameters(), lr=0.01, momentum=0.9)
optimizer = optim.Adam(classifier.parameters(), lr=0.001, betas=(0.9, 0.999))
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.5)
classifier.cuda()

num_batch = len(dataset) / opt.batchSize

for epoch in range(opt.nepoch):
scheduler.step()
for i, data in enumerate(dataloader, 0):
points, target = data
points, target = Variable(points), Variable(target)
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
optimizer.zero_grad()
Expand All @@ -94,17 +98,43 @@
if i % 10 == 0:
j, data = next(enumerate(testdataloader, 0))
points, target = data
points, target = Variable(points), Variable(target)
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _ = classifier(points)
pred = pred.view(-1, num_classes)
target = target.view(-1, 1)[:, 0] - 1

loss = F.nll_loss(pred, target)
pred_choice = pred.data.max(1)[1]
correct = pred_choice.eq(target.data).cpu().sum()
print('[%d: %d/%d] %s loss: %f accuracy: %f' % (epoch, i, num_batch, blue('test'), loss.item(), correct.item()/float(opt.batchSize * 2500)))

torch.save(classifier.state_dict(), '%s/seg_model_%d.pth' % (opt.outf, epoch))
torch.save(classifier.state_dict(), '%s/seg_model_%s_%d.pth' % (opt.outf, opt.class_choice, epoch))

## benchmark mIOU
shape_ious = []
for i,data in tqdm(enumerate(testdataloader, 0)):
points, target = data
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
classifier = classifier.eval()
pred, _ = classifier(points)
pred_choice = pred.data.max(2)[1]

pred_np = pred_choice.cpu().data.numpy()
target_np = target.cpu().data.numpy() - 1

for shape_idx in range(target_np.shape[0]):
parts = np.unique(target_np[shape_idx])
part_ious = []
for part in parts:
I = np.sum(np.logical_and(pred_np[shape_idx] == part, target_np[shape_idx] == part))
U = np.sum(np.logical_or(pred_np[shape_idx] == part, target_np[shape_idx] == part))
if U == 0:
iou = 0
else:
iou = I / float(U)
part_ious.append(iou)
shape_ious.append(np.mean(part_ious))

print("mIOU for class {}: {}".format(opt.class_choice, np.mean(shape_ious)))

0 comments on commit 523f07f

Please sign in to comment.