Skip to content

Commit

Permalink
Use torch.max instead of torch.nn.MaxPool1d
Browse files Browse the repository at this point in the history
According to the original paper, there should be no restriction on the number of points within a point cloud. This PR updates the pointnet to allow for variable number of points within a pointcloud or from mini-batch to mini-batch. If one has a dataset that has different numbers of points within each point cloud, one way of using this is to upsample (similar to how images are padded with 0s when they are different sizes). However, one wants to minimize the number of added points because unlike image data, adding 0s in a point cloud changes the structure of the data. Instead, one should duplicate the fewest number of points such that each sample in a mini-batch has the same number of points but each mini-batch may have a different number of points per sample. To do this, one should sort their dataset by the number of points within each point cloud, then group the point clouds into sizes of the desired mini-batch. For example one mini-batch may have point clouds of sizes [901, 905, 905, ..., 945]. In order to upsample all pointclouds P in mini-batch_j, one randomly duplicates K points from a point cloud P_i with N points where K is the difference between the current point cloud size and the maximum point cloud size in mini-batch_j (K = max(P_i.size() for P_i in mini-batch_j) - N ). For example, the previous example mini-batch will now have sizes [945, 945, 945, ..., 945] because the first point cloud with size 901 had (945 - 901 = 44) points duplicated to it etc. Now that each mini-batch has the same number of points, one can train their network by randomly sampling these mini-batches.
  • Loading branch information
Jaden Travnik committed Jun 27, 2018
1 parent 75dd61c commit 5fd9592
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
1 change: 0 additions & 1 deletion datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json
import codecs
import numpy as np
import progressbar
import sys
import torchvision.transforms as transforms
import argparse
Expand Down
32 changes: 14 additions & 18 deletions pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@


class STN3d(nn.Module):
def __init__(self, num_points = 2500):
def __init__(self):
super(STN3d, self).__init__()
self.num_points = num_points
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
Expand All @@ -43,7 +41,7 @@ def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)

x = F.relu(self.bn4(self.fc1(x)))
Expand All @@ -59,20 +57,19 @@ def forward(self, x):


class PointNetfeat(nn.Module):
def __init__(self, num_points = 2500, global_feat = True):
def __init__(self, global_feat = True):
super(PointNetfeat, self).__init__()
self.stn = STN3d(num_points = num_points)
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.num_points = num_points
self.global_feat = global_feat
def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
Expand All @@ -81,19 +78,18 @@ def forward(self, x):
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans

class PointNetCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetCls, self).__init__()
self.num_points = num_points
self.feat = PointNetfeat(num_points, global_feat=True)
self.feat = PointNetfeat(global_feat=True)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
Expand All @@ -105,14 +101,13 @@ def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return F.log_softmax(x, dim=-1), trans
return F.log_softmax(x, dim=0), trans

class PointNetDenseCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetDenseCls, self).__init__()
self.num_points = num_points
self.k = k
self.feat = PointNetfeat(num_points, global_feat=False)
self.feat = PointNetfeat(global_feat=False)
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
self.conv2 = torch.nn.Conv1d(512, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 128, 1)
Expand All @@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):

def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
x, trans = self.feat(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = x.transpose(2,1).contiguous()
x = F.log_softmax(x.view(-1,self.k), dim=-1)
x = x.view(batchsize, self.num_points, self.k)
x = x.view(batchsize, n_pts, self.k)
return x, trans


Expand Down

0 comments on commit 5fd9592

Please sign in to comment.