Skip to content

Commit

Permalink
init commit
Browse files Browse the repository at this point in the history
  • Loading branch information
FanLu97 committed Oct 9, 2020
1 parent 70fd8a4 commit 86ef3bf
Show file tree
Hide file tree
Showing 26 changed files with 261,706 additions and 0 deletions.
136 changes: 136 additions & 0 deletions data/kittiloader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
import torch.utils.data as data

import os
import glob
import numpy as np

import torch
import torchvision

def read_pc(filename):
'''
Read point cloud to ndarray from a TXT file
The TXT file contains pre-processed point cloud using PCL (https://pointclouds.org/)
One line of the TXT file represents: [x y z intensity nx ny nz curvature]
'''
scan = np.loadtxt(filename, dtype=np.float32)
pc = scan[:,0:3]
sn = scan[:,4:8]
return pc, sn

def read_calib(filename):
'''
Read camera to velodyne tranformation matrix from calibration file
Output:
Tr: 4X4 Transformation matrix
'''
with open(filename) as f:
lines = f.readlines()
Tr_line = lines[-1]
Tr_words = Tr_line.split(' ')
calib_list = Tr_words[1:]
for i in range(len(calib_list)):
calib_list[i] = float(calib_list[i])
calib_array = np.array(calib_list).astype(np.float32)
Tr = np.zeros((4,4),dtype=np.float32)
Tr[0,:] = calib_array[0:4]
Tr[1,:] = calib_array[4:8]
Tr[2,:] = calib_array[8:12]
Tr[3,3] = 1.0
return Tr

def read_pose(filename, Tr):
'''
Read vehicle pose from pose file and calibrate to velodyne frame
Input:
Tr: Transformation matrix
'''
Tr_inv = np.linalg.inv(Tr)
Tlist = []
poses = np.loadtxt(filename, dtype=np.float32)
for i in range(poses.shape[0]):
one_pose = poses[i,:]
Tcam = np.zeros((4,4),dtype=np.float32)
Tcam[0,:] = one_pose[0:4]
Tcam[1,:] = one_pose[4:8]
Tcam[2,:] = one_pose[8:12]
Tcam[3,3] = 1.0
Tvelo = Tr_inv.dot(Tcam).dot(Tr)
Tlist.append(Tvelo)
return Tlist

def points_sample(points, sn, npoints):
'''
Random selected npoints points
'''
size = points.shape[0]
new_points = np.zeros((npoints, 3))
new_sn = np.zeros((npoints, 4))
for i in range(npoints):
index = np.random.randint(size)
new_points[i,:] = points[index,:]
new_sn[i,:] = sn[index,:]
return new_points, new_sn

def get_pointcloud(filename, npoints):
'''
Read point cloud from file and random sample points
'''
pc, sn = read_pc(filename)
pc, sn = points_sample(pc, sn, npoints)
pc = torch.from_numpy(pc.astype(np.float32))
sn = torch.from_numpy(sn.astype(np.float32))
return pc, sn

class KittiDataset(data.Dataset):
def __init__(self, root, seq, npoints):
super(KittiDataset, self).__init__()
self.velodyne_path = os.path.join(root, 'sequences', seq, 'velodyne_txt')
self.velodyne_names = glob.glob(os.path.join(self.velodyne_path, '*.txt'))
self.velodyne_names = sorted(self.velodyne_names)
self.poses_path = os.path.join(root, 'poses', seq+'.txt')
self.calib_path = os.path.join(root, 'sequences', seq, 'calib.txt')
self.npoints = npoints
Tr = read_calib(self.calib_path)
self.Tlist = read_pose(self.poses_path, Tr)
self.dataset = self.make_dataset()

def make_dataset(self):
max_ind = len(self.velodyne_names)
dataset = []
bias = 10
for i in range(max_ind):
src_idx = i
if i + bias >= max_ind:
dst_idx = i - bias
else:
dst_idx = i + bias
dataset.append([src_idx, dst_idx])
return dataset

def __getitem__(self, index):
src_idx, dst_idx = self.dataset[index]
src_file_name = self.velodyne_names[src_idx]
dst_file_name = self.velodyne_names[dst_idx]
src_pc, src_sn = get_pointcloud(src_file_name, self.npoints)
dst_pc, dst_sn = get_pointcloud(dst_file_name, self.npoints)
src_T = self.Tlist[src_idx]
dst_T = self.Tlist[dst_idx]
relaT = np.linalg.inv(dst_T).dot(src_T)
relaT = torch.from_numpy(relaT.astype(np.float32))

return src_pc, src_sn, dst_pc, dst_sn, relaT

def __len__(self):
return len(self.dataset)

if __name__ == '__main__':
root = ''
seq = '00'
npoints = 16384
trainset = KittiDataset(root, seq, npoints)
print(len(trainset))
src_pc, src_sn, dst_pc, dst_sn, relaT = trainset[0]
print(src_pc.shape, src_sn.shape, dst_pc.shape, dst_sn.shape, relaT.shape)
print(relaT)
66 changes: 66 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import argparse
import os

from models.models import Detector, Descriptor, RSKDD
from data.kittiloader import get_pointcloud

def parse_args():
parser = argparse.ArgumentParser('RSKDD-Net')
parser.add_argument('--gpu', type=str, default='0')
parser.add_argument('--model_path', type=str, default='./pretrain/rskdd.pth')
parser.add_argument('--save_dir', type=str, default='./demo/results')
parser.add_argument('--nsample', type=int, default=512)
parser.add_argument('--npoints', type=int, default=16384)
parser.add_argument('--k', type=int, default=128)
parser.add_argument('--desc_dim', type=int, default=128)
parser.add_argument('--dilation_ratio', type=float, default=2.0)
parser.add_argument('--data_dir', type=str, default='./demo/pc')

return parser.parse_args()

def demo(args):
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpu
model = RSKDD(args)
model = model.cuda()
model.load_state_dict(torch.load(args.model_path))
model.eval()

file_names = os.listdir(args.data_dir)

kp_save_dir = os.path.join(args.save_dir, "keypoints")
desc_save_dir = os.path.join(args.save_dir, "desc")
if not os.path.exists(kp_save_dir):
os.makedirs(kp_save_dir)
if not os.path.exists(desc_save_dir):
os.makedirs(desc_save_dir)

for file_name in file_names:
file_path = os.path.join(args.data_dir, file_name)
kp_save_path = os.path.join(kp_save_dir, file_name)
desc_save_path = os.path.join(desc_save_dir, file_name)

pc, sn = get_pointcloud(file_path, args.npoints)
feature = torch.cat((pc, sn), dim=-1)
feature = feature.unsqueeze(0)
feature = feature.cuda()

kp, sigmas, desc = model(feature)

kp_sigmas = torch.cat((kp, sigmas.unsqueeze(1)),dim=1)
kp_sigmas = kp_sigmas.squeeze().cpu().detach().numpy().transpose()
desc = desc.squeeze().cpu().detach().numpy().transpose()

print(file_name, "processed")

np.savetxt(kp_save_path, kp_sigmas, fmt='%.04f')
np.savetxt(desc_save_path, desc, fmt='%.04f')

print("Done")

if __name__ == '__main__':
args = parse_args()
demo(args)
46 changes: 46 additions & 0 deletions demo/demo_reg/demo_reg.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
clear;
clc;

kp_dir = '../results/keypoints';
desc_dir = '../results/desc';
pc_dir = '../pc';

nkp = 256;

dataset = 'ford'; % kitti or ford

src_name = strcat(dataset, '_01.txt');
dst_name = strcat(dataset, '_02.txt');

src_kp_path = fullfile(kp_dir, src_name);
src_desc_path = fullfile(desc_dir, src_name);
src_pc_path = fullfile(pc_dir, src_name);

dst_kp_path = fullfile(kp_dir, dst_name);
dst_desc_path = fullfile(desc_dir, dst_name);
dst_pc_path = fullfile(pc_dir, dst_name);

src_pc = load(src_pc_path);
src_kp_sigmas = load(src_kp_path);
src_kp = src_kp_sigmas(:,1:3);
src_sigmas = src_kp_sigmas(:,4);
src_desc = load(src_desc_path);
[temp, src_idx] = sort(src_sigmas);
src_kp = src_kp(src_idx,:);
src_desc = src_desc(src_idx,:);
src_kp = src_kp(1:nkp,:);
src_desc = src_desc(1:nkp,:);

dst_pc = load(dst_pc_path);
dst_kp_sigmas = load(dst_kp_path);
dst_kp = dst_kp_sigmas(:,1:3);
dst_sigmas = dst_kp_sigmas(:,4);
dst_desc = load(dst_desc_path);
[temp, dst_idx] = sort(dst_sigmas);
dst_kp = dst_kp(dst_idx,:);
dst_desc = dst_desc(dst_idx,:);
dst_kp = dst_kp(1:nkp,:);
dst_desc = dst_desc(1:nkp,:);

[R, t, src_inliers, dst_inliers] = estimateRt(src_kp, src_desc, dst_kp, dst_desc);
plot_match(src_pc, src_kp, src_inliers, dst_pc, dst_kp, dst_inliers);
50 changes: 50 additions & 0 deletions demo/demo_reg/estimateRt.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
function [R,t,src_inliers, dst_inliers] = estimateRt(src_kp,src_desc,dst_kp,dst_desc)
% src_pc: source keypoints, Nx3
% src_des: source descriptors, NxD
% dst_pc: target keypoints, Nx3
% dst_des: target descriptors, NxD
nsample = 3;
max_iter = 10000;
min_inlier_size = 100;

nsize = length(src_kp);

[src_corres_idx,src_corres_dists] = knnsearch(dst_desc,src_desc);
src_corres_kp = dst_kp(src_corres_idx,:);

iter = 0;
max_inlier_size = 0;

best_inlier_idx = 0;
dist_t = 1.0; % inlier threshold

N = 1;
p = 0.99;

while iter < max_iter && N > iter
rand_idx = randi(nsize, nsample, 1);
src_sample = src_kp(rand_idx,:);
dst_sample = src_corres_kp(rand_idx,:);
[R1, t1] = estimateRtSVD(src_sample, dst_sample);

src_trans = (R1*src_kp' + t1)';

resi = src_trans - src_corres_kp;
resi = vecnorm(resi, 2, 2);
inlier_idx = find(resi < dist_t);
inlier_size = length(inlier_idx);
if inlier_size > max_inlier_size
inlier_ratio = inlier_size/nsize;
pNoOutliers = 1 - inlier_ratio^nsample;
pNoOutliers = max(eps, pNoOutliers);
pNoOutliers = min(1-eps, pNoOutliers);
N = log(1-p)/log(pNoOutliers);
N = max(N,10);
best_inlier_idx = inlier_idx;
end
iter = iter + 1;
end

src_inliers = src_kp(best_inlier_idx,:);
dst_inliers = src_corres_kp(best_inlier_idx,:);
[R,t] = estimateRtSVD(src_inliers, dst_inliers);
15 changes: 15 additions & 0 deletions demo/demo_reg/estimateRtSVD.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
function [R,t] = estimateRtSVD(src_sample,dst_sample)
reflect = [1,0,0;0,1,0;0,0,-1];
src_sample_mean = mean(src_sample, 1);
src_sample_decenter = src_sample - src_sample_mean;
dst_sample_mean = mean(dst_sample, 1);
dst_sample_decenter = dst_sample - dst_sample_mean;
W = src_sample_decenter' * dst_sample_decenter;
[u,s,v] = svd(W);
R = v*u';
detR = det(R);
if detR < 0
v = v*reflect;
R = v*u';
end
t = -R*src_sample_mean' + dst_sample_mean';
39 changes: 39 additions & 0 deletions demo/demo_reg/plot_match.m
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
function [] = plot_match(src_pc, src_keypoints, src_inliers, dst_pc, dst_keypoints, dst_inliers)
s1 = 2.0;
s2 = 20.0;
gap = 1.0;
minz_src = min(src_pc(:,3));
maxz_src = max(src_pc(:,3));
minz_dst = min(dst_pc(:,3));
maxz_dst = max(dst_pc(:,3));
dist = abs(minz_dst - maxz_src);
dst_pc(:,3) = dst_pc(:,3)+gap+dist;
dst_keypoints(:,3) = dst_keypoints(:,3)+gap+dist;
dst_inliers(:,3) = dst_inliers(:,3)+gap+dist;

scatter3(src_pc(:,1),src_pc(:,2),src_pc(:,3),s1,src_pc(:,4));
hold on;
scatter3(dst_pc(:,1),dst_pc(:,2),dst_pc(:,3),s1,dst_pc(:,4));
hold on;
scatter3(src_keypoints(:,1),src_keypoints(:,2),src_keypoints(:,3),s2,'r','filled');
hold on;
scatter3(dst_keypoints(:,1),dst_keypoints(:,2),dst_keypoints(:,3),s2,'r','filled');
hold on;

inlier_size = size(src_inliers);
inlier_size = inlier_size(1);
point_1 = src_inliers(1,:);
point_2 = dst_inliers(1,:);
line = cat(1,point_1,point_2);
plot3(line(:,1),line(:,2),line(:,3),'r*');
hold on;
for i = 1:inlier_size
point_1 = src_inliers(i,:);
point_2 = dst_inliers(i,:);
line = cat(1,point_1,point_2);
plot3(line(:,1),line(:,2),line(:,3),'r');
hold on;
end
grid off;
axis off;
hold off;
Loading

0 comments on commit 86ef3bf

Please sign in to comment.