Skip to content

Commit

Permalink
add modelnet script
Browse files Browse the repository at this point in the history
  • Loading branch information
fxia22 committed Mar 5, 2019
1 parent 3c7e2cd commit 5787367
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 29 deletions.
40 changes: 40 additions & 0 deletions misc/modelnet_id.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
airplane 0
bathtub 1
bed 2
bench 3
bookshelf 4
bottle 5
bowl 6
car 7
chair 8
cone 9
cup 10
curtain 11
desk 12
door 13
dresser 14
flower_pot 15
glass_box 16
guitar 17
keyboard 18
lamp 19
laptop 20
mantel 21
monitor 22
night_stand 23
person 24
piano 25
plant 26
radio 27
range_hood 28
sink 29
sofa 30
stairs 31
stool 32
table 33
tent 34
toilet 35
tv_stand 36
vase 37
wardrobe 38
xbox 39
File renamed without changes.
99 changes: 84 additions & 15 deletions pointnet/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
from tqdm import tqdm
import json
from plyfile import PlyData, PlyElement

def get_segmentation_classes(root):
catfile = os.path.join(root, 'synsetoffset2category.txt')
Expand All @@ -27,7 +28,7 @@ def get_segmentation_classes(root):
token = (os.path.splitext(os.path.basename(fn))[0])
meta[item].append((os.path.join(dir_point, token + '.pts'), os.path.join(dir_seg, token + '.seg')))

with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'num_seg_classes.txt'), 'w') as f:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'w') as f:
for item in cat:
datapath = []
num_seg_classes = 0
Expand All @@ -42,6 +43,16 @@ def get_segmentation_classes(root):
print("category {} num segmentation classes {}".format(item, num_seg_classes))
f.write("{}\t{}\n".format(item, num_seg_classes))

def gen_modelnet_id(root):
classes = []
with open(os.path.join(root, 'train.txt'), 'r') as f:
for line in f:
classes.append(line.strip().split('/')[0])
classes = np.unique(classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'w') as f:
for i in range(len(classes)):
f.write('{}\t{}\n'.format(classes[i], i))

class ShapeNetDataset(data.Dataset):
def __init__(self,
root,
Expand Down Expand Up @@ -88,7 +99,7 @@ def __init__(self,

self.classes = dict(zip(sorted(self.cat), range(len(self.cat))))
print(self.classes)
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), 'num_seg_classes.txt'), 'r') as f:
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/num_seg_classes.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.seg_classes[ls[0]] = int(ls[1])
Expand Down Expand Up @@ -129,18 +140,76 @@ def __getitem__(self, index):
def __len__(self):
return len(self.datapath)

class ModelNetDataset(data.Dataset):
def __init__(self,
root,
npoints=2500,
split='train',
data_augmentation=True):
self.npoints = npoints
self.root = root
self.split = split
self.data_augmentation = data_augmentation
self.fns = []
with open(os.path.join(root, '{}.txt'.format(self.split)), 'r') as f:
for line in f:
self.fns.append(line.strip())

self.cat = {}
with open(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../misc/modelnet_id.txt'), 'r') as f:
for line in f:
ls = line.strip().split()
self.cat[ls[0]] = int(ls[1])

print(self.cat)
self.classes = list(self.cat.keys())

def __getitem__(self, index):
fn = self.fns[index]
cls = self.cat[fn.split('/')[0]]
with open(os.path.join(self.root, fn), 'rb') as f:
plydata = PlyData.read(f)
pts = np.vstack([plydata['vertex']['x'], plydata['vertex']['y'], plydata['vertex']['z']]).T
choice = np.random.choice(len(pts), self.npoints, replace=True)
point_set = pts[choice, :]

point_set = point_set - np.expand_dims(np.mean(point_set, axis=0), 0) # center
dist = np.max(np.sqrt(np.sum(point_set ** 2, axis=1)), 0)
point_set = point_set / dist # scale

if self.data_augmentation:
theta = np.random.uniform(0, np.pi * 2)
rotation_matrix = np.array([[np.cos(theta), -np.sin(theta)], [np.sin(theta), np.cos(theta)]])
point_set[:, [0, 2]] = point_set[:, [0, 2]].dot(rotation_matrix) # random rotation
point_set += np.random.normal(0, 0.02, size=point_set.shape) # random jitter

point_set = torch.from_numpy(point_set.astype(np.float32))
cls = torch.from_numpy(np.array([cls]).astype(np.int64))
return point_set, cls


def __len__(self):
return len(self.fns)

if __name__ == '__main__':
datapath = sys.argv[1]
print('test')
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())

d = ShapeNetDataset(root = datapath, classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())

#get_segmentation_classes(datapath)
dataset = sys.argv[1]
datapath = sys.argv[2]

if dataset == 'shapenet':
d = ShapeNetDataset(root = datapath, class_choice = ['Chair'])
print(len(d))
ps, seg = d[0]
print(ps.size(), ps.type(), seg.size(),seg.type())

d = ShapeNetDataset(root = datapath, classification = True)
print(len(d))
ps, cls = d[0]
print(ps.size(), ps.type(), cls.size(),cls.type())
# get_segmentation_classes(datapath)

if dataset == 'modelnet':
gen_modelnet_id(datapath)
d = ModelNetDataset(root=datapath)
print(len(d))
print(d[0])

44 changes: 30 additions & 14 deletions utils/train_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
from pointnet.dataset import ShapeNetDataset
from pointnet.dataset import ShapeNetDataset, ModelNetDataset
from pointnet.model import PointNetCls
import torch.nn.functional as F
from tqdm import tqdm
Expand All @@ -24,6 +24,7 @@
parser.add_argument('--outf', type=str, default='cls', help='output folder')
parser.add_argument('--model', type=str, default='', help='model path')
parser.add_argument('--dataset', type=str, required=True, help="dataset path")
parser.add_argument('--dataset_type', type=str, default='shapenet', help="dataset type shapenet|modelnet40")

opt = parser.parse_args()
print(opt)
Expand All @@ -35,26 +36,41 @@
random.seed(opt.manualSeed)
torch.manual_seed(opt.manualSeed)

dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
npoints=opt.num_points)
if opt.dataset_type == 'shapenet':
dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
npoints=opt.num_points)

test_dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
split='test',
npoints=opt.num_points)
elif opt.dataset_type == 'modelnet40':
dataset = ModelNetDataset(
root=opt.dataset,
npoints=opt.num_points)

test_dataset = ModelNetDataset(
root=opt.dataset,
split='test',
npoints=opt.num_points)
else:
exit('wrong dataset type')


dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))

test_dataset = ShapeNetDataset(
root=opt.dataset,
classification=True,
split='test',
npoints=opt.num_points)
testdataloader = torch.utils.data.DataLoader(
test_dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))
test_dataset,
batch_size=opt.batchSize,
shuffle=True,
num_workers=int(opt.workers))

print(len(dataset), len(test_dataset))
num_classes = len(dataset.classes)
Expand Down

0 comments on commit 5787367

Please sign in to comment.