Skip to content

Commit 98aca9f

Browse files
committed
super point nms
1 parent 32bcbae commit 98aca9f

File tree

1 file changed

+72
-14
lines changed

1 file changed

+72
-14
lines changed

examples/SuperPointApp.cpp

Lines changed: 72 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,33 @@
77

88
#include "SuperPoint.hpp"
99
#include "Utility.hpp"
10+
#include <algorithm>
11+
#include <iterator>
1012

1113
namespace
1214
{
1315
using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
1416

1517
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove = 4,
16-
float confidenceThresh = 0.015, bool alignCorners = true);
18+
float confidenceThresh = 0.015, bool alignCorners = true, int distThresh = 2);
1719

20+
/**
21+
* @brief detect super point
22+
*
23+
* @return vector of detected key points
24+
*/
25+
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
26+
int borderRemove, float confidenceThresh);
27+
28+
/**
29+
* @brief estimate super point's keypoint descriptor
30+
*
31+
* @return keypoint Mat of shape [num key point x 256]
32+
*/
33+
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
34+
int width, bool alignCorners);
35+
36+
std::vector<int> nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2);
1837
} // namespace
1938

2039
int main(int argc, char* argv[])
@@ -66,18 +85,13 @@ int main(int argc, char* argv[])
6685
cv::drawMatches(images[0], results[0].first, images[1], results[1].first, goodMatches, matchesImage,
6786
cv::Scalar::all(-1), cv::Scalar::all(-1), std::vector<char>(),
6887
cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
69-
cv::imwrite("good_matches.jpg", matchesImage);
88+
cv::imwrite("super_point_good_matches.jpg", matchesImage);
7089

7190
return EXIT_SUCCESS;
7291
}
7392

7493
namespace
7594
{
76-
/**
77-
* @brief detect super point
78-
*
79-
* @return vector of detected key points
80-
*/
8195
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
8296
int borderRemove, float confidenceThresh)
8397
{
@@ -119,12 +133,6 @@ std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler:
119133

120134
return keyPoints;
121135
}
122-
123-
/**
124-
* @brief estimate super point's keypoint descriptor
125-
*
126-
* @return keypoint Mat of shape [num key point x 256]
127-
*/
128136
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
129137
int width, bool alignCorners)
130138
{
@@ -145,8 +153,50 @@ cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::K
145153
return buffer;
146154
}
147155

156+
std::vector<int> nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157+
{
158+
static const int TO_PROCESS = 1;
159+
static const int EMPTY_OR_SUPPRESSED = 0;
160+
static const int KEPT = -1;
161+
162+
std::vector<int> sortedIndices(keyPoints.size());
163+
std::iota(sortedIndices.begin(), sortedIndices.end(), 0);
164+
165+
// sort in descending order base on confidence
166+
std::stable_sort(sortedIndices.begin(), sortedIndices.end(),
167+
[&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response; });
168+
169+
cv::Mat grid = cv::Mat(height, width, CV_8S, TO_PROCESS);
170+
std::vector<int> keepIndices;
171+
172+
for (int idx : sortedIndices) {
173+
int x = keyPoints[idx].pt.x;
174+
int y = keyPoints[idx].pt.y;
175+
176+
if (grid.at<schar>(y, x) == TO_PROCESS) {
177+
for (int i = y - distThresh; i < y + distThresh; ++i) {
178+
if (i < 0 || i >= height) {
179+
continue;
180+
}
181+
182+
for (int j = x - distThresh; j < x + distThresh; ++j) {
183+
if (j < 0 || j >= width) {
184+
continue;
185+
}
186+
grid.at<int>(i, j) = EMPTY_OR_SUPPRESSED;
187+
}
188+
}
189+
190+
grid.at<int>(y, x) = KEPT;
191+
keepIndices.emplace_back(idx);
192+
}
193+
}
194+
195+
return keepIndices;
196+
}
197+
148198
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove,
149-
float confidenceThresh, bool alignCorners)
199+
float confidenceThresh, bool alignCorners, int distThresh)
150200
{
151201
int origW = inputImg.cols, origH = inputImg.rows;
152202
std::vector<float> originImageSize{static_cast<float>(origH), static_cast<float>(origW)};
@@ -161,6 +211,14 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
161211
cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F,
162212
inferenceOutput[1].first); // 1 x 256 x H/8 x W/8
163213

214+
std::vector<int> keepIndices = nmsFast(keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
215+
216+
std::vector<cv::KeyPoint> keepKeyPoints;
217+
keepKeyPoints.reserve(keepIndices.size());
218+
std::transform(keepIndices.begin(), keepIndices.end(), std::back_inserter(keepKeyPoints),
219+
[&keyPoints](int idx) { return keyPoints[idx]; });
220+
keyPoints = std::move(keepKeyPoints);
221+
164222
cv::Mat descriptors =
165223
getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
166224

0 commit comments

Comments
 (0)