Skip to content

Commit

Permalink
Delete unused code and re-format.
Browse files Browse the repository at this point in the history
  • Loading branch information
delldu committed Dec 5, 2018
1 parent f6eeca1 commit 05f2da9
Show file tree
Hide file tree
Showing 7 changed files with 216 additions and 215 deletions.
22 changes: 9 additions & 13 deletions datasets.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,17 @@
from __future__ import print_function
import torch.utils.data as data
from PIL import Image
import os
import os.path
import errno
import torch
import json
import codecs
import numpy as np
import sys
import torchvision.transforms as transforms
import argparse
import json


class PartDataset(data.Dataset):
def __init__(self, root, npoints = 2500, classification = False, class_choice = None, train = True):
def __init__(self,
root,
npoints=2500,
classification=False,
class_choice=None,
train=True):
self.npoints = npoints
self.root = root
self.catfile = os.path.join(self.root, 'synsetoffset2category.txt')
Expand All @@ -28,8 +24,8 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
ls = line.strip().split()
self.cat[ls[0]] = ls[1]
#print(self.cat)
if not class_choice is None:
self.cat = {k:v for k,v in self.cat.items() if k in class_choice}
if not class_choice is None:
self.cat = {k: v for k, v in self.cat.items() if k in class_choice}

self.meta = {}
for item in self.cat:
Expand Down Expand Up @@ -59,7 +55,7 @@ def __init__(self, root, npoints = 2500, classification = False, class_choice =
print(self.classes)
self.num_seg_classes = 0
if not self.classification:
for i in range(len(self.datapath)//50):
for i in range(len(self.datapath) // 50):
l = len(np.unique(np.loadtxt(self.datapath[i][-1]).astype(np.uint8)))
if l > self.num_seg_classes:
self.num_seg_classes = l
Expand Down
18 changes: 5 additions & 13 deletions pointnet.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,10 @@
from __future__ import print_function
import argparse
import os
import random
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import pdb
import torch.nn.functional as F


Expand Down Expand Up @@ -67,13 +57,14 @@ def __init__(self, global_feat = True):
self.bn2 = nn.BatchNorm1d(128)
self.bn3 = nn.BatchNorm1d(1024)
self.global_feat = global_feat


def forward(self, x):
batchsize = x.size()[0]
n_pts = x.size()[2]
trans = self.stn(x)
x = x.transpose(2,1)
x = x.transpose(2, 1)
x = torch.bmm(x, trans)
x = x.transpose(2,1)
x = x.transpose(2, 1)
x = F.relu(self.bn1(self.conv1(x)))
pointfeat = x
x = F.relu(self.bn2(self.conv2(x)))
Expand All @@ -96,6 +87,7 @@ def __init__(self, k = 2):
self.bn1 = nn.BatchNorm1d(512)
self.bn2 = nn.BatchNorm1d(256)
self.relu = nn.ReLU()

def forward(self, x):
x, trans = self.feat(x)
x = F.relu(self.bn1(self.fc1(x)))
Expand Down
194 changes: 100 additions & 94 deletions show3d_balls.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,142 +2,148 @@
import ctypes as ct
import cv2
import sys
showsz=800
mousex,mousey=0.5,0.5
zoom=1.0
changed=True
showsz = 800
mousex, mousey = 0.5, 0.5
zoom = 1.0
changed = True

def onmouse(*args):
global mousex,mousey,changed
y=args[1]
x=args[2]
mousex=x/float(showsz)
mousey=y/float(showsz)
changed=True
global mousex, mousey, changed
y = args[1]
x = args[2]
mousex = x / float(showsz)
mousey = y / float(showsz)
changed = True

cv2.namedWindow('show3d')
cv2.moveWindow('show3d',0,0)
cv2.setMouseCallback('show3d',onmouse)
cv2.moveWindow('show3d', 0, 0)
cv2.setMouseCallback('show3d', onmouse)

dll=np.ctypeslib.load_library('render_balls_so','.')
dll = np.ctypeslib.load_library('render_balls_so', '.')

def showpoints(xyz,c_gt=None, c_pred = None ,waittime=0,showrot=False,magnifyBlue=0,freezerot=False,background=(0,0,0),normalizecolor=True,ballradius=10):
global showsz,mousex,mousey,zoom,changed
def showpoints(xyz,c_gt=None, c_pred = None, waittime=0,
showrot=False, magnifyBlue=0, freezerot=False, background=(0,0,0),
normalizecolor=True, ballradius=10):
global showsz, mousex, mousey, zoom, changed
xyz=xyz-xyz.mean(axis=0)
radius=((xyz**2).sum(axis=-1)**0.5).max()
xyz/=(radius*2.2)/showsz
if c_gt is None:
c0=np.zeros((len(xyz),),dtype='float32')+255
c1=np.zeros((len(xyz),),dtype='float32')+255
c2=np.zeros((len(xyz),),dtype='float32')+255
c0 = np.zeros((len(xyz), ), dtype='float32') + 255
c1 = np.zeros((len(xyz), ), dtype='float32') + 255
c2 = np.zeros((len(xyz), ), dtype='float32') + 255
else:
c0=c_gt[:,0]
c1=c_gt[:,1]
c2=c_gt[:,2]
c0 = c_gt[:, 0]
c1 = c_gt[:, 1]
c2 = c_gt[:, 2]


if normalizecolor:
c0/=(c0.max()+1e-14)/255.0
c1/=(c1.max()+1e-14)/255.0
c2/=(c2.max()+1e-14)/255.0
c0 /= (c0.max() + 1e-14) / 255.0
c1 /= (c1.max() + 1e-14) / 255.0
c2 /= (c2.max() + 1e-14) / 255.0


c0=np.require(c0,'float32','C')
c1=np.require(c1,'float32','C')
c2=np.require(c2,'float32','C')
c0 = np.require(c0, 'float32', 'C')
c1 = np.require(c1, 'float32', 'C')
c2 = np.require(c2, 'float32', 'C')

show=np.zeros((showsz,showsz,3),dtype='uint8')
show = np.zeros((showsz, showsz, 3), dtype='uint8')
def render():
rotmat=np.eye(3)
if not freezerot:
xangle=(mousey-0.5)*np.pi*1.2
else:
xangle=0
rotmat=rotmat.dot(np.array([
[1.0,0.0,0.0],
[0.0,np.cos(xangle),-np.sin(xangle)],
[0.0,np.sin(xangle),np.cos(xangle)],
rotmat = rotmat.dot(
np.array([
[1.0, 0.0, 0.0],
[0.0, np.cos(xangle), -np.sin(xangle)],
[0.0, np.sin(xangle), np.cos(xangle)],
]))
if not freezerot:
yangle=(mousex-0.5)*np.pi*1.2
yangle = (mousex - 0.5) * np.pi * 1.2
else:
yangle=0
rotmat=rotmat.dot(np.array([
[np.cos(yangle),0.0,-np.sin(yangle)],
[0.0,1.0,0.0],
[np.sin(yangle),0.0,np.cos(yangle)],
yangle = 0
rotmat = rotmat.dot(
np.array([
[np.cos(yangle), 0.0, -np.sin(yangle)],
[0.0, 1.0, 0.0],
[np.sin(yangle), 0.0, np.cos(yangle)],
]))
rotmat*=zoom
nxyz=xyz.dot(rotmat)+[showsz/2,showsz/2,0]
rotmat *= zoom
nxyz = xyz.dot(rotmat) + [showsz / 2, showsz / 2, 0]

ixyz=nxyz.astype('int32')
show[:]=background
ixyz = nxyz.astype('int32')
show[:] = background
dll.render_ball(
ct.c_int(show.shape[0]),
ct.c_int(show.shape[1]),
show.ctypes.data_as(ct.c_void_p),
ct.c_int(ixyz.shape[0]),
ixyz.ctypes.data_as(ct.c_void_p),
c0.ctypes.data_as(ct.c_void_p),
c1.ctypes.data_as(ct.c_void_p),
c2.ctypes.data_as(ct.c_void_p),
ct.c_int(ballradius)
)
ct.c_int(show.shape[0]), ct.c_int(show.shape[1]),
show.ctypes.data_as(ct.c_void_p), ct.c_int(ixyz.shape[0]),
ixyz.ctypes.data_as(ct.c_void_p), c0.ctypes.data_as(ct.c_void_p),
c1.ctypes.data_as(ct.c_void_p), c2.ctypes.data_as(ct.c_void_p),
ct.c_int(ballradius))

if magnifyBlue>0:
show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=0))
if magnifyBlue>=2:
show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=0))
show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],1,axis=1))
if magnifyBlue>=2:
show[:,:,0]=np.maximum(show[:,:,0],np.roll(show[:,:,0],-1,axis=1))
if magnifyBlue > 0:
show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(
show[:, :, 0], 1, axis=0))
if magnifyBlue >= 2:
show[:, :, 0] = np.maximum(show[:, :, 0],
np.roll(show[:, :, 0], -1, axis=0))
show[:, :, 0] = np.maximum(show[:, :, 0], np.roll(
show[:, :, 0], 1, axis=1))
if magnifyBlue >= 2:
show[:, :, 0] = np.maximum(show[:, :, 0],
np.roll(show[:, :, 0], -1, axis=1))
if showrot:
cv2.putText(show,'xangle %d'%(int(xangle/np.pi*180)),(30,showsz-30),0,0.5,cv2.cv.CV_RGB(255,0,0))
cv2.putText(show,'yangle %d'%(int(yangle/np.pi*180)),(30,showsz-50),0,0.5,cv2.cv.CV_RGB(255,0,0))
cv2.putText(show,'zoom %d%%'%(int(zoom*100)),(30,showsz-70),0,0.5,cv2.cv.CV_RGB(255,0,0))
changed=True
cv2.putText(show, 'xangle %d' % (int(xangle / np.pi * 180)),
(30, showsz - 30), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
cv2.putText(show, 'yangle %d' % (int(yangle / np.pi * 180)),
(30, showsz - 50), 0, 0.5, cv2.cv.CV_RGB(255, 0, 0))
cv2.putText(show, 'zoom %d%%' % (int(zoom * 100)), (30, showsz - 70), 0,
0.5, cv2.cv.CV_RGB(255, 0, 0))
changed = True
while True:
if changed:
render()
changed=False
cv2.imshow('show3d',show)
if waittime==0:
cmd=cv2.waitKey(10)%256
changed = False
cv2.imshow('show3d', show)
if waittime == 0:
cmd = cv2.waitKey(10) % 256
else:
cmd=cv2.waitKey(waittime)%256
if cmd==ord('q'):
cmd = cv2.waitKey(waittime) % 256
if cmd == ord('q'):
break
elif cmd==ord('Q'):
elif cmd == ord('Q'):
sys.exit(0)

if cmd==ord('t') or cmd == ord('p'):
if cmd == ord('t') or cmd == ord('p'):
if cmd == ord('t'):
if c_gt is None:
c0=np.zeros((len(xyz),),dtype='float32')+255
c1=np.zeros((len(xyz),),dtype='float32')+255
c2=np.zeros((len(xyz),),dtype='float32')+255
c0 = np.zeros((len(xyz), ), dtype='float32') + 255
c1 = np.zeros((len(xyz), ), dtype='float32') + 255
c2 = np.zeros((len(xyz), ), dtype='float32') + 255
else:
c0=c_gt[:,0]
c1=c_gt[:,1]
c2=c_gt[:,2]
c0 = c_gt[:, 0]
c1 = c_gt[:, 1]
c2 = c_gt[:, 2]
else:
if c_pred is None:
c0=np.zeros((len(xyz),),dtype='float32')+255
c1=np.zeros((len(xyz),),dtype='float32')+255
c2=np.zeros((len(xyz),),dtype='float32')+255
c0 = np.zeros((len(xyz), ), dtype='float32') + 255
c1 = np.zeros((len(xyz), ), dtype='float32') + 255
c2 = np.zeros((len(xyz), ), dtype='float32') + 255
else:
c0=c_pred[:,0]
c1=c_pred[:,1]
c2=c_pred[:,2]
c0 = c_pred[:, 0]
c1 = c_pred[:, 1]
c2 = c_pred[:, 2]
if normalizecolor:
c0/=(c0.max()+1e-14)/255.0
c1/=(c1.max()+1e-14)/255.0
c2/=(c2.max()+1e-14)/255.0
c0=np.require(c0,'float32','C')
c1=np.require(c1,'float32','C')
c2=np.require(c2,'float32','C')
c0 /= (c0.max() + 1e-14) / 255.0
c1 /= (c1.max() + 1e-14) / 255.0
c2 /= (c2.max() + 1e-14) / 255.0
c0 = np.require(c0, 'float32', 'C')
c1 = np.require(c1, 'float32', 'C')
c2 = np.require(c2, 'float32', 'C')
changed = True



if cmd==ord('n'):
zoom*=1.1
changed=True
Expand All @@ -152,7 +158,7 @@ def render():
if waittime!=0:
break
return cmd
if __name__=='__main__':
np.random.seed(100)
showpoints(np.random.randn(2500,3))

if __name__ == '__main__':
np.random.seed(100)
showpoints(np.random.randn(2500, 3))
21 changes: 8 additions & 13 deletions show_cls.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
from __future__ import print_function
import argparse
import os
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.backends.cudnn as cudnn
import torch.optim as optim
import torch.utils.data
import torchvision.datasets as dset
import torchvision.transforms as transforms
import torchvision.utils as vutils
from torch.autograd import Variable
from datasets import PartDataset
from pointnet import PointNetCls
import torch.nn.functional as F
import matplotlib.pyplot as plt


#showpoints(np.random.randn(2500,3), c1 = np.random.uniform(0,1,size = (2500)))
Expand All @@ -28,11 +18,16 @@


opt = parser.parse_args()
print (opt)
print(opt)

test_dataset = PartDataset(root = 'shapenetcore_partanno_segmentation_benchmark_v0' , train = False, classification = True, npoints = opt.num_points)
test_dataset = PartDataset(
root='shapenetcore_partanno_segmentation_benchmark_v0',
train=False,
classification=True,
npoints=opt.num_points)

testdataloader = torch.utils.data.DataLoader(test_dataset, batch_size=32, shuffle = True)
testdataloader = torch.utils.data.DataLoader(
test_dataset, batch_size=32, shuffle=True)

classifier = PointNetCls(k=len(test_dataset.classes))
classifier.cuda()
Expand Down
Loading

0 comments on commit 05f2da9

Please sign in to comment.