Skip to content

Commit bf91ed3

Browse files
committed
add feature transform
1 parent 44f0112 commit bf91ed3

File tree

6 files changed

+112
-37
lines changed

6 files changed

+112
-37
lines changed

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,10 +25,9 @@ python train_classification.py --dataset <dataset path> --nepoch=<number epochs>
2525
python train_segmentation.py --dataset <dataset path> --nepoch=<number epochs>
2626
```
2727

28-
# Performance
28+
Use `--feature_transform` to use feature transform.
2929

30-
Sample segmentation result:
31-
![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D)
30+
# Performance
3231

3332
## Classification performance
3433

@@ -37,7 +36,7 @@ On ModelNet40:
3736
| | Overall Acc |
3837
| :---: | :---: |
3938
| Original implementation | 89.2 |
40-
| this implementation(w/o feature transform) | TBA |
39+
| this implementation(w/o feature transform) | 86.4 |
4140
| this implementation(w/ feature transform) | TBA |
4241

4342
On [A subset of shapenet](http://web.stanford.edu/~ericyi/project_page/part_annotation/index.html)
@@ -60,6 +59,9 @@ Segmentation on [A subset of shapenet](http://web.stanford.edu/~ericyi/project_
6059

6160
Note that this implementation trains each class separately, so classes with fewer data will have slightly lower performance than reference implementation.
6261

62+
Sample segmentation result:
63+
![seg](https://raw.githubusercontent.com/fxia22/pointnet.pytorch/master/misc/show3d.png?token=AE638Oy51TL2HDCaeCF273X_-Bsy6-E2ks5Y_BUzwA%3D%3D)
64+
6365
# Links
6466

6567
- [Project Page](http://stanford.edu/~rqi/pointnet/)

pointnet/model.py

Lines changed: 83 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,46 @@ def forward(self, x):
4646
return x
4747

4848

49+
class STNkd(nn.Module):
50+
def __init__(self, k=64):
51+
super(STNkd, self).__init__()
52+
self.conv1 = torch.nn.Conv1d(k, 64, 1)
53+
self.conv2 = torch.nn.Conv1d(64, 128, 1)
54+
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
55+
self.fc1 = nn.Linear(1024, 512)
56+
self.fc2 = nn.Linear(512, 256)
57+
self.fc3 = nn.Linear(256, k*k)
58+
self.relu = nn.ReLU()
59+
60+
self.bn1 = nn.BatchNorm1d(64)
61+
self.bn2 = nn.BatchNorm1d(128)
62+
self.bn3 = nn.BatchNorm1d(1024)
63+
self.bn4 = nn.BatchNorm1d(512)
64+
self.bn5 = nn.BatchNorm1d(256)
65+
66+
self.k = k
67+
68+
def forward(self, x):
69+
batchsize = x.size()[0]
70+
x = F.relu(self.bn1(self.conv1(x)))
71+
x = F.relu(self.bn2(self.conv2(x)))
72+
x = F.relu(self.bn3(self.conv3(x)))
73+
x = torch.max(x, 2, keepdim=True)[0]
74+
x = x.view(-1, 1024)
75+
76+
x = F.relu(self.bn4(self.fc1(x)))
77+
x = F.relu(self.bn5(self.fc2(x)))
78+
x = self.fc3(x)
79+
80+
iden = Variable(torch.from_numpy(np.eye(self.k).flatten().astype(np.float32))).view(1,self.k*self.k).repeat(batchsize,1)
81+
if x.is_cuda:
82+
iden = iden.cuda()
83+
x = x + iden
84+
x = x.view(-1, self.k, self.k)
85+
return x
86+
4987
class PointNetfeat(nn.Module):
50-
def __init__(self, global_feat = True):
88+
def __init__(self, global_feat = True, feature_transform = False):
5189
super(PointNetfeat, self).__init__()
5290
self.stn = STN3d()
5391
self.conv1 = torch.nn.Conv1d(3, 64, 1)
@@ -57,7 +95,9 @@ def __init__(self, global_feat = True):
5795
self.bn2 = nn.BatchNorm1d(128)
5896
self.bn3 = nn.BatchNorm1d(1024)
5997
self.global_feat = global_feat
60-
98+
self.feature_transform = feature_transform
99+
if self.feature_transform:
100+
self.fstn = STNkd(k=64)
61101

62102
def forward(self, x):
63103
n_pts = x.size()[2]
@@ -66,21 +106,31 @@ def forward(self, x):
66106
x = torch.bmm(x, trans)
67107
x = x.transpose(2, 1)
68108
x = F.relu(self.bn1(self.conv1(x)))
109+
110+
if self.feature_transform:
111+
trans_feat = self.fstn(x)
112+
x = x.transpose(2,1)
113+
x = torch.bmm(x, trans_feat)
114+
x = x.transpose(2,1)
115+
else:
116+
trans_feat = None
117+
69118
pointfeat = x
70119
x = F.relu(self.bn2(self.conv2(x)))
71120
x = self.bn3(self.conv3(x))
72121
x = torch.max(x, 2, keepdim=True)[0]
73122
x = x.view(-1, 1024)
74123
if self.global_feat:
75-
return x, trans
124+
return x, trans, trans_feat
76125
else:
77126
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
78-
return torch.cat([x, pointfeat], 1), trans
127+
return torch.cat([x, pointfeat], 1), trans, trans_feat
79128

80129
class PointNetCls(nn.Module):
81-
def __init__(self, k = 2):
130+
def __init__(self, k=2, feature_transform=False):
82131
super(PointNetCls, self).__init__()
83-
self.feat = PointNetfeat(global_feat=True)
132+
self.feature_transform = feature_transform
133+
self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
84134
self.fc1 = nn.Linear(1024, 512)
85135
self.fc2 = nn.Linear(512, 256)
86136
self.fc3 = nn.Linear(256, k)
@@ -90,17 +140,18 @@ def __init__(self, k = 2):
90140
self.relu = nn.ReLU()
91141

92142
def forward(self, x):
93-
x, trans = self.feat(x)
143+
x, trans, trans_feat = self.feat(x)
94144
x = F.relu(self.bn1(self.fc1(x)))
95145
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
96146
x = self.fc3(x)
97-
return F.log_softmax(x, dim=1), trans
147+
return F.log_softmax(x, dim=1), trans, trans_feat
98148

99149
class PointNetDenseCls(nn.Module):
100-
def __init__(self, k = 2):
150+
def __init__(self, k = 2, feature_transform=False):
101151
super(PointNetDenseCls, self).__init__()
102152
self.k = k
103-
self.feat = PointNetfeat(global_feat=False)
153+
self.feature_transform=feature_transform
154+
self.feat = PointNetfeat(global_feat=False, feature_transform=feature_transform)
104155
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
105156
self.conv2 = torch.nn.Conv1d(512, 256, 1)
106157
self.conv3 = torch.nn.Conv1d(256, 128, 1)
@@ -112,35 +163,50 @@ def __init__(self, k = 2):
112163
def forward(self, x):
113164
batchsize = x.size()[0]
114165
n_pts = x.size()[2]
115-
x, trans = self.feat(x)
166+
x, trans, trans_feat = self.feat(x)
116167
x = F.relu(self.bn1(self.conv1(x)))
117168
x = F.relu(self.bn2(self.conv2(x)))
118169
x = F.relu(self.bn3(self.conv3(x)))
119170
x = self.conv4(x)
120171
x = x.transpose(2,1).contiguous()
121172
x = F.log_softmax(x.view(-1,self.k), dim=-1)
122173
x = x.view(batchsize, n_pts, self.k)
123-
return x, trans
174+
return x, trans, trans_feat
124175

176+
def feature_transform_reguliarzer(trans):
177+
d = trans.size()[1]
178+
batchsize = trans.size()[0]
179+
I = torch.eye(d)[None, :, :]
180+
if trans.is_cuda:
181+
I = I.cuda()
182+
loss = torch.mean(torch.norm(torch.bmm(trans, trans.transpose(2,1) - I), dim=(1,2)))
183+
return loss
125184

126185
if __name__ == '__main__':
127186
sim_data = Variable(torch.rand(32,3,2500))
128187
trans = STN3d()
129188
out = trans(sim_data)
130189
print('stn', out.size())
131-
190+
print('loss', feature_transform_reguliarzer(out))
191+
192+
sim_data_64d = Variable(torch.rand(32, 64, 2500))
193+
trans = STNkd(k=64)
194+
out = trans(sim_data_64d)
195+
print('stn64d', out.size())
196+
print('loss', feature_transform_reguliarzer(out))
197+
132198
pointfeat = PointNetfeat(global_feat=True)
133-
out, _ = pointfeat(sim_data)
199+
out, _, _ = pointfeat(sim_data)
134200
print('global feat', out.size())
135201

136202
pointfeat = PointNetfeat(global_feat=False)
137-
out, _ = pointfeat(sim_data)
203+
out, _, _ = pointfeat(sim_data)
138204
print('point feat', out.size())
139205

140206
cls = PointNetCls(k = 5)
141-
out, _ = cls(sim_data)
207+
out, _, _ = cls(sim_data)
142208
print('class', out.size())
143209

144210
seg = PointNetDenseCls(k = 3)
145-
out, _ = seg(sim_data)
211+
out, _, _ = seg(sim_data)
146212
print('seg', out.size())

utils/show_cls.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,10 @@
2222

2323
test_dataset = ShapeNetDataset(
2424
root='shapenetcore_partanno_segmentation_benchmark_v0',
25-
train=False,
25+
split='test',
2626
classification=True,
27-
npoints=opt.num_points)
27+
npoints=opt.num_points,
28+
data_augmentation=False)
2829

2930
testdataloader = torch.utils.data.DataLoader(
3031
test_dataset, batch_size=32, shuffle=True)
@@ -40,7 +41,7 @@
4041
points, target = Variable(points), Variable(target[:, 0])
4142
points = points.transpose(2, 1)
4243
points, target = points.cuda(), target.cuda()
43-
pred, _ = classifier(points)
44+
pred, _, _ = classifier(points)
4445
loss = F.nll_loss(pred, target)
4546

4647
pred_choice = pred.data.max(1)[1]

utils/show_seg.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@
2626
d = ShapeNetDataset(
2727
root=opt.dataset,
2828
class_choice=[opt.class_choice],
29-
train=False)
29+
split='test',
30+
data_augmentation=False)
3031

3132
idx = opt.idx
3233

@@ -47,7 +48,7 @@
4748
point = point.transpose(1, 0).contiguous()
4849

4950
point = Variable(point.view(1, point.size()[0], point.size()[1]))
50-
pred, _ = classifier(point)
51+
pred, _, _ = classifier(point)
5152
pred_choice = pred.data.max(2)[1]
5253
print(pred_choice)
5354

utils/train_classification.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.optim as optim
88
import torch.utils.data
99
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
10-
from pointnet.model import PointNetCls
10+
from pointnet.model import PointNetCls, feature_transform_reguliarzer
1111
import torch.nn.functional as F
1212
from tqdm import tqdm
1313

@@ -25,6 +25,7 @@
2525
parser.add_argument('--model', type=str, default='', help='model path')
2626
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
2727
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")
28+
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
2829

2930
opt = parser.parse_args()
3031
print(opt)
@@ -84,7 +85,7 @@
8485
except OSError:
8586
pass
8687

87-
classifier = PointNetCls(k=num_classes)
88+
classifier = PointNetCls(k=num_classes, feature_transform=opt.feature_transform)
8889

8990
if opt.model != '':
9091
classifier.load_state_dict(torch.load(opt.model))
@@ -105,8 +106,10 @@
105106
points, target = points.cuda(), target.cuda()
106107
optimizer.zero_grad()
107108
classifier = classifier.train()
108-
pred, _ = classifier(points)
109+
pred, trans, trans_feat = classifier(points)
109110
loss = F.nll_loss(pred, target)
111+
if opt.feature_transform:
112+
loss += feature_transform_reguliarzer(trans_feat) * 0.001
110113
loss.backward()
111114
optimizer.step()
112115
pred_choice = pred.data.max(1)[1]
@@ -120,7 +123,7 @@
120123
points = points.transpose(2, 1)
121124
points, target = points.cuda(), target.cuda()
122125
classifier = classifier.eval()
123-
pred, _ = classifier(points)
126+
pred, _, _ = classifier(points)
124127
loss = F.nll_loss(pred, target)
125128
pred_choice = pred.data.max(1)[1]
126129
correct = pred_choice.eq(target.data).cpu().sum()
@@ -136,7 +139,7 @@
136139
points = points.transpose(2, 1)
137140
points, target = points.cuda(), target.cuda()
138141
classifier = classifier.eval()
139-
pred, _ = classifier(points)
142+
pred, _, _ = classifier(points)
140143
pred_choice = pred.data.max(1)[1]
141144
correct = pred_choice.eq(target.data).cpu().sum()
142145
total_correct += correct.item()

utils/train_segmentation.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import torch.optim as optim
88
import torch.utils.data
99
from pointnet.dataset import ShapeNetDataset
10-
from pointnet.model import PointNetDenseCls
10+
from pointnet.model import PointNetDenseCls, feature_transform_reguliarzer
1111
import torch.nn.functional as F
1212
from tqdm import tqdm
1313
import numpy as np
@@ -24,7 +24,7 @@
2424
parser.add_argument('--model', type=str, default='', help='model path')
2525
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
2626
parser.add_argument('--class_choice', type=str, default='Chair', help="class_choice")
27-
27+
parser.add_argument('--feature_transform', action='store_true', help="use feature transform")
2828

2929
opt = parser.parse_args()
3030
print(opt)
@@ -66,7 +66,7 @@
6666

6767
blue = lambda x: '\033[94m' + x + '\033[0m'
6868

69-
classifier = PointNetDenseCls(k=num_classes)
69+
classifier = PointNetDenseCls(k=num_classes, feature_transform=opt.feature_transform)
7070

7171
if opt.model != '':
7272
classifier.load_state_dict(torch.load(opt.model))
@@ -85,11 +85,13 @@
8585
points, target = points.cuda(), target.cuda()
8686
optimizer.zero_grad()
8787
classifier = classifier.train()
88-
pred, _ = classifier(points)
88+
pred, trans, trans_feat = classifier(points)
8989
pred = pred.view(-1, num_classes)
9090
target = target.view(-1, 1)[:, 0] - 1
9191
#print(pred.size(), target.size())
9292
loss = F.nll_loss(pred, target)
93+
if opt.feature_transform:
94+
loss += feature_transform_reguliarzer(trans_feat) * 0.001
9395
loss.backward()
9496
optimizer.step()
9597
pred_choice = pred.data.max(1)[1]
@@ -102,7 +104,7 @@
102104
points = points.transpose(2, 1)
103105
points, target = points.cuda(), target.cuda()
104106
classifier = classifier.eval()
105-
pred, _ = classifier(points)
107+
pred, _, _ = classifier(points)
106108
pred = pred.view(-1, num_classes)
107109
target = target.view(-1, 1)[:, 0] - 1
108110
loss = F.nll_loss(pred, target)
@@ -119,7 +121,7 @@
119121
points = points.transpose(2, 1)
120122
points, target = points.cuda(), target.cuda()
121123
classifier = classifier.eval()
122-
pred, _ = classifier(points)
124+
pred, _, _ = classifier(points)
123125
pred_choice = pred.data.max(2)[1]
124126

125127
pred_np = pred_choice.cpu().data.numpy()

0 commit comments

Comments
 (0)