-
Notifications
You must be signed in to change notification settings - Fork 353
/
feature_superpoint.py
128 lines (107 loc) · 4.83 KB
/
feature_superpoint.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
* This file is part of PYSLAM
*
* Copyright (C) 2016-present Luigi Freda <luigi dot freda at gmail dot com>
*
* PYSLAM is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PYSLAM is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with PYSLAM. If not, see <http://www.gnu.org/licenses/>.
"""
import sys
import os
import cv2
import torch
import time
import platform
import config
config.cfg.set_lib('superpoint')
from demo_superpoint import SuperPointFrontend
from threading import RLock
from utils_sys import Printer, is_opencv_version_greater_equal
kVerbose = True
class SuperPointOptions:
def __init__(self, do_cuda=True):
# default options from demo_superpoints
self.weights_path=config.cfg.root_folder + '/thirdparty/superpoint/superpoint_v1.pth'
print(f'SuperPoint weights: {self.weights_path}')
self.nms_dist=4
self.conf_thresh=0.015
self.nn_thresh=0.7
use_cuda = torch.cuda.is_available() and do_cuda
device = torch.device('cuda' if use_cuda else 'cpu')
print('SuperPoint using ', device)
self.cuda=use_cuda
# convert matrix of pts into list of keypoints
# N.B.: pts are - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
def convert_superpts_to_keypoints(pts, size=1):
kps = []
if pts is not None:
# convert matrix [Nx2] of pts into list of keypoints
if is_opencv_version_greater_equal(4,5,3):
kps = [ cv2.KeyPoint(p[0], p[1], size=size, response=p[2]) for p in pts ]
else:
kps = [ cv2.KeyPoint(p[0], p[1], _size=size, _response=p[2]) for p in pts ]
return kps
def transpose_des(des):
if des is not None:
return des.T
else:
return None
# interface for pySLAM
class SuperPointFeature2D:
def __init__(self, do_cuda=True):
if platform.system() == 'Darwin':
do_cuda=False
self.lock = RLock()
self.opts = SuperPointOptions(do_cuda)
print(self.opts)
print('SuperPointFeature2D')
print('==> Loading pre-trained network.')
# This class runs the SuperPoint network and processes its outputs.
self.fe = SuperPointFrontend(weights_path=self.opts.weights_path,
nms_dist=self.opts.nms_dist,
conf_thresh=self.opts.conf_thresh,
nn_thresh=self.opts.nn_thresh,
cuda=self.opts.cuda)
print('==> Successfully loaded pre-trained network.')
self.pts = []
self.kps = []
self.des = []
self.heatmap = []
self.frame = None
self.frameFloat = None
self.keypoint_size = 20 # just a representative size for visualization and in order to convert extracted points to cv2.KeyPoint
# compute both keypoints and descriptors
def detectAndCompute(self, frame, mask=None): # mask is a fake input
with self.lock:
self.frame = frame
self.frameFloat = (frame.astype('float32') / 255.)
self.pts, self.des, self.heatmap = self.fe.run(self.frameFloat)
# N.B.: pts are - 3xN numpy array with corners [x_i, y_i, confidence_i]^T.
#print('pts: ', self.pts.T)
self.kps = convert_superpts_to_keypoints(self.pts.T, size=self.keypoint_size)
if kVerbose:
print('detector: SUPERPOINT, #features: ', len(self.kps), ', frame res: ', frame.shape[0:2])
return self.kps, transpose_des(self.des)
# return keypoints if available otherwise call detectAndCompute()
def detect(self, frame, mask=None): # mask is a fake input
with self.lock:
#if self.frame is not frame:
self.detectAndCompute(frame)
return self.kps
# return descriptors if available otherwise call detectAndCompute()
def compute(self, frame, kps=None, mask=None): # kps is a fake input, mask is a fake input
with self.lock:
if self.frame is not frame:
Printer.orange('WARNING: SUPERPOINT is recomputing both kps and des on last input frame', frame.shape)
self.detectAndCompute(frame)
return self.kps, transpose_des(self.des)