Skip to content

Commit a17ed70

Browse files
committed
superglue onnxruntime cpp
1 parent 7231be9 commit a17ed70

File tree

7 files changed

+240
-2
lines changed

7 files changed

+240
-2
lines changed

examples/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,3 +159,21 @@ target_include_directories(super_point
159159
PUBLIC
160160
${OpenCV_INCLUDE_DIRS}
161161
)
162+
163+
# ---------------------------------------------------------
164+
165+
add_executable(super_glue
166+
${CMAKE_CURRENT_LIST_DIR}/SuperPoint.cpp
167+
${CMAKE_CURRENT_LIST_DIR}/SuperGlueApp.cpp
168+
)
169+
170+
target_link_libraries(super_glue
171+
PUBLIC
172+
${PROJECT_NAME}
173+
${OpenCV_LIBS}
174+
)
175+
176+
target_include_directories(super_glue
177+
PUBLIC
178+
${OpenCV_INCLUDE_DIRS}
179+
)

examples/SuperGlueApp.cpp

Lines changed: 183 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,183 @@
1+
/**
2+
* @file SuperGlueApp.cpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#include "SuperPoint.hpp"
9+
#include "Utility.hpp"
10+
11+
namespace
12+
{
13+
using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>;
14+
15+
KeyPointAndDesc processOneFrameSuperPoint(const Ort::SuperPoint& superPointOsh, const cv::Mat& inputImg, float* dst,
16+
int borderRemove = 4, float confidenceThresh = 0.015,
17+
bool alignCorners = true, int distThresh = 2);
18+
19+
void normalizeDescriptors(cv::Mat& descriptors);
20+
} // namespace
21+
22+
int main(int argc, char* argv[])
23+
{
24+
if (argc != 5) {
25+
std::cerr
26+
<< "Usage: [apps] [path/to/onnx/super/point] [path/to/onnx/super/glue] [path/to/image1] [path/to/image2]"
27+
<< std::endl;
28+
return EXIT_FAILURE;
29+
}
30+
31+
const std::string ONNX_MODEL_PATH = argv[1];
32+
const std::string SUPERGLUE_ONNX_MODEL_PATH = argv[2];
33+
const std::vector<std::string> IMAGE_PATHS = {argv[3], argv[4]};
34+
35+
Ort::SuperPoint superPointOsh(ONNX_MODEL_PATH, 0,
36+
std::vector<std::vector<int64_t>>{{1, Ort::SuperPoint::IMG_CHANNEL,
37+
Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W}});
38+
39+
std::vector<cv::Mat> images;
40+
std::vector<cv::Mat> grays;
41+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images),
42+
[](const auto& imagePath) { return cv::imread(imagePath); });
43+
for (int i = 0; i < 2; ++i) {
44+
if (images[i].empty()) {
45+
throw std::runtime_error("failed to open " + IMAGE_PATHS[i]);
46+
}
47+
}
48+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays),
49+
[](const auto& imagePath) { return cv::imread(imagePath, 0); });
50+
51+
std::vector<float> dst(Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W);
52+
53+
std::vector<KeyPointAndDesc> superPointResults;
54+
std::transform(grays.begin(), grays.end(), std::back_inserter(superPointResults),
55+
[&superPointOsh, &dst](const auto& gray) {
56+
return processOneFrameSuperPoint(superPointOsh, gray, dst.data());
57+
});
58+
59+
for (auto& curKeyPointAndDesc : superPointResults) {
60+
normalizeDescriptors(curKeyPointAndDesc.second);
61+
}
62+
63+
// superglue
64+
static const int DUMMY_NUM_KEYPOINTS = 256;
65+
Ort::SuperPoint superGlueOsh(SUPERGLUE_ONNX_MODEL_PATH, 0,
66+
std::vector<std::vector<int64_t>>{
67+
{4},
68+
{1, DUMMY_NUM_KEYPOINTS},
69+
{1, DUMMY_NUM_KEYPOINTS, 2},
70+
{1, 256, DUMMY_NUM_KEYPOINTS},
71+
{4},
72+
{1, DUMMY_NUM_KEYPOINTS},
73+
{1, DUMMY_NUM_KEYPOINTS, 2},
74+
{1, 256, DUMMY_NUM_KEYPOINTS},
75+
});
76+
77+
int numKeypoints0 = superPointResults[0].first.size();
78+
int numKeypoints1 = superPointResults[1].first.size();
79+
std::vector<std::vector<int64_t>> inputShapes = {
80+
{4}, {1, numKeypoints0}, {1, numKeypoints0, 2}, {1, 256, numKeypoints0},
81+
{4}, {1, numKeypoints1}, {1, numKeypoints1, 2}, {1, 256, numKeypoints1},
82+
};
83+
superGlueOsh.updateInputShapes(inputShapes);
84+
85+
std::vector<std::vector<float>> imageShapes(2);
86+
std::vector<std::vector<float>> scores(2);
87+
std::vector<std::vector<float>> keypoints(2);
88+
std::vector<std::vector<float>> descriptors(2);
89+
90+
cv::Mat buffer;
91+
for (int i = 0; i < 2; ++i) {
92+
imageShapes[i] = {1, 1, static_cast<float>(images[0].rows), static_cast<float>(images[0].cols)};
93+
std::transform(superPointResults[i].first.begin(), superPointResults[i].first.end(),
94+
std::back_inserter(scores[i]), [](const cv::KeyPoint& keypoint) { return keypoint.response; });
95+
for (const auto& k : superPointResults[i].first) {
96+
keypoints[i].emplace_back(k.pt.y);
97+
keypoints[i].emplace_back(k.pt.x);
98+
}
99+
100+
transposeNDWrapper(superPointResults[i].second, {1, 0}, buffer);
101+
std::copy(buffer.begin<float>(), buffer.end<float>(), std::back_inserter(descriptors[i]));
102+
buffer.release();
103+
}
104+
std::vector<Ort::OrtSessionHandler::DataOutputType> superGlueOrtOutput =
105+
superGlueOsh({imageShapes[0].data(), scores[0].data(), keypoints[0].data(), descriptors[0].data(),
106+
imageShapes[1].data(), scores[1].data(), keypoints[1].data(), descriptors[1].data()});
107+
108+
// match keypoints 0 to keypoints 1
109+
std::vector<int64_t> matchIndices(reinterpret_cast<int64_t*>(superGlueOrtOutput[0].first),
110+
reinterpret_cast<int64_t*>(superGlueOrtOutput[0].first) + numKeypoints0);
111+
112+
std::vector<cv::DMatch> goodMatches;
113+
for (std::size_t i = 0; i < matchIndices.size(); ++i) {
114+
if (matchIndices[i] < 0) {
115+
continue;
116+
}
117+
cv::DMatch match;
118+
match.imgIdx = 0;
119+
match.queryIdx = i;
120+
match.trainIdx = matchIndices[i];
121+
goodMatches.emplace_back(match);
122+
}
123+
124+
cv::Mat matchesImage;
125+
cv::drawMatches(images[0], superPointResults[0].first, images[1], superPointResults[1].first, goodMatches,
126+
matchesImage, cv::Scalar::all(-1), cv::Scalar::all(-1), std::vector<char>(),
127+
cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
128+
cv::imwrite("super_point_good_matches.jpg", matchesImage);
129+
cv::imshow("super_point_good_matches", matchesImage);
130+
cv::waitKey();
131+
132+
return EXIT_SUCCESS;
133+
}
134+
135+
namespace
136+
{
137+
void normalizeDescriptors(cv::Mat& descriptors)
138+
{
139+
cv::Mat rsquaredSumMat;
140+
cv::reduce(descriptors.mul(descriptors), rsquaredSumMat, 1, cv::REDUCE_SUM);
141+
cv::sqrt(rsquaredSumMat, rsquaredSumMat);
142+
for (int i = 0; i < descriptors.rows; ++i) {
143+
float rsquaredSum = std::max<float>(rsquaredSumMat.ptr<float>()[i], 1e-12);
144+
descriptors.row(i) /= rsquaredSum;
145+
}
146+
}
147+
148+
KeyPointAndDesc processOneFrameSuperPoint(const Ort::SuperPoint& superPointOsh, const cv::Mat& inputImg, float* dst,
149+
int borderRemove, float confidenceThresh, bool alignCorners, int distThresh)
150+
{
151+
int origW = inputImg.cols, origH = inputImg.rows;
152+
cv::Mat scaledImg;
153+
cv::resize(inputImg, scaledImg, cv::Size(Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H), 0, 0, cv::INTER_CUBIC);
154+
superPointOsh.preprocess(dst, scaledImg.data, Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H,
155+
Ort::SuperPoint::IMG_CHANNEL);
156+
auto inferenceOutput = superPointOsh({dst});
157+
158+
std::vector<cv::KeyPoint> keyPoints = superPointOsh.getKeyPoints(inferenceOutput, borderRemove, confidenceThresh);
159+
160+
std::vector<int> descriptorShape(inferenceOutput[1].second.begin(), inferenceOutput[1].second.end());
161+
cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F,
162+
inferenceOutput[1].first); // 1 x 256 x H/8 x W/8
163+
164+
std::vector<int> keepIndices =
165+
superPointOsh.nmsFast(keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh);
166+
167+
std::vector<cv::KeyPoint> keepKeyPoints;
168+
keepKeyPoints.reserve(keepIndices.size());
169+
std::transform(keepIndices.begin(), keepIndices.end(), std::back_inserter(keepKeyPoints),
170+
[&keyPoints](int idx) { return keyPoints[idx]; });
171+
keyPoints = std::move(keepKeyPoints);
172+
173+
cv::Mat descriptors = superPointOsh.getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H,
174+
Ort::SuperPoint::IMG_W, alignCorners);
175+
176+
for (auto& keyPoint : keyPoints) {
177+
keyPoint.pt.x *= static_cast<float>(origW) / Ort::SuperPoint::IMG_W;
178+
keyPoint.pt.y *= static_cast<float>(origH) / Ort::SuperPoint::IMG_H;
179+
}
180+
181+
return {keyPoints, descriptors};
182+
}
183+
} // namespace

include/ort_utility/OrtSessionHandler.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@ class OrtSessionHandler
2929
// multiple inputs, multiple outputs
3030
std::vector<DataOutputType> operator()(const std::vector<float*>& inputImgData) const;
3131

32+
void updateInputShapes(const std::vector<std::vector<int64_t>>& inputShapes);
33+
3234
private:
3335
class OrtSessionHandlerIml;
3436
std::unique_ptr<OrtSessionHandlerIml> m_piml;

scripts/superglue/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
```bash
1414
git submodule update --init --recursive
1515

16+
python3 -m pip install -r requirements.txt
1617
python3 -m pip install -r SuperGluePretrainedNetwork/requirements.txt
1718
```
1819

scripts/superglue/convert_to_onnx.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import torch
55

6+
import onnxruntime
67
from superglue_wrapper import SuperGlueWrapper as SuperGlue
78

89

@@ -25,7 +26,9 @@ def main():
2526
num_keypoints = 382
2627
data = {}
2728
for i in range(2):
28-
data[f"image{i}_shape"] = torch.tensor([batch_size, 1, height, width])
29+
data[f"image{i}_shape"] = torch.tensor(
30+
[batch_size, 1, height, width], dtype=torch.float32
31+
)
2932
data[f"scores{i}"] = torch.randn(batch_size, num_keypoints)
3033
data[f"keypoints{i}"] = torch.randn(batch_size, num_keypoints, 2)
3134
data[f"descriptors{i}"] = torch.randn(batch_size, 256, num_keypoints)
@@ -54,6 +57,15 @@ def main():
5457
)
5558
print(f"\nonnx model is saved to: {os.getcwd()}/super_glue.onnx")
5659

60+
print("\ntest inference using onnxruntime")
61+
sess = onnxruntime.InferenceSession("super_glue.onnx")
62+
for input in sess.get_inputs():
63+
print("input: ", input)
64+
65+
print("\n")
66+
for output in sess.get_outputs():
67+
print("output: ", output)
68+
5769

5870
if __name__ == "__main__":
5971
main()

scripts/superglue/requirements.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
onnxruntime

src/OrtSessionHandler.cpp

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,21 @@ class OrtSessionHandler::OrtSessionHandlerIml
8888

8989
std::vector<DataOutputType> operator()(const std::vector<float*>& inputData) const;
9090

91+
void updateInputShapes(const std::vector<std::vector<int64_t>>& inputShapes)
92+
{
93+
if (inputShapes.size() != m_numInputs) {
94+
DEBUG_LOG("inputShapes must be of size: %d", m_numInputs);
95+
return;
96+
}
97+
m_inputShapes = inputShapes;
98+
99+
for (int i = 0; i < m_numInputs; i++) {
100+
const auto& curInputShape = m_inputShapes[i];
101+
m_inputTensorSizes[i] =
102+
std::accumulate(std::begin(curInputShape), std::end(curInputShape), 1, std::multiplies<int64_t>());
103+
}
104+
}
105+
91106
private:
92107
void initSession();
93108
void initModelInfo();
@@ -272,7 +287,8 @@ std::vector<OrtSessionHandler::DataOutputType>
272287
OrtSessionHandler::OrtSessionHandlerIml::operator()(const std::vector<float*>& inputData) const
273288
{
274289
if (m_numInputs != inputData.size()) {
275-
throw std::runtime_error("Mismatch size of input data\n");
290+
DEBUG_LOG("m_numInputs:%d, input size:%ld", m_numInputs, inputData.size());
291+
throw std::runtime_error("Mismatch size of input data");
276292
}
277293

278294
Ort::MemoryInfo memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
@@ -302,4 +318,9 @@ OrtSessionHandler::OrtSessionHandlerIml::operator()(const std::vector<float*>& i
302318

303319
return outputData;
304320
}
321+
322+
void OrtSessionHandler::updateInputShapes(const std::vector<std::vector<int64_t>>& inputShapes)
323+
{
324+
m_piml->updateInputShapes(inputShapes);
325+
}
305326
} // namespace Ort

0 commit comments

Comments
 (0)