Skip to content

Commit 2ec315f

Browse files
authored
Merge pull request #18 from JadenTravnik/variable_number_of_points
Use torch.max instead of torch.nn.MaxPool1d
2 parents 3bfb8fd + 5fd9592 commit 2ec315f

File tree

2 files changed

+14
-19
lines changed

2 files changed

+14
-19
lines changed

datasets.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@
88
import json
99
import codecs
1010
import numpy as np
11-
import progressbar
1211
import sys
1312
import torchvision.transforms as transforms
1413
import argparse

pointnet.py

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,11 @@
1919

2020

2121
class STN3d(nn.Module):
22-
def __init__(self, num_points = 2500):
22+
def __init__(self):
2323
super(STN3d, self).__init__()
24-
self.num_points = num_points
2524
self.conv1 = torch.nn.Conv1d(3, 64, 1)
2625
self.conv2 = torch.nn.Conv1d(64, 128, 1)
2726
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
28-
self.mp1 = torch.nn.MaxPool1d(num_points)
2927
self.fc1 = nn.Linear(1024, 512)
3028
self.fc2 = nn.Linear(512, 256)
3129
self.fc3 = nn.Linear(256, 9)
@@ -43,7 +41,7 @@ def forward(self, x):
4341
x = F.relu(self.bn1(self.conv1(x)))
4442
x = F.relu(self.bn2(self.conv2(x)))
4543
x = F.relu(self.bn3(self.conv3(x)))
46-
x = self.mp1(x)
44+
x = torch.max(x, 2, keepdim=True)[0]
4745
x = x.view(-1, 1024)
4846

4947
x = F.relu(self.bn4(self.fc1(x)))
@@ -59,20 +57,19 @@ def forward(self, x):
5957

6058

6159
class PointNetfeat(nn.Module):
62-
def __init__(self, num_points = 2500, global_feat = True):
60+
def __init__(self, global_feat = True):
6361
super(PointNetfeat, self).__init__()
64-
self.stn = STN3d(num_points = num_points)
62+
self.stn = STN3d()
6563
self.conv1 = torch.nn.Conv1d(3, 64, 1)
6664
self.conv2 = torch.nn.Conv1d(64, 128, 1)
6765
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
6866
self.bn1 = nn.BatchNorm1d(64)
6967
self.bn2 = nn.BatchNorm1d(128)
7068
self.bn3 = nn.BatchNorm1d(1024)
71-
self.mp1 = torch.nn.MaxPool1d(num_points)
72-
self.num_points = num_points
7369
self.global_feat = global_feat
7470
def forward(self, x):
7571
batchsize = x.size()[0]
72+
n_pts = x.size()[2]
7673
trans = self.stn(x)
7774
x = x.transpose(2,1)
7875
x = torch.bmm(x, trans)
@@ -81,19 +78,18 @@ def forward(self, x):
8178
pointfeat = x
8279
x = F.relu(self.bn2(self.conv2(x)))
8380
x = self.bn3(self.conv3(x))
84-
x = self.mp1(x)
81+
x = torch.max(x, 2, keepdim=True)[0]
8582
x = x.view(-1, 1024)
8683
if self.global_feat:
8784
return x, trans
8885
else:
89-
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
86+
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
9087
return torch.cat([x, pointfeat], 1), trans
9188

9289
class PointNetCls(nn.Module):
93-
def __init__(self, num_points = 2500, k = 2):
90+
def __init__(self, k = 2):
9491
super(PointNetCls, self).__init__()
95-
self.num_points = num_points
96-
self.feat = PointNetfeat(num_points, global_feat=True)
92+
self.feat = PointNetfeat(global_feat=True)
9793
self.fc1 = nn.Linear(1024, 512)
9894
self.fc2 = nn.Linear(512, 256)
9995
self.fc3 = nn.Linear(256, k)
@@ -105,14 +101,13 @@ def forward(self, x):
105101
x = F.relu(self.bn1(self.fc1(x)))
106102
x = F.relu(self.bn2(self.fc2(x)))
107103
x = self.fc3(x)
108-
return F.log_softmax(x, dim=-1), trans
104+
return F.log_softmax(x, dim=0), trans
109105

110106
class PointNetDenseCls(nn.Module):
111-
def __init__(self, num_points = 2500, k = 2):
107+
def __init__(self, k = 2):
112108
super(PointNetDenseCls, self).__init__()
113-
self.num_points = num_points
114109
self.k = k
115-
self.feat = PointNetfeat(num_points, global_feat=False)
110+
self.feat = PointNetfeat(global_feat=False)
116111
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
117112
self.conv2 = torch.nn.Conv1d(512, 256, 1)
118113
self.conv3 = torch.nn.Conv1d(256, 128, 1)
@@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):
123118

124119
def forward(self, x):
125120
batchsize = x.size()[0]
121+
n_pts = x.size()[2]
126122
x, trans = self.feat(x)
127123
x = F.relu(self.bn1(self.conv1(x)))
128124
x = F.relu(self.bn2(self.conv2(x)))
129125
x = F.relu(self.bn3(self.conv3(x)))
130126
x = self.conv4(x)
131127
x = x.transpose(2,1).contiguous()
132128
x = F.log_softmax(x.view(-1,self.k), dim=-1)
133-
x = x.view(batchsize, self.num_points, self.k)
129+
x = x.view(batchsize, n_pts, self.k)
134130
return x, trans
135131

136132

0 commit comments

Comments
 (0)