@@ -15,23 +15,6 @@ using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
15
15
KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove = 4 ,
16
16
float confidenceThresh = 0.015 , bool alignCorners = true , int distThresh = 2 );
17
17
18
- /* *
19
- * @brief detect super point
20
- *
21
- * @return vector of detected key points
22
- */
23
- std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
24
- int borderRemove, float confidenceThresh);
25
-
26
- /* *
27
- * @brief estimate super point's keypoint descriptor
28
- *
29
- * @return keypoint Mat of shape [num key point x 256]
30
- */
31
- cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
32
- int width, bool alignCorners);
33
-
34
- std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh = 2 );
35
18
} // namespace
36
19
37
20
int main (int argc, char * argv[])
@@ -92,109 +75,6 @@ int main(int argc, char* argv[])
92
75
93
76
namespace
94
77
{
95
- std::vector<cv::KeyPoint> getKeyPoints (const std::vector<Ort::OrtSessionHandler::DataOutputType>& inferenceOutput,
96
- int borderRemove, float confidenceThresh)
97
- {
98
- std::vector<int > detectorShape (inferenceOutput[0 ].second .begin () + 1 , inferenceOutput[0 ].second .end ());
99
-
100
- cv::Mat detectorMat (detectorShape.size (), detectorShape.data (), CV_32F,
101
- inferenceOutput[0 ].first ); // 65 x H/8 x W/8
102
- cv::Mat buffer;
103
-
104
- transposeNDWrapper (detectorMat, {1 , 2 , 0 }, buffer);
105
- buffer.copyTo (detectorMat); // H/8 x W/8 x 65
106
-
107
- for (int i = 0 ; i < detectorShape[1 ]; ++i) {
108
- for (int j = 0 ; j < detectorShape[2 ]; ++j) {
109
- Ort::softmax (detectorMat.ptr <float >(i, j), detectorShape[0 ]);
110
- }
111
- }
112
- detectorMat = detectorMat ({cv::Range::all (), cv::Range::all (), cv::Range (0 , detectorShape[0 ] - 1 )})
113
- .clone (); // H/8 x W/8 x 64
114
- detectorMat = detectorMat.reshape (1 , {detectorShape[1 ], detectorShape[2 ], 8 , 8 }); // H/8 x W/8 x 8 x 8
115
- transposeNDWrapper (detectorMat, {0 , 2 , 1 , 3 }, buffer);
116
- buffer.copyTo (detectorMat); // H/8 x 8 x W/8 x 8
117
-
118
- detectorMat = detectorMat.reshape (1 , {detectorShape[1 ] * 8 , detectorShape[2 ] * 8 }); // H x W
119
-
120
- std::vector<cv::KeyPoint> keyPoints;
121
- for (int i = borderRemove; i < detectorMat.rows - borderRemove; ++i) {
122
- auto rowPtr = detectorMat.ptr <float >(i);
123
- for (int j = borderRemove; j < detectorMat.cols - borderRemove; ++j) {
124
- if (rowPtr[j] > confidenceThresh) {
125
- cv::KeyPoint keyPoint;
126
- keyPoint.pt .x = j;
127
- keyPoint.pt .y = i;
128
- keyPoint.response = rowPtr[j];
129
- keyPoints.emplace_back (keyPoint);
130
- }
131
- }
132
- }
133
-
134
- return keyPoints;
135
- }
136
- cv::Mat getDescriptors (const cv::Mat& coarseDescriptors, const std::vector<cv::KeyPoint>& keyPoints, int height,
137
- int width, bool alignCorners)
138
- {
139
- cv::Mat keyPointMat (keyPoints.size (), 2 , CV_32F);
140
-
141
- for (int i = 0 ; i < keyPoints.size (); ++i) {
142
- auto rowPtr = keyPointMat.ptr <float >(i);
143
- rowPtr[0 ] = 2 * keyPoints[i].pt .y / (height - 1 ) - 1 ;
144
- rowPtr[1 ] = 2 * keyPoints[i].pt .x / (width - 1 ) - 1 ;
145
- }
146
- keyPointMat = keyPointMat.reshape (1 , {1 , 1 , static_cast <int >(keyPoints.size ()), 2 });
147
- cv::Mat descriptors = bilinearGridSample (coarseDescriptors, keyPointMat, alignCorners);
148
- descriptors = descriptors.reshape (1 , {coarseDescriptors.size [1 ], static_cast <int >(keyPoints.size ())});
149
-
150
- cv::Mat buffer;
151
- transposeNDWrapper (descriptors, {1 , 0 }, buffer);
152
-
153
- return buffer;
154
- }
155
-
156
- std::vector<int > nmsFast (const std::vector<cv::KeyPoint>& keyPoints, int height, int width, int distThresh)
157
- {
158
- static const int TO_PROCESS = 1 ;
159
- static const int EMPTY_OR_SUPPRESSED = 0 ;
160
- static const int KEPT = -1 ;
161
-
162
- std::vector<int > sortedIndices (keyPoints.size ());
163
- std::iota (sortedIndices.begin (), sortedIndices.end (), 0 );
164
-
165
- // sort in descending order base on confidence
166
- std::stable_sort (sortedIndices.begin (), sortedIndices.end (),
167
- [&keyPoints](int lidx, int ridx) { return keyPoints[lidx].response > keyPoints[ridx].response ; });
168
-
169
- cv::Mat grid = cv::Mat (height, width, CV_8S, TO_PROCESS);
170
- std::vector<int > keepIndices;
171
-
172
- for (int idx : sortedIndices) {
173
- int x = keyPoints[idx].pt .x ;
174
- int y = keyPoints[idx].pt .y ;
175
-
176
- if (grid.at <schar>(y, x) == TO_PROCESS) {
177
- for (int i = y - distThresh; i < y + distThresh; ++i) {
178
- if (i < 0 || i >= height) {
179
- continue ;
180
- }
181
-
182
- for (int j = x - distThresh; j < x + distThresh; ++j) {
183
- if (j < 0 || j >= width) {
184
- continue ;
185
- }
186
- grid.at <int >(i, j) = EMPTY_OR_SUPPRESSED;
187
- }
188
- }
189
-
190
- grid.at <int >(y, x) = KEPT;
191
- keepIndices.emplace_back (idx);
192
- }
193
- }
194
-
195
- return keepIndices;
196
- }
197
-
198
78
KeyPointAndDesc processOneFrame (const Ort::SuperPoint& osh, const cv::Mat& inputImg, float * dst, int borderRemove,
199
79
float confidenceThresh, bool alignCorners, int distThresh)
200
80
{
@@ -204,22 +84,22 @@ KeyPointAndDesc processOneFrame(const Ort::SuperPoint& osh, const cv::Mat& input
204
84
osh.preprocess (dst, scaledImg.data , Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_CHANNEL);
205
85
auto inferenceOutput = osh ({dst});
206
86
207
- std::vector<cv::KeyPoint> keyPoints = getKeyPoints (inferenceOutput, borderRemove, confidenceThresh);
87
+ std::vector<cv::KeyPoint> keyPoints = osh. getKeyPoints (inferenceOutput, borderRemove, confidenceThresh);
208
88
209
89
std::vector<int > descriptorShape (inferenceOutput[1 ].second .begin (), inferenceOutput[1 ].second .end ());
210
90
cv::Mat coarseDescriptorMat (descriptorShape.size (), descriptorShape.data (), CV_32F,
211
91
inferenceOutput[1 ].first ); // 1 x 256 x H/8 x W/8
212
92
213
- std::vector<int > keepIndices = nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
93
+ std::vector<int > keepIndices = osh. nmsFast (keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
214
94
215
95
std::vector<cv::KeyPoint> keepKeyPoints;
216
96
keepKeyPoints.reserve (keepIndices.size ());
217
97
std::transform (keepIndices.begin (), keepIndices.end (), std::back_inserter (keepKeyPoints),
218
98
[&keyPoints](int idx) { return keyPoints[idx]; });
219
99
keyPoints = std::move (keepKeyPoints);
220
100
221
- cv::Mat descriptors =
222
- getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, alignCorners);
101
+ cv::Mat descriptors = osh. getDescriptors (coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H,
102
+ Ort::SuperPoint::IMG_W, alignCorners);
223
103
224
104
for (auto & keyPoint : keyPoints) {
225
105
keyPoint.pt .x *= static_cast <float >(origW) / Ort::SuperPoint::IMG_W;
0 commit comments