-
Notifications
You must be signed in to change notification settings - Fork 366
/
Copy pathfeature_lightglue_sift.py
99 lines (85 loc) · 3.97 KB
/
feature_lightglue_sift.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
"""
* This file is part of PYSLAM
*
* Copyright (C) 2016-present Luigi Freda <luigi dot freda at gmail dot com>
*
* PYSLAM is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* PYSLAM is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with PYSLAM. If not, see <http://www.gnu.org/licenses/>.
"""
import os
import cv2
import numpy as np
import torch
from threading import RLock
from utils_sys import Printer, import_from, is_opencv_version_greater_equal
import config
config.cfg.set_lib('lightglue')
SIFT = import_from('lightglue', 'SIFT')
kVerbose = True
def numpy_image_to_torch(image: np.ndarray) -> torch.Tensor:
"""Normalize the image tensor and reorder the dimensions."""
if image.ndim == 3:
image = image.transpose((2, 0, 1)) # HxWxC to CxHxW
elif image.ndim == 2:
image = image[None] # add channel axis
else:
raise ValueError(f"Not an image: {image.shape}")
return torch.tensor(image / 255.0, dtype=torch.float)
def convert_pts_to_keypoints(pts, scales, oris):
kps = []
if pts is not None:
# convert matrix [Nx2] of pts into list of keypoints
kps = [ cv2.KeyPoint(p[0], p[1], size=s, angle=o, response=1.0, octave=0) for p,s,o in zip(pts,scales,oris) ]
return kps
# interface for pySLAM
class LightGlueSIFTFeature2D:
def __init__(self,num_features=2000):
print('Using LightGlueSIFTFeature2D')
self.num_features = num_features
self.config = SIFT.default_conf.copy()
self.config['max_num_keypoints'] = self.num_features
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # 'mps', 'cpu'
self.SIFT = SIFT(conf=self.config)
def setMaxFeatures(self, num_features): # use the cv2 method name for extractors (see https://docs.opencv.org/4.x/db/d95/classcv_1_1ORB.html#aca471cb82c03b14d3e824e4dcccf90b7)
self.num_features = num_features
self.config['max_num_keypoints'] = self.num_features
self.SIFT = SIFT(conf=self.config)
def extract(self, image):
tensor = numpy_image_to_torch(image)
feats = self.SIFT.extract(tensor.to(self.device))
#print(f'feats: {feats}')
kps = feats["keypoints"].cpu().numpy()[0]
des = feats["descriptors"].cpu().numpy()[0]
scales = feats["scales"].cpu().numpy()[0]
oris = feats["oris"].cpu().numpy()[0]
#print(f'kps: {kps}')
#print(f'des: {des}')
#print(f'scales: {scales}')
#print(f'oris: {oris}')
return kps, des, scales, oris
# extract keypoints
def detect(self, img, mask=None): #mask is fake: it is not considered by the c++ implementation
# detect and compute
kps, des, scales, oris = self.extract(img)
return kps
def compute(self, img, kps, mask=None):
Printer.orange('WARNING: you are supposed to call detectAndCompute() for LIGHTGLUESIFT instead of compute()')
Printer.orange('WARNING: LIGHTGLUESIFT is recomputing both kps and des on input frame', img.shape)
kps, des, scales, oris = self.extract(img)
return des
# compute both keypoints and descriptors
def detectAndCompute(self, img, mask=None): #mask is fake: it is not considered by the c++ implementation
# detect and compute
kps, des, scales, oris = self.extract(img)
kps = convert_pts_to_keypoints(kps, scales, oris)
return kps, des