Skip to content

Commit

Permalink
add geoerror modification
Browse files Browse the repository at this point in the history
  • Loading branch information
wyddmw committed May 5, 2021
1 parent 3b89739 commit adf3e52
Show file tree
Hide file tree
Showing 9 changed files with 234 additions and 64 deletions.
File renamed without changes.
File renamed without changes.
Empty file added model/__init__.py
Empty file.
215 changes: 215 additions & 0 deletions model/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
from __future__ import print_function
import torch.utils.data as data
import os
import os.path
import torch
import numpy as np
import sys
from tqdm import tqdm
import json
from plyfile import PlyData, PlyElement

def get_segmentation_classes(root):
catfile = os.path.join(root, 'synsetoffset2category.txt')
cat = {}
meta = {}

with open(catfile, 'r') as f:
for line in f:
ls = line.strip().split()
cat[ls[0]] = ls[1]

for item in cat:
dir_seg = os.path.join(root, cat[item], 'points_label')
dir_point = os.path.join(root, cat[item], 'points')
fns = sorted(os.listdir(dir_point))
meta[item] = []
for fn in fns:
token = (os.path.splitext(os.path.basename(fn))[0])
meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))

with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:
for item in cat:
datapath = []
num_seg_classes = 0
for fn in meta[item]:
datapath.append((item, fn[0], fn[1]))

for i in tqdm(range(len(datapath))):
l = len(np.unique(np.loadtxt(datapath[i][-1]).astype(np.uint8)))
if l > num_seg_classes:
num_seg_classes = l

print("category {} num segmentation classes {}".format(item, num_seg_classes))
f.write("{}\t{}\n".format(item, num_seg_classes))

def gen_modelnet_id(root):
classes = []
with open(os.path.join(root, 'train.txt'), 'r') as f:
for line in f:
classes.append(line.strip().split('/')[0])
classes = np.unique(classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:
for i in range(len(classes)):
f.write('{}\t{}\n'.format(classes[i], i))

class ShapeNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
classification=False,
class_choice=None,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
self.cat = {}
self.data_augmentation = data_augmentation
self.classification = classification
self.seg_classes = {}

with open(self.catfile, 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}

self.id2cat = {v: k for k, v in self.cat.items()}

self.meta = {}
splitfile = os.path.join(self.root, 'train_test_split', 'shuffled_{}_file_list.json'.format(split))
#from IPython import embed; embed()
filelist = json.load(open(splitfile, 'r'))
for item in self.cat:
self.meta[item] = []

for file in filelist:
_, category, uuid = file.split('/')
if category in self.cat.values():
self.meta[self.id2cat[category]].append((os.path.join(self.root, category, 'points', uuid+'.pts'),
os.path.join(self.root, category, 'points_label', uuid+'.seg')))

self.datapath = []
for item in self.cat:
for fn in self.meta[item]:
self.datapath.append((item, fn[0], fn[1]))

self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
print(self.classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.seg_classes[ls[0]] = int(ls[1])
self.num_seg_classes = self.seg_classes[list(self.cat.keys())[0]]
print(self.seg_classes, self.num_seg_classes)

def __getitem__(self, index):
fn = self.datapath[index]
cls = self.classes[self.datapath[index][0]]
point_set = np.loadtxt(fn[1]).astype(np.float32)
seg = np.loadtxt(fn[2]).astype(np.int64)
#print(point_set.shape, seg.shape)

choice = np.random.choice(len(seg), self.npoints, replace=True)
#resample
point_set = point_set[choice, :]

point_set = point_set - np.expand_dims(np.mean(point_set, axis = 0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis = 1)),0)
point_set = point_set / dist #scale

if self.data_augmentation:
theta = np.random.uniform(0,np.pi*2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)],[np.sin(theta), np.cos(theta)]])
point_set[:,[0,2]] = point_set[:,[0,2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

seg = seg[choice]
point_set = torch.from_numpy(point_set)
seg = torch.from_numpy(seg)
cls = torch.from_numpy(np.array([cls]).astype(np.int64))

if self.classification:
return point_set, cls
else:
return point_set, seg

def __len__(self):
return len(self.datapath)

class ModelNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.split = split
self.data_augmentation = data_augmentation
self.fns = []
with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:
for line in f:
self.fns.append(line.strip())

self.cat = {}
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = int(ls[1])

print(self.cat)
self.classes = list(self.cat.keys())

def __getitem__(self, index):
fn = self.fns[index]
cls = self.cat[fn.split('/')[0]]
with open(os.path.join(self.root, fn), 'rb') as f:
plydata = PlyData.read(f)
pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
choice = np.random.choice(len(pts), self.npoints, replace=True)
point_set = pts[choice, :]

point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
point_set = point_set / dist # scale

if self.data_augmentation:
theta = np.random.uniform(0, np.pi * 2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

point_set = torch.from_numpy(point_set.astype(np.float32))
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
return point_set, cls


def __len__(self):
return len(self.fns)

if __name__ == '__main__':
dataset = sys.argv[1]
datapath = sys.argv[2]

if dataset == 'shapenet':
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())

d = ShapeNetDataset(root = datapath, classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())
# get_segmentation_classes(datapath)

if dataset == 'modelnet':
gen_modelnet_id(datapath)
d = ModelNetDataset(root=datapath)
print(len(d))
print(d[0])

72 changes: 12 additions & 60 deletions pointnet/model.py → model/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,49 +88,17 @@ class PointNetfeat(nn.Module):
def __init__(self, global_feat = True, feature_transform = False):
super(PointNetfeat, self).__init__()
self.stn = STN3d()
self.conv1 = torch.nn.Conv1d(3, 64, 1) # point shape N point_num xyz=3 -> N point_num 64
self.conv2 = torch.nn.Conv1d(64, 128, 1) # feature shape N point_num 128
self.conv3 = torch.nn.Conv1d(128, 1024, 1) # feature shape N point_num 1024
self.conv1 = torch.nn.Conv1d(3, 64, 1)
self.conv2 = torch.nn.Conv1d(64, 128, 1)
self.conv3 = torch.nn.Conv1d(128, 1024, 1)
self.bn1 = nn.BatchNorm1d(64)
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat
self.feature_transform = feature_transform
if self.feature_transform: #
self.fstn = STNkd(k=64)

def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))

if self.feature_transform:
trans_feat = self.fstn(x)
x = x.transpose(2,1)
x = torch.bmm(x, trans_feat)
x = x.transpose(2,1)
else:
trans_feat = None

pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x)) # N num_point 1024
x = torch.max(x, 2, keepdim=True)[0] # max pooling operation -> N 1024 global feature
x = x.view(-1, 1024) # batch_size * 1024
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat
self.fstn = STNkd(k=64)

class PointNetfeatureGeoError(PointNetfeat):
def __init__(self, global_feat=True, feature_transform=False, feature_encode=200):
super(PointNetfeatureGeoError, self).__init__(global_feat, feature_transform)
self.conv1 = torch.nn.Conv1d(3+feature_encode, 64, 1)

def forward(self, x):
n_pts = x.size()[2]
trans = self.stn(x)
Expand All @@ -149,52 +117,36 @@ def forward(self, x):

pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
x = self.bn3(self.conv3(x)) # N num_point 1024
x = torch.max(x, 2, keepdim=True)[0] # max pooling operation -> N 1024 global feature
x = x.view(-1, 1024) # batch_size * 1024
x = self.bn3(self.conv3(x))
x = torch.max(x, 2, keepdim=True)[0]
x = x.view(-1, 1024)
if self.global_feat:
return x, trans, trans_feat
else:
x = x.view(-1, 1024, 1).repeat(1, 1, n_pts)
return torch.cat([x, pointfeat], 1), trans, trans_feat


class PointNetCls(nn.Module):
def __init__(self, k=2, feature_transform=False):
# 初始化函数中定义网络结构
super(PointNetCls, self).__init__()
self.feature_transform = feature_transform
self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform) # input point cloud global feature shape is N * 1024 N indicates batch_size
self.fc1 = nn.Linear(1024, 512) # fully connected
self.feat = PointNetfeat(global_feat=True, feature_transform=feature_transform)
self.fc1 = nn.Linear(1024, 512)
self.fc2 = nn.Linear(512, 256)
self.fc3 = nn.Linear(256, k) # classification for K classes 4*21
self.dropout = nn.Dropout(p=0.3) # 神经元随机失活
self.fc3 = nn.Linear(256, k)
self.dropout = nn.Dropout(p=0.3)
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()

def forward(self, x):
# forward函数进行前向传播
x, trans, trans_feat = self.feat(x)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
x = self.fc3(x)
return F.log_softmax(x, dim=1), trans, trans_feat # 经过log_softmax之后得到21个分类上各自的概率
return F.log_softmax(x, dim=1), trans, trans_feat


class PointNetClsGeoError(PointNetCls):
def __init__(self, k=21, feature_transform=False):
super(PointNetClsGeoError, self).__init__(k, feature_transform)
self.feat = PointNetfeatureGeoError(global_feat=True, feature_transform=feature_transform)

def forward(self, x):
x, trans, trans_feat = self.feat(x)
x = F.relu(self.bn1(self.fc1(x)))
x = F.relu(self.bn2(self.dropout(self.fc2(x))))
x = self.fc3(x)
return F.log_softmax(x, dim=1), trans, trans_feat # 经过log_softmax之后得到21个分类上各自的概率


class PointNetDenseCls(nn.Module):
def __init__(self, k = 2, feature_transform=False):
super(PointNetDenseCls, self).__init__()
Expand Down
1 change: 1 addition & 0 deletions train.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
python train_classification.py --dataset /home/spyder/hazel/shapenet_dataset/shapenetcore_partanno_segmentation_benchmark_v0/ --nepoch=5 --dataset_type shapenet
8 changes: 5 additions & 3 deletions utils/train_classification.py → train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
from pointnet.model import PointNetCls, feature_transform_regularizer
from model.dataset import ShapeNetDataset, ModelNetDataset
from model.model import PointNetCls, feature_transform_regularizer
import torch.nn.functional as F
from tqdm import tqdm

Expand Down Expand Up @@ -105,9 +105,11 @@
target = target[:, 0]
points = points.transpose(2, 1)
points, target = points.cuda(), target.cuda()
print("target shape is ", target.shape)
optimizer.zero_grad()
classifier = classifier.train()
pred, trans, trans_feat = classifier(points) # 调用这个对象
print("pred shape is ", pred.shape)
loss = F.nll_loss(pred, target)
if opt.feature_transform:
loss += feature_transform_regularizer(trans_feat) * 0.001
Expand Down Expand Up @@ -146,4 +148,4 @@
total_correct += correct.item()
total_testset += points.size()[0]

print("final accuracy {}".format(total_correct / float(total_testset)))
print("final accuracy {}".format(total_correct / float(total_testset)))
Empty file added utils/__init__.py
Empty file.
2 changes: 1 addition & 1 deletion utils/train_classification_geoerror.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,4 +146,4 @@
total_correct += correct.item()
total_testset += points.size()[0]

print("final accuracy {}".format(total_correct / float(total_testset)))
print("final accuracy {}".format(total_correct / float(total_testset)))

0 comments on commit adf3e52

Please sign in to comment.