7
7
8
8
#include " SuperPoint.hpp"
9
9
#include " Utility.hpp"
10
+ #include < algorithm>
11
+ #include < iterator>
10
12
11
13
namespace
12
14
{
13
15
using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
14
16
15
17
KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove = 4 ,
16
- float confidenceThresh = 0.015 , bool alignCorners = true );
18
+ float confidenceThresh = 0.015 , bool alignCorners = true , int distThresh = 2 );
17
19
20
+ /* *
21
+ * @brief detect super point
22
+ *
23
+ * @return vector of detected key points
24
+ */
25
+ std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
26
+ int borderRemove, float confidenceThresh);
27
+
28
+ /* *
29
+ * @brief estimate super point's keypoint descriptor
30
+ *
31
+ * @return keypoint Mat of shape [num key point x 256]
32
+ */
33
+ cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
34
+ int width, bool alignCorners);
35
+
36
+ std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2 );
18
37
} // namespace
19
38
20
39
int main (int argc, char * argv[])
@@ -66,18 +85,13 @@ int main(int argc, char* argv[])
66
85
cv::drawMatches (images[0 ], results[0 ].first , images[1 ], results[1 ].first , goodMatches, matchesImage,
67
86
cv::Scalar::all (-1 ), cv::Scalar::all (-1 ), std::vector<char >(),
68
87
cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
69
- cv::imwrite (" good_matches .jpg" , matchesImage);
88
+ cv::imwrite (" super_point_good_matches .jpg" , matchesImage);
70
89
71
90
return EXIT_SUCCESS;
72
91
}
73
92
74
93
namespace
75
94
{
76
- /* *
77
- * @brief detect super point
78
- *
79
- * @return vector of detected key points
80
- */
81
95
std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
82
96
int borderRemove, float confidenceThresh)
83
97
{
@@ -119,12 +133,6 @@ std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler:
119
133
120
134
return keyPoints;
121
135
}
122
-
123
- /* *
124
- * @brief estimate super point's keypoint descriptor
125
- *
126
- * @return keypoint Mat of shape [num key point x 256]
127
- */
128
136
cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
129
137
int width, bool alignCorners)
130
138
{
@@ -145,8 +153,50 @@ cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::K
145
153
return buffer;
146
154
}
147
155
156
+ std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157
+ {
158
+ static const int TO_PROCESS = 1 ;
159
+ static const int EMPTY_OR_SUPPRESSED = 0 ;
160
+ static const int KEPT = -1 ;
161
+
162
+ std::vector<int > sortedIndices (keyPoints.size ());
163
+ std::iota (sortedIndices.begin (), sortedIndices.end (), 0 );
164
+
165
+ // sort in descending order base on confidence
166
+ std::stable_sort (sortedIndices.begin (), sortedIndices.end (),
167
+ [&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response ; });
168
+
169
+ cv::Mat grid = cv::Mat (height, width, CV_8S, TO_PROCESS);
170
+ std::vector<int > keepIndices;
171
+
172
+ for (int idx : sortedIndices) {
173
+ int x = keyPoints[idx].pt .x ;
174
+ int y = keyPoints[idx].pt .y ;
175
+
176
+ if (grid.at <schar>(y, x) == TO_PROCESS) {
177
+ for (int i = y - distThresh; i < y + distThresh; ++i) {
178
+ if (i < 0 || i >= height) {
179
+ continue ;
180
+ }
181
+
182
+ for (int j = x - distThresh; j < x + distThresh; ++j) {
183
+ if (j < 0 || j >= width) {
184
+ continue ;
185
+ }
186
+ grid.at <int >(i, j) = EMPTY_OR_SUPPRESSED;
187
+ }
188
+ }
189
+
190
+ grid.at <int >(y, x) = KEPT;
191
+ keepIndices.emplace_back (idx);
192
+ }
193
+ }
194
+
195
+ return keepIndices;
196
+ }
197
+
148
198
KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove,
149
- float confidenceThresh, bool alignCorners)
199
+ float confidenceThresh, bool alignCorners, int distThresh )
150
200
{
151
201
int origW = inputImg.cols , origH = inputImg.rows ;
152
202
std::vector<float > originImageSize{static_cast <float >(origH), static_cast <float >(origW)};
@@ -161,6 +211,14 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
161
211
cv::Mat coarseDescriptorMat (descriptorShape.size (), descriptorShape.data (), CV_32F,
162
212
inferenceOutput[1 ].first ); // 1 x 256 x H/8 x W/8
163
213
214
+ std::vector<int > keepIndices = nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
215
+
216
+ std::vector<cv::KeyPoint> keepKeyPoints;
217
+ keepKeyPoints.reserve (keepIndices.size ());
218
+ std::transform (keepIndices.begin (), keepIndices.end (), std::back_inserter (keepKeyPoints),
219
+ [&keyPoints](int idx) { return keyPoints[idx]; });
220
+ keyPoints = std::move (keepKeyPoints);
221
+
164
222
cv::Mat descriptors =
165
223
getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
166
224
0 commit comments