|
| 1 | +/** |
| 2 | + * @file LoFTRApp.cpp |
| 3 | + * |
| 4 | + * @author btran |
| 5 | + * |
| 6 | + */ |
| 7 | + |
| 8 | +#include "LoFTR.hpp" |
| 9 | +#include "Utility.hpp" |
| 10 | +#include <utility> |
| 11 | + |
| 12 | +static constexpr float CONFIDENCE_THRESHOLD = 0.1; |
| 13 | + |
| 14 | +namespace |
| 15 | +{ |
| 16 | +std::pair<std::vector<cv::KeyPoint>, std::vector<cv::KeyPoint>> |
| 17 | +processOneImagePair(const Ort::LoFTR& loftrOsh, const cv::Mat& queryImg, const cv::Mat& refImg, float* queryData, |
| 18 | + float* refData, float confidenceThresh = CONFIDENCE_THRESHOLD); |
| 19 | +} // namespace |
| 20 | + |
| 21 | +int main(int argc, char* argv[]) |
| 22 | +{ |
| 23 | + if (argc != 4) { |
| 24 | + std::cerr << "Usage: [apps] [path/to/onnx/loftr] [path/to/image1] [path/to/image2]" << std::endl; |
| 25 | + return EXIT_FAILURE; |
| 26 | + } |
| 27 | + |
| 28 | + const std::string ONNX_MODEL_PATH = argv[1]; |
| 29 | + const std::vector<std::string> IMAGE_PATHS = {argv[2], argv[3]}; |
| 30 | + |
| 31 | + std::vector<cv::Mat> images; |
| 32 | + std::vector<cv::Mat> grays; |
| 33 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images), |
| 34 | + [](const auto& imagePath) { return cv::imread(imagePath); }); |
| 35 | + for (int i = 0; i < 2; ++i) { |
| 36 | + if (images[i].empty()) { |
| 37 | + throw std::runtime_error("failed to open " + IMAGE_PATHS[i]); |
| 38 | + } |
| 39 | + } |
| 40 | + |
| 41 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays), |
| 42 | + [](const auto& imagePath) { return cv::imread(imagePath, 0); }); |
| 43 | + |
| 44 | + std::vector<float> queryData(Ort::LoFTR::IMG_CHANNEL * Ort::LoFTR::IMG_H * Ort::LoFTR::IMG_W); |
| 45 | + std::vector<float> refData(Ort::LoFTR::IMG_CHANNEL * Ort::LoFTR::IMG_H * Ort::LoFTR::IMG_W); |
| 46 | + |
| 47 | + Ort::LoFTR osh( |
| 48 | + ONNX_MODEL_PATH, 0, |
| 49 | + std::vector<std::vector<int64_t>>{{1, Ort::LoFTR::IMG_CHANNEL, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_W}, |
| 50 | + {1, Ort::LoFTR::IMG_CHANNEL, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_W}}); |
| 51 | + |
| 52 | + auto matchedKpts = processOneImagePair(osh, grays[0], grays[1], queryData.data(), refData.data()); |
| 53 | + const std::vector<cv::KeyPoint>& queryKpts = matchedKpts.first; |
| 54 | + const std::vector<cv::KeyPoint>& refKpts = matchedKpts.second; |
| 55 | + std::vector<cv::DMatch> matches; |
| 56 | + for (int i = 0; i < queryKpts.size(); ++i) { |
| 57 | + cv::DMatch match; |
| 58 | + match.imgIdx = 0; |
| 59 | + match.queryIdx = i; |
| 60 | + match.trainIdx = i; |
| 61 | + matches.emplace_back(std::move(match)); |
| 62 | + } |
| 63 | + cv::Mat matchesImage; |
| 64 | + cv::drawMatches(images[0], queryKpts, images[1], refKpts, matches, matchesImage, cv::Scalar::all(-1), |
| 65 | + cv::Scalar::all(-1), std::vector<char>(), cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS); |
| 66 | + cv::imwrite("loftr.jpg", matchesImage); |
| 67 | + cv::imshow("loftr", matchesImage); |
| 68 | + cv::waitKey(); |
| 69 | + |
| 70 | + return EXIT_SUCCESS; |
| 71 | +} |
| 72 | + |
| 73 | +namespace |
| 74 | +{ |
| 75 | +std::pair<std::vector<cv::KeyPoint>, std::vector<cv::KeyPoint>> |
| 76 | +processOneImagePair(const Ort::LoFTR& loftrOsh, const cv::Mat& queryImg, const cv::Mat& refImg, float* queryData, |
| 77 | + float* refData, float confidenceThresh) |
| 78 | +{ |
| 79 | + int origQueryW = queryImg.cols, origQueryH = queryImg.rows; |
| 80 | + int origRefW = refImg.cols, origRefH = refImg.rows; |
| 81 | + |
| 82 | + cv::Mat scaledQueryImg, scaledRefImg; |
| 83 | + cv::resize(queryImg, scaledQueryImg, cv::Size(Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H), 0, 0, cv::INTER_CUBIC); |
| 84 | + cv::resize(refImg, scaledRefImg, cv::Size(Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H), 0, 0, cv::INTER_CUBIC); |
| 85 | + |
| 86 | + loftrOsh.preprocess(queryData, scaledQueryImg.data, Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_CHANNEL); |
| 87 | + loftrOsh.preprocess(refData, scaledRefImg.data, Ort::LoFTR::IMG_W, Ort::LoFTR::IMG_H, Ort::LoFTR::IMG_CHANNEL); |
| 88 | + auto inferenceOutput = loftrOsh({queryData, refData}); |
| 89 | + |
| 90 | + // inferenceOutput[0].second: keypoints0 of shape [num kpt x 2] |
| 91 | + // inferenceOutput[1].second: keypoints1 of shape [num kpt x 2] |
| 92 | + // inferenceOutput[2].second: confidences of shape [num kpt] |
| 93 | + |
| 94 | + int numKeyPoints = inferenceOutput[2].second[0]; |
| 95 | + std::vector<cv::KeyPoint> queryKpts, refKpts; |
| 96 | + queryKpts.reserve(numKeyPoints); |
| 97 | + refKpts.reserve(numKeyPoints); |
| 98 | + |
| 99 | + for (int i = 0; i < numKeyPoints; ++i) { |
| 100 | + float confidence = inferenceOutput[2].first[i]; |
| 101 | + if (confidence < confidenceThresh) { |
| 102 | + continue; |
| 103 | + } |
| 104 | + float queryX = inferenceOutput[0].first[i * 2 + 0]; |
| 105 | + float queryY = inferenceOutput[0].first[i * 2 + 1]; |
| 106 | + float refX = inferenceOutput[1].first[i * 2 + 0]; |
| 107 | + float refY = inferenceOutput[1].first[i * 2 + 1]; |
| 108 | + cv::KeyPoint queryKpt, refKpt; |
| 109 | + queryKpt.pt.x = queryX * origQueryW / Ort::LoFTR::IMG_W; |
| 110 | + queryKpt.pt.y = queryY * origQueryH / Ort::LoFTR::IMG_H; |
| 111 | + |
| 112 | + refKpt.pt.x = refX * origRefW / Ort::LoFTR::IMG_W; |
| 113 | + refKpt.pt.y = refY * origRefH / Ort::LoFTR::IMG_H; |
| 114 | + |
| 115 | + queryKpts.emplace_back(std::move(queryKpt)); |
| 116 | + refKpts.emplace_back(std::move(refKpt)); |
| 117 | + } |
| 118 | + |
| 119 | + return std::make_pair(queryKpts, refKpts); |
| 120 | +} |
| 121 | +} // namespace |
0 commit comments