Skip to content

Commit 7231be9

Browse files
committed
simplified superpoint
1 parent 89cf774 commit 7231be9

File tree

3 files changed

+131
-125
lines changed

3 files changed

+131
-125
lines changed

examples/SuperPoint.cpp

Lines changed: 107 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
*/
77

88
#include "SuperPoint.hpp"
9+
#include "Utility.hpp"
910

1011
namespace Ort
1112
{
@@ -21,5 +22,110 @@ void SuperPoint::preprocess(float* dst, const unsigned char* src, const int64_t
2122
}
2223
}
2324
}
25+
26+
std::vector<int> SuperPoint::nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width,
27+
int distThresh) const
28+
{
29+
static const int TO_PROCESS = 1;
30+
static const int EMPTY_OR_SUPPRESSED = 0;
31+
static const int KEPT = -1;
32+
33+
std::vector<int> sortedIndices(keyPoints.size());
34+
std::iota(sortedIndices.begin(), sortedIndices.end(), 0);
35+
36+
// sort in descending order base on confidence
37+
std::stable_sort(sortedIndices.begin(), sortedIndices.end(),
38+
[&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response; });
39+
40+
cv::Mat grid = cv::Mat(height, width, CV_8S, TO_PROCESS);
41+
std::vector<int> keepIndices;
42+
43+
for (int idx : sortedIndices) {
44+
int x = keyPoints[idx].pt.x;
45+
int y = keyPoints[idx].pt.y;
46+
47+
if (grid.at<schar>(y, x) == TO_PROCESS) {
48+
for (int i = y - distThresh; i < y + distThresh; ++i) {
49+
if (i < 0 || i >= height) {
50+
continue;
51+
}
52+
53+
for (int j = x - distThresh; j < x + distThresh; ++j) {
54+
if (j < 0 || j >= width) {
55+
continue;
56+
}
57+
grid.at<int>(i, j) = EMPTY_OR_SUPPRESSED;
58+
}
59+
}
60+
61+
grid.at<int>(y, x) = KEPT;
62+
keepIndices.emplace_back(idx);
63+
}
64+
}
65+
66+
return keepIndices;
67+
}
68+
69+
std::vector<cv::KeyPoint>
70+
SuperPoint::getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput, int borderRemove,
71+
float confidenceThresh) const
72+
{
73+
std::vector<int> detectorShape(inferenceOutput[0].second.begin() + 1, inferenceOutput[0].second.end());
74+
75+
cv::Mat detectorMat(detectorShape.size(), detectorShape.data(), CV_32F,
76+
inferenceOutput[0].first); // 65 x H/8 x W/8
77+
cv::Mat buffer;
78+
79+
transposeNDWrapper(detectorMat, {1, 2, 0}, buffer);
80+
buffer.copyTo(detectorMat); // H/8 x W/8 x 65
81+
82+
for (int i = 0; i < detectorShape[1]; ++i) {
83+
for (int j = 0; j < detectorShape[2]; ++j) {
84+
Ort::softmax(detectorMat.ptr<float>(i, j), detectorShape[0]);
85+
}
86+
}
87+
detectorMat = detectorMat({cv::Range::all(), cv::Range::all(), cv::Range(0, detectorShape[0] - 1)})
88+
.clone(); // H/8 x W/8 x 64
89+
detectorMat = detectorMat.reshape(1, {detectorShape[1], detectorShape[2], 8, 8}); // H/8 x W/8 x 8 x 8
90+
transposeNDWrapper(detectorMat, {0, 2, 1, 3}, buffer);
91+
buffer.copyTo(detectorMat); // H/8 x 8 x W/8 x 8
92+
93+
detectorMat = detectorMat.reshape(1, {detectorShape[1] * 8, detectorShape[2] * 8}); // H x W
94+
95+
std::vector<cv::KeyPoint> keyPoints;
96+
for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) {
97+
auto rowPtr = detectorMat.ptr<float>(i);
98+
for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) {
99+
if (rowPtr[j] > confidenceThresh) {
100+
cv::KeyPoint keyPoint;
101+
keyPoint.pt.x = j;
102+
keyPoint.pt.y = i;
103+
keyPoint.response = rowPtr[j];
104+
keyPoints.emplace_back(keyPoint);
105+
}
106+
}
107+
}
108+
109+
return keyPoints;
110+
}
111+
112+
cv::Mat SuperPoint::getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints,
113+
int height, int width, bool alignCorners) const
114+
{
115+
cv::Mat keyPointMat(keyPoints.size(), 2, CV_32F);
116+
117+
for (int i = 0; i < keyPoints.size(); ++i) {
118+
auto rowPtr = keyPointMat.ptr<float>(i);
119+
rowPtr[0] = 2 * keyPoints[i].pt.y / (height - 1) - 1;
120+
rowPtr[1] = 2 * keyPoints[i].pt.x / (width - 1) - 1;
121+
}
122+
keyPointMat = keyPointMat.reshape(1, {1, 1, static_cast<int>(keyPoints.size()), 2});
123+
cv::Mat descriptors = bilinearGridSample(coarseDescriptors, keyPointMat, alignCorners);
124+
descriptors = descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())});
125+
126+
cv::Mat buffer;
127+
transposeNDWrapper(descriptors, {1, 0}, buffer);
128+
129+
return buffer;
130+
}
24131
} // namespace Ort
25-
// namespace Ort

examples/SuperPoint.hpp

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
#pragma once
99

10+
#include <opencv2/opencv.hpp>
1011
#include <ort_utility/ort_utility.hpp>
1112

1213
namespace Ort
@@ -25,5 +26,24 @@ class SuperPoint : public OrtSessionHandler
2526
const int64_t targetImgWidth, //
2627
const int64_t targetImgHeight, //
2728
const int numChannels) const;
29+
30+
std::vector<int> nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width,
31+
int distThresh = 2) const;
32+
33+
/**
34+
* @brief detect super point
35+
*
36+
* @return vector of detected key points
37+
*/
38+
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
39+
int borderRemove, float confidenceThresh) const;
40+
41+
/**
42+
* @brief estimate super point's keypoint descriptor
43+
*
44+
* @return keypoint Mat of shape [num key point x 256]
45+
*/
46+
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
47+
int width, bool alignCorners) const;
2848
};
2949
} // namespace Ort

examples/SuperPointApp.cpp

Lines changed: 4 additions & 124 deletions
Original file line numberDiff line numberDiff line change
@@ -15,23 +15,6 @@ using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
1515
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove = 4,
1616
float confidenceThresh = 0.015, bool alignCorners = true, int distThresh = 2);
1717

18-
/**
19-
* @brief detect super point
20-
*
21-
* @return vector of detected key points
22-
*/
23-
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
24-
int borderRemove, float confidenceThresh);
25-
26-
/**
27-
* @brief estimate super point's keypoint descriptor
28-
*
29-
* @return keypoint Mat of shape [num key point x 256]
30-
*/
31-
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
32-
int width, bool alignCorners);
33-
34-
std::vector<int> nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2);
3518
} // namespace
3619

3720
int main(int argc, char* argv[])
@@ -92,109 +75,6 @@ int main(int argc, char* argv[])
9275

9376
namespace
9477
{
95-
std::vector<cv::KeyPoint> getKeyPoints(const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
96-
int borderRemove, float confidenceThresh)
97-
{
98-
std::vector<int> detectorShape(inferenceOutput[0].second.begin() + 1, inferenceOutput[0].second.end());
99-
100-
cv::Mat detectorMat(detectorShape.size(), detectorShape.data(), CV_32F,
101-
inferenceOutput[0].first); // 65 x H/8 x W/8
102-
cv::Mat buffer;
103-
104-
transposeNDWrapper(detectorMat, {1, 2, 0}, buffer);
105-
buffer.copyTo(detectorMat); // H/8 x W/8 x 65
106-
107-
for (int i = 0; i < detectorShape[1]; ++i) {
108-
for (int j = 0; j < detectorShape[2]; ++j) {
109-
Ort::softmax(detectorMat.ptr<float>(i, j), detectorShape[0]);
110-
}
111-
}
112-
detectorMat = detectorMat({cv::Range::all(), cv::Range::all(), cv::Range(0, detectorShape[0] - 1)})
113-
.clone(); // H/8 x W/8 x 64
114-
detectorMat = detectorMat.reshape(1, {detectorShape[1], detectorShape[2], 8, 8}); // H/8 x W/8 x 8 x 8
115-
transposeNDWrapper(detectorMat, {0, 2, 1, 3}, buffer);
116-
buffer.copyTo(detectorMat); // H/8 x 8 x W/8 x 8
117-
118-
detectorMat = detectorMat.reshape(1, {detectorShape[1] * 8, detectorShape[2] * 8}); // H x W
119-
120-
std::vector<cv::KeyPoint> keyPoints;
121-
for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) {
122-
auto rowPtr = detectorMat.ptr<float>(i);
123-
for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) {
124-
if (rowPtr[j] > confidenceThresh) {
125-
cv::KeyPoint keyPoint;
126-
keyPoint.pt.x = j;
127-
keyPoint.pt.y = i;
128-
keyPoint.response = rowPtr[j];
129-
keyPoints.emplace_back(keyPoint);
130-
}
131-
}
132-
}
133-
134-
return keyPoints;
135-
}
136-
cv::Mat getDescriptors(const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
137-
int width, bool alignCorners)
138-
{
139-
cv::Mat keyPointMat(keyPoints.size(), 2, CV_32F);
140-
141-
for (int i = 0; i < keyPoints.size(); ++i) {
142-
auto rowPtr = keyPointMat.ptr<float>(i);
143-
rowPtr[0] = 2 * keyPoints[i].pt.y / (height - 1) - 1;
144-
rowPtr[1] = 2 * keyPoints[i].pt.x / (width - 1) - 1;
145-
}
146-
keyPointMat = keyPointMat.reshape(1, {1, 1, static_cast<int>(keyPoints.size()), 2});
147-
cv::Mat descriptors = bilinearGridSample(coarseDescriptors, keyPointMat, alignCorners);
148-
descriptors = descriptors.reshape(1, {coarseDescriptors.size[1], static_cast<int>(keyPoints.size())});
149-
150-
cv::Mat buffer;
151-
transposeNDWrapper(descriptors, {1, 0}, buffer);
152-
153-
return buffer;
154-
}
155-
156-
std::vector<int> nmsFast(const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157-
{
158-
static const int TO_PROCESS = 1;
159-
static const int EMPTY_OR_SUPPRESSED = 0;
160-
static const int KEPT = -1;
161-
162-
std::vector<int> sortedIndices(keyPoints.size());
163-
std::iota(sortedIndices.begin(), sortedIndices.end(), 0);
164-
165-
// sort in descending order base on confidence
166-
std::stable_sort(sortedIndices.begin(), sortedIndices.end(),
167-
[&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response; });
168-
169-
cv::Mat grid = cv::Mat(height, width, CV_8S, TO_PROCESS);
170-
std::vector<int> keepIndices;
171-
172-
for (int idx : sortedIndices) {
173-
int x = keyPoints[idx].pt.x;
174-
int y = keyPoints[idx].pt.y;
175-
176-
if (grid.at<schar>(y, x) == TO_PROCESS) {
177-
for (int i = y - distThresh; i < y + distThresh; ++i) {
178-
if (i < 0 || i >= height) {
179-
continue;
180-
}
181-
182-
for (int j = x - distThresh; j < x + distThresh; ++j) {
183-
if (j < 0 || j >= width) {
184-
continue;
185-
}
186-
grid.at<int>(i, j) = EMPTY_OR_SUPPRESSED;
187-
}
188-
}
189-
190-
grid.at<int>(y, x) = KEPT;
191-
keepIndices.emplace_back(idx);
192-
}
193-
}
194-
195-
return keepIndices;
196-
}
197-
19878
KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& inputImg, float* dst, int borderRemove,
19979
float confidenceThresh, bool alignCorners, int distThresh)
20080
{
@@ -204,22 +84,22 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
20484
osh.preprocess(dst, scaledImg.data, Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_CHANNEL);
20585
auto inferenceOutput = osh({dst});
20686

207-
std::vector<cv::KeyPoint> keyPoints = getKeyPoints(inferenceOutput, borderRemove, confidenceThresh);
87+
std::vector<cv::KeyPoint> keyPoints = osh.getKeyPoints(inferenceOutput, borderRemove, confidenceThresh);
20888

20989
std::vector<int> descriptorShape(inferenceOutput[1].second.begin(), inferenceOutput[1].second.end());
21090
cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F,
21191
inferenceOutput[1].first); // 1 x 256 x H/8 x W/8
21292

213-
std::vector<int> keepIndices = nmsFast(keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
93+
std::vector<int> keepIndices = osh.nmsFast(keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
21494

21595
std::vector<cv::KeyPoint> keepKeyPoints;
21696
keepKeyPoints.reserve(keepIndices.size());
21797
std::transform(keepIndices.begin(), keepIndices.end(), std::back_inserter(keepKeyPoints),
21898
[&keyPoints](int idx) { return keyPoints[idx]; });
21999
keyPoints = std::move(keepKeyPoints);
220100

221-
cv::Mat descriptors =
222-
getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
101+
cv::Mat descriptors = osh.getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H,
102+
Ort::SuperPoint::IMG_W, alignCorners);
223103

224104
for (auto& keyPoint : keyPoints) {
225105
keyPoint.pt.x *= static_cast<float>(origW) / Ort::SuperPoint::IMG_W;

0 commit comments

Comments
 (0)