Skip to content

Commit

Permalink
Merge pull request #18 from JadenTravnik/variable_number_of_points
Browse files Browse the repository at this point in the history
Use torch.max instead of torch.nn.MaxPool1d
  • Loading branch information
fxia22 authored Oct 8, 2018
2 parents 3bfb8fd + 5fd9592 commit 2ec315f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 19 deletions.
1 change: 0 additions & 1 deletion datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json
import codecs
import numpy as np
import progressbar
import sys
import torchvision.transforms as transforms
import argparse
Expand Down
32 changes: 14 additions & 18 deletions pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@


class STN3d(nn.Module):
def __init__(self, num_points = 2500):
def __init__(self):
super(STN3d, self).__init__()
self.num_points = num_points
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
Expand All @@ -43,7 +41,7 @@ def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)

x = F.relu(self.bn4(self.fc1(x)))
Expand All @@ -59,20 +57,19 @@ def forward(self, x):


class PointNetfeat(nn.Module):
def __init__(self, num_points = 2500, global_feat = True):
def __init__(self, global_feat = True):
super(PointNetfeat, self).__init__()
self.stn = STN3d(num_points = num_points)
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.num_points = num_points
self.global_feat = global_feat
def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
Expand All @@ -81,19 +78,18 @@ def forward(self, x):
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans

class PointNetCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetCls, self).__init__()
self.num_points = num_points
self.feat = PointNetfeat(num_points, global_feat=True)
self.feat = PointNetfeat(global_feat=True)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
Expand All @@ -105,14 +101,13 @@ def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return F.log_softmax(x, dim=-1), trans
return F.log_softmax(x, dim=0), trans

class PointNetDenseCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetDenseCls, self).__init__()
self.num_points = num_points
self.k = k
self.feat = PointNetfeat(num_points, global_feat=False)
self.feat = PointNetfeat(global_feat=False)
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
self.conv2 = torch.nn.Conv1d(512, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 128, 1)
Expand All @@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):

def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
x, trans = self.feat(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = x.transpose(2,1).contiguous()
x = F.log_softmax(x.view(-1,self.k), dim=-1)
x = x.view(batchsize, self.num_points, self.k)
x = x.view(batchsize, n_pts, self.k)
return x, trans


Expand Down

0 comments on commit 2ec315f

Please sign in to comment.