Skip to content

Commit 861b44f

Browse files
committed
LoFTR App
1 parent 740afc2 commit 861b44f

File tree

4 files changed

+192
-0
lines changed

4 files changed

+192
-0
lines changed

examples/CMakeLists.txt

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,3 +177,21 @@ target_include_directories(super_glue
177177
PUBLIC
178178
${OpenCV_INCLUDE_DIRS}
179179
)
180+
181+
# ---------------------------------------------------------
182+
183+
add_executable(loftr
184+
${CMAKE_CURRENT_LIST_DIR}/LoFTR.cpp
185+
${CMAKE_CURRENT_LIST_DIR}/LoFTRApp.cpp
186+
)
187+
188+
target_link_libraries(loftr
189+
PUBLIC
190+
${PROJECT_NAME}
191+
${OpenCV_LIBS}
192+
)
193+
194+
target_include_directories(loftr
195+
PUBLIC
196+
${OpenCV_INCLUDE_DIRS}
197+
)

examples/LoFTR.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
/**
2+
* @file LoFTR.cpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#include "LoFTR.hpp"
9+
10+
namespace Ort
11+
{
12+
void LoFTR::preprocess(float* dst, const unsigned char* src, const int64_t targetImgWidth,
13+
const int64_t targetImgHeight, const int numChannels) const
14+
{
15+
for (int i = 0; i < targetImgHeight; ++i) {
16+
for (int j = 0; j < targetImgWidth; ++j) {
17+
for (int c = 0; c < numChannels; ++c) {
18+
dst[c * targetImgHeight * targetImgWidth + i * targetImgWidth + j] =
19+
(src[i * targetImgWidth * numChannels + j * numChannels + c] / 255.0);
20+
}
21+
}
22+
}
23+
}
24+
} // namespace Ort

examples/LoFTR.hpp

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/**
2+
* @file LoFTR.hpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#pragma once
9+
10+
#include <ort_utility/ort_utility.hpp>
11+
12+
namespace Ort
13+
{
14+
class LoFTR : public OrtSessionHandler
15+
{
16+
public:
17+
static constexpr int64_t IMG_H = 480;
18+
static constexpr int64_t IMG_W = 640;
19+
static constexpr int64_t IMG_CHANNEL = 1;
20+
21+
using OrtSessionHandler::OrtSessionHandler;
22+
23+
void preprocess(float* dst, //
24+
const unsigned char* src, //
25+
const int64_t targetImgWidth, //
26+
const int64_t targetImgHeight, //
27+
const int numChannels) const;
28+
};
29+
} // namespace Ort

examples/LoFTRApp.cpp

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
/**
2+
* @file LoFTRApp.cpp
3+
*
4+
* @author btran
5+
*
6+
*/
7+
8+
#include "LoFTR.hpp"
9+
#include "Utility.hpp"
10+
#include <utility>
11+
12+
static constexpr float CONFIDENCE_THRESHOLD = 0.1;
13+
14+
namespace
15+
{
16+
std::pair<std::vector<cv::KeyPoint>, std::vector<cv::KeyPoint>>
17+
processOneImagePair(const Ort::LoFTR& loftrOsh, const cv::Mat& queryImg, const cv::Mat& refImg, float* queryData,
18+
float* refData, float confidenceThresh = CONFIDENCE_THRESHOLD);
19+
} // namespace
20+
21+
int main(int argc, char* argv[])
22+
{
23+
if (argc != 4) {
24+
std::cerr << "Usage: [apps] [path/to/onnx/loftr] [path/to/image1] [path/to/image2]" << std::endl;
25+
return EXIT_FAILURE;
26+
}
27+
28+
const std::string ONNX_MODEL_PATH = argv[1];
29+
const std::vector<std::string> IMAGE_PATHS = {argv[2], argv[3]};
30+
31+
std::vector<cv::Mat> images;
32+
std::vector<cv::Mat> grays;
33+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images),
34+
[](const auto& imagePath) { return cv::imread(imagePath); });
35+
for (int i = 0; i < 2; ++i) {
36+
if (images[i].empty()) {
37+
throw std::runtime_error("failed to open " + IMAGE_PATHS[i]);
38+
}
39+
}
40+
41+
std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays),
42+
[](const auto& imagePath) { return cv::imread(imagePath, 0); });
43+
44+
std::vector<float> queryData(Ort::LoFTR::IMG_CHANNEL * Ort::LoFTR::IMG_H * Ort::LoFTR::IMG_W);
45+
std::vector<float> refData(Ort::LoFTR::IMG_CHANNEL * Ort::LoFTR::IMG_H * Ort::LoFTR::IMG_W);
46+
47+
Ort::LoFTR osh(
48+
ONNX_MODEL_PATH, 0,
49+
std::vector<std::vector<int64_t>>{{1, Ort::LoFTR::IMG_CHANNEL, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_W},
50+
{1, Ort::LoFTR::IMG_CHANNEL, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_W}});
51+
52+
auto matchedKpts = processOneImagePair(osh, grays[0], grays[1], queryData.data(), refData.data());
53+
const std::vector<cv::KeyPoint>& queryKpts = matchedKpts.first;
54+
const std::vector<cv::KeyPoint>& refKpts = matchedKpts.second;
55+
std::vector<cv::DMatch> matches;
56+
for (int i = 0; i < queryKpts.size(); ++i) {
57+
cv::DMatch match;
58+
match.imgIdx = 0;
59+
match.queryIdx = i;
60+
match.trainIdx = i;
61+
matches.emplace_back(std::move(match));
62+
}
63+
cv::Mat matchesImage;
64+
cv::drawMatches(images[0], queryKpts, images[1], refKpts, matches, matchesImage, cv::Scalar::all(-1),
65+
cv::Scalar::all(-1), std::vector<char>(), cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS);
66+
cv::imwrite("loftr.jpg", matchesImage);
67+
cv::imshow("loftr", matchesImage);
68+
cv::waitKey();
69+
70+
return EXIT_SUCCESS;
71+
}
72+
73+
namespace
74+
{
75+
std::pair<std::vector<cv::KeyPoint>, std::vector<cv::KeyPoint>>
76+
processOneImagePair(const Ort::LoFTR& loftrOsh, const cv::Mat& queryImg, const cv::Mat& refImg, float* queryData,
77+
float* refData, float confidenceThresh)
78+
{
79+
int origQueryW = queryImg.cols, origQueryH = queryImg.rows;
80+
int origRefW = refImg.cols, origRefH = refImg.rows;
81+
82+
cv::Mat scaledQueryImg, scaledRefImg;
83+
cv::resize(queryImg, scaledQueryImg, cv::Size(Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H), 0, 0, cv::INTER_CUBIC);
84+
cv::resize(refImg, scaledRefImg, cv::Size(Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H), 0, 0, cv::INTER_CUBIC);
85+
86+
loftrOsh.preprocess(queryData, scaledQueryImg.data, Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_CHANNEL);
87+
loftrOsh.preprocess(refData, scaledRefImg.data, Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_CHANNEL);
88+
auto inferenceOutput = loftrOsh({queryData, refData});
89+
90+
// inferenceOutput[0].second: keypoints0 of shape [num kpt x 2]
91+
// inferenceOutput[1].second: keypoints1 of shape [num kpt x 2]
92+
// inferenceOutput[2].second: confidences of shape [num kpt]
93+
94+
int numKeyPoints = inferenceOutput[2].second[0];
95+
std::vector<cv::KeyPoint> queryKpts, refKpts;
96+
queryKpts.reserve(numKeyPoints);
97+
refKpts.reserve(numKeyPoints);
98+
99+
for (int i = 0; i < numKeyPoints; ++i) {
100+
float confidence = inferenceOutput[2].first[i];
101+
if (confidence < confidenceThresh) {
102+
continue;
103+
}
104+
float queryX = inferenceOutput[0].first[i * 2 + 0];
105+
float queryY = inferenceOutput[0].first[i * 2 + 1];
106+
float refX = inferenceOutput[1].first[i * 2 + 0];
107+
float refY = inferenceOutput[1].first[i * 2 + 1];
108+
cv::KeyPoint queryKpt, refKpt;
109+
queryKpt.pt.x = queryX * origQueryW / Ort::LoFTR::IMG_W;
110+
queryKpt.pt.y = queryY * origQueryH / Ort::LoFTR::IMG_H;
111+
112+
refKpt.pt.x = refX * origRefW / Ort::LoFTR::IMG_W;
113+
refKpt.pt.y = refY * origRefH / Ort::LoFTR::IMG_H;
114+
115+
queryKpts.emplace_back(std::move(queryKpt));
116+
refKpts.emplace_back(std::move(refKpt));
117+
}
118+
119+
return std::make_pair(queryKpts, refKpts);
120+
}
121+
} // namespace

0 commit comments

Comments
 (0)