Skip to content

Use torch.max instead of torch.nn.MaxPool1d #18

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Oct 8, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import json
import codecs
import numpy as np
import progressbar
import sys
import torchvision.transforms as transforms
import argparse
Expand Down
32 changes: 14 additions & 18 deletions pointnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,11 @@


class STN3d(nn.Module):
def __init__(self, num_points = 2500):
def __init__(self):
super(STN3d, self).__init__()
self.num_points = num_points
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, 9)
Expand All @@ -43,7 +41,7 @@ def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)

x = F.relu(self.bn4(self.fc1(x)))
Expand All @@ -59,20 +57,19 @@ def forward(self, x):


class PointNetfeat(nn.Module):
def __init__(self, num_points = 2500, global_feat = True):
def __init__(self, global_feat = True):
super(PointNetfeat, self).__init__()
self.stn = STN3d(num_points = num_points)
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.mp1 = torch.nn.MaxPool1d(num_points)
self.num_points = num_points
self.global_feat = global_feat
def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans)
Expand All @@ -81,19 +78,18 @@ def forward(self, x):
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x))
x = self.mp1(x)
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans
else:
x = x.view(-1, 1024, 1).repeat(1, 1, self.num_points)
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans

class PointNetCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetCls, self).__init__()
self.num_points = num_points
self.feat = PointNetfeat(num_points, global_feat=True)
self.feat = PointNetfeat(global_feat=True)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k)
Expand All @@ -105,14 +101,13 @@ def forward(self, x):
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.fc2(x)))
x = self.fc3(x)
return F.log_softmax(x, dim=-1), trans
return F.log_softmax(x, dim=0), trans

class PointNetDenseCls(nn.Module):
def __init__(self, num_points = 2500, k = 2):
def __init__(self, k = 2):
super(PointNetDenseCls, self).__init__()
self.num_points = num_points
self.k = k
self.feat = PointNetfeat(num_points, global_feat=False)
self.feat = PointNetfeat(global_feat=False)
self.conv1 = torch.nn.Conv1d(1088, 512, 1)
self.conv2 = torch.nn.Conv1d(512, 256, 1)
self.conv3 = torch.nn.Conv1d(256, 128, 1)
Expand All @@ -123,14 +118,15 @@ def __init__(self, num_points = 2500, k = 2):

def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
x, trans = self.feat(x)
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = F.relu(self.bn3(self.conv3(x)))
x = self.conv4(x)
x = x.transpose(2,1).contiguous()
x = F.log_softmax(x.view(-1,self.k), dim=-1)
x = x.view(batchsize, self.num_points, self.k)
x = x.view(batchsize, n_pts, self.k)
return x, trans


Expand Down