From bf91ed3a6bd2f133ea7151c43bddbe7af972b527 Mon Sep 17 00:00:00 2001 From: fxia22 Date: Tue, 5 Mar 2019 22:21:35 -0800 Subject: [PATCH] add feature transform --- README.md | 10 ++-- pointnet/model.py | 100 ++++++++++++++++++++++++++++------ utils/show_cls.py | 7 ++- utils/show_seg.py | 5 +- utils/train_classification.py | 13 +++-- utils/train_segmentation.py | 14 +++-- 6 files changed, 112 insertions(+), 37 deletions(-) diff --git a/README.md b/README.md index 755059706..db58a919b 100644 --- a/README.md +++ b/README.md @@ -25,10 +25,9 @@ python train_classification.py --dataset --nepoch= python train_segmentation.py --dataset --nepoch= ``` -# Performance +Use `--feature_transform` to use feature transform. -Sample segmentation result: -![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) +# Performance ## Classification performance @@ -37,7 +36,7 @@ On ModelNet40: | | Overall Acc | | :---: | :---: | | Original implementation | 89.2 | -| this implementation(w/o feature transform) | TBA | +| this implementation(w/o feature transform) | 86.4 | | this implementation(w/ feature transform) | TBA | On [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html) @@ -60,6 +59,9 @@ Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_ Note that this implementation trains each class separately, so classes with fewer data will have slightly lower performance than reference implementation. +Sample segmentation result: +![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D) + # Links - [Project Page](http://stanford.edu/~rqi/pointnet/) diff --git a/pointnet/model.py b/pointnet/model.py index 7be4b7cfa..9e5838e96 100644 --- a/pointnet/model.py +++ b/pointnet/model.py @@ -46,8 +46,46 @@ def forward(self, x): return x +class STNkd(nn.Module): + def __init__(self, k=64): + super(STNkd, self).__init__() + self.conv1 = torch.nn.Conv1d(k, 64, 1) + self.conv2 = torch.nn.Conv1d(64, 128, 1) + self.conv3 = torch.nn.Conv1d(128, 1024, 1) + self.fc1 = nn.Linear(1024, 512) + self.fc2 = nn.Linear(512, 256) + self.fc3 = nn.Linear(256, k*k) + self.relu = nn.ReLU() + + self.bn1 = nn.BatchNorm1d(64) + self.bn2 = nn.BatchNorm1d(128) + self.bn3 = nn.BatchNorm1d(1024) + self.bn4 = nn.BatchNorm1d(512) + self.bn5 = nn.BatchNorm1d(256) + + self.k = k + + def forward(self, x): + batchsize = x.size()[0] + x = F.relu(self.bn1(self.conv1(x))) + x = F.relu(self.bn2(self.conv2(x))) + x = F.relu(self.bn3(self.conv3(x))) + x = torch.max(x, 2, keepdim=True)[0] + x = x.view(-1, 1024) + + x = F.relu(self.bn4(self.fc1(x))) + x = F.relu(self.bn5(self.fc2(x))) + x = self.fc3(x) + + iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1) + if x.is_cuda: + iden = iden.cuda() + x = x + iden + x = x.view(-1, self.k, self.k) + return x + class PointNetfeat(nn.Module): - def __init__(self, global_feat = True): + def __init__(self, global_feat = True, feature_transform = False): super(PointNetfeat, self).__init__() self.stn = STN3d() self.conv1 = torch.nn.Conv1d(3, 64, 1) @@ -57,7 +95,9 @@ def __init__(self, global_feat = True): self.bn2 = nn.BatchNorm1d(128) self.bn3 = nn.BatchNorm1d(1024) self.global_feat = global_feat - + self.feature_transform = feature_transform + if self.feature_transform: + self.fstn = STNkd(k=64) def forward(self, x): n_pts = x.size()[2] @@ -66,21 +106,31 @@ def forward(self, x): x = torch.bmm(x, trans) x = x.transpose(2, 1) x = F.relu(self.bn1(self.conv1(x))) + + if self.feature_transform: + trans_feat = self.fstn(x) + x = x.transpose(2,1) + x = torch.bmm(x, trans_feat) + x = x.transpose(2,1) + else: + trans_feat = None + pointfeat = x x = F.relu(self.bn2(self.conv2(x))) x = self.bn3(self.conv3(x)) x = torch.max(x, 2, keepdim=True)[0] x = x.view(-1, 1024) if self.global_feat: - return x, trans + return x, trans, trans_feat else: x = x.view(-1, 1024, 1).repeat(1, 1, n_pts) - return torch.cat([x, pointfeat], 1), trans + return torch.cat([x, pointfeat], 1), trans, trans_feat class PointNetCls(nn.Module): - def __init__(self, k = 2): + def __init__(self, k=2, feature_transform=False): super(PointNetCls, self).__init__() - self.feat = PointNetfeat(global_feat=True) + self.feature_transform = feature_transform + self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) self.fc1 = nn.Linear(1024, 512) self.fc2 = nn.Linear(512, 256) self.fc3 = nn.Linear(256, k) @@ -90,17 +140,18 @@ def __init__(self, k = 2): self.relu = nn.ReLU() def forward(self, x): - x, trans = self.feat(x) + x, trans, trans_feat = self.feat(x) x = F.relu(self.bn1(self.fc1(x))) x = F.relu(self.bn2(self.dropout(self.fc2(x)))) x = self.fc3(x) - return F.log_softmax(x, dim=1), trans + return F.log_softmax(x, dim=1), trans, trans_feat class PointNetDenseCls(nn.Module): - def __init__(self, k = 2): + def __init__(self, k = 2, feature_transform=False): super(PointNetDenseCls, self).__init__() self.k = k - self.feat = PointNetfeat(global_feat=False) + self.feature_transform=feature_transform + self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform) self.conv1 = torch.nn.Conv1d(1088, 512, 1) self.conv2 = torch.nn.Conv1d(512, 256, 1) self.conv3 = torch.nn.Conv1d(256, 128, 1) @@ -112,7 +163,7 @@ def __init__(self, k = 2): def forward(self, x): batchsize = x.size()[0] n_pts = x.size()[2] - x, trans = self.feat(x) + x, trans, trans_feat = self.feat(x) x = F.relu(self.bn1(self.conv1(x))) x = F.relu(self.bn2(self.conv2(x))) x = F.relu(self.bn3(self.conv3(x))) @@ -120,27 +171,42 @@ def forward(self, x): x = x.transpose(2,1).contiguous() x = F.log_softmax(x.view(-1,self.k), dim=-1) x = x.view(batchsize, n_pts, self.k) - return x, trans + return x, trans, trans_feat +def feature_transform_reguliarzer(trans): + d = trans.size()[1] + batchsize = trans.size()[0] + I = torch.eye(d)[None, :, :] + if trans.is_cuda: + I = I.cuda() + loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1) - I), dim=(1,2))) + return loss if __name__ == '__main__': sim_data = Variable(torch.rand(32,3,2500)) trans = STN3d() out = trans(sim_data) print('stn', out.size()) - + print('loss', feature_transform_reguliarzer(out)) + + sim_data_64d = Variable(torch.rand(32, 64, 2500)) + trans = STNkd(k=64) + out = trans(sim_data_64d) + print('stn64d', out.size()) + print('loss', feature_transform_reguliarzer(out)) + pointfeat = PointNetfeat(global_feat=True) - out, _ = pointfeat(sim_data) + out, _, _ = pointfeat(sim_data) print('global feat', out.size()) pointfeat = PointNetfeat(global_feat=False) - out, _ = pointfeat(sim_data) + out, _, _ = pointfeat(sim_data) print('point feat', out.size()) cls = PointNetCls(k = 5) - out, _ = cls(sim_data) + out, _, _ = cls(sim_data) print('class', out.size()) seg = PointNetDenseCls(k = 3) - out, _ = seg(sim_data) + out, _, _ = seg(sim_data) print('seg', out.size()) diff --git a/utils/show_cls.py b/utils/show_cls.py index ed9348986..d52ef0878 100644 --- a/utils/show_cls.py +++ b/utils/show_cls.py @@ -22,9 +22,10 @@ test_dataset = ShapeNetDataset( root='shapenetcore_partanno_segmentation_benchmark_v0', - train=False, + split='test', classification=True, - npoints=opt.num_points) + npoints=opt.num_points, + data_augmentation=False) testdataloader = torch.utils.data.DataLoader( test_dataset, batch_size=32, shuffle=True) @@ -40,7 +41,7 @@ points, target = Variable(points), Variable(target[:, 0]) points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() - pred, _ = classifier(points) + pred, _, _ = classifier(points) loss = F.nll_loss(pred, target) pred_choice = pred.data.max(1)[1] diff --git a/utils/show_seg.py b/utils/show_seg.py index 3926876c7..a849751a3 100644 --- a/utils/show_seg.py +++ b/utils/show_seg.py @@ -26,7 +26,8 @@ d = ShapeNetDataset( root=opt.dataset, class_choice=[opt.class_choice], - train=False) + split='test', + data_augmentation=False) idx = opt.idx @@ -47,7 +48,7 @@ point = point.transpose(1, 0).contiguous() point = Variable(point.view(1, point.size()[0], point.size()[1])) -pred, _ = classifier(point) +pred, _, _ = classifier(point) pred_choice = pred.data.max(2)[1] print(pred_choice) diff --git a/utils/train_classification.py b/utils/train_classification.py index 5427cb579..f85edbe19 100644 --- a/utils/train_classification.py +++ b/utils/train_classification.py @@ -7,7 +7,7 @@ import torch.optim as optim import torch.utils.data from pointnet.dataset import ShapeNetDataset, ModelNetDataset -from pointnet.model import PointNetCls +from pointnet.model import PointNetCls, feature_transform_reguliarzer import torch.nn.functional as F from tqdm import tqdm @@ -25,6 +25,7 @@ parser.add_argument('--model', type=str, default='', help='model path') parser.add_argument('--dataset', type=str, required=True, help="dataset path") parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40") +parser.add_argument('--feature_transform', action='store_true', help="use feature transform") opt = parser.parse_args() print(opt) @@ -84,7 +85,7 @@ except OSError: pass -classifier = PointNetCls(k=num_classes) +classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) @@ -105,8 +106,10 @@ points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() - pred, _ = classifier(points) + pred, trans, trans_feat = classifier(points) loss = F.nll_loss(pred, target) + if opt.feature_transform: + loss += feature_transform_reguliarzer(trans_feat) * 0.001 loss.backward() optimizer.step() pred_choice = pred.data.max(1)[1] @@ -120,7 +123,7 @@ points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() - pred, _ = classifier(points) + pred, _, _ = classifier(points) loss = F.nll_loss(pred, target) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() @@ -136,7 +139,7 @@ points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() - pred, _ = classifier(points) + pred, _, _ = classifier(points) pred_choice = pred.data.max(1)[1] correct = pred_choice.eq(target.data).cpu().sum() total_correct += correct.item() diff --git a/utils/train_segmentation.py b/utils/train_segmentation.py index 0ae5ca263..237c298c6 100644 --- a/utils/train_segmentation.py +++ b/utils/train_segmentation.py @@ -7,7 +7,7 @@ import torch.optim as optim import torch.utils.data from pointnet.dataset import ShapeNetDataset -from pointnet.model import PointNetDenseCls +from pointnet.model import PointNetDenseCls, feature_transform_reguliarzer import torch.nn.functional as F from tqdm import tqdm import numpy as np @@ -24,7 +24,7 @@ parser.add_argument('--model', type=str, default='', help='model path') parser.add_argument('--dataset', type=str, required=True, help="dataset path") parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice") - +parser.add_argument('--feature_transform', action='store_true', help="use feature transform") opt = parser.parse_args() print(opt) @@ -66,7 +66,7 @@ blue = lambda x: '\033[94m' + x + '\033[0m' -classifier = PointNetDenseCls(k=num_classes) +classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform) if opt.model != '': classifier.load_state_dict(torch.load(opt.model)) @@ -85,11 +85,13 @@ points, target = points.cuda(), target.cuda() optimizer.zero_grad() classifier = classifier.train() - pred, _ = classifier(points) + pred, trans, trans_feat = classifier(points) pred = pred.view(-1, num_classes) target = target.view(-1, 1)[:, 0] - 1 #print(pred.size(), target.size()) loss = F.nll_loss(pred, target) + if opt.feature_transform: + loss += feature_transform_reguliarzer(trans_feat) * 0.001 loss.backward() optimizer.step() pred_choice = pred.data.max(1)[1] @@ -102,7 +104,7 @@ points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() - pred, _ = classifier(points) + pred, _, _ = classifier(points) pred = pred.view(-1, num_classes) target = target.view(-1, 1)[:, 0] - 1 loss = F.nll_loss(pred, target) @@ -119,7 +121,7 @@ points = points.transpose(2, 1) points, target = points.cuda(), target.cuda() classifier = classifier.eval() - pred, _ = classifier(points) + pred, _, _ = classifier(points) pred_choice = pred.data.max(2)[1] pred_np = pred_choice.cpu().data.numpy()