|
| 1 | +/** |
| 2 | + * @file SuperGlueApp.cpp |
| 3 | + * |
| 4 | + * @author btran |
| 5 | + * |
| 6 | + */ |
| 7 | + |
| 8 | +#include "SuperPoint.hpp" |
| 9 | +#include "Utility.hpp" |
| 10 | + |
| 11 | +namespace |
| 12 | +{ |
| 13 | +using KeyPointAndDesc = std::pair<std::vector<cv::KeyPoint>, cv::Mat>; |
| 14 | + |
| 15 | +KeyPointAndDesc processOneFrameSuperPoint(const Ort::SuperPoint& superPointOsh, const cv::Mat& inputImg, float* dst, |
| 16 | + int borderRemove = 4, float confidenceThresh = 0.015, |
| 17 | + bool alignCorners = true, int distThresh = 2); |
| 18 | + |
| 19 | +void normalizeDescriptors(cv::Mat& descriptors); |
| 20 | +} // namespace |
| 21 | + |
| 22 | +int main(int argc, char* argv[]) |
| 23 | +{ |
| 24 | + if (argc != 5) { |
| 25 | + std::cerr |
| 26 | + << "Usage: [apps] [path/to/onnx/super/point] [path/to/onnx/super/glue] [path/to/image1] [path/to/image2]" |
| 27 | + << std::endl; |
| 28 | + return EXIT_FAILURE; |
| 29 | + } |
| 30 | + |
| 31 | + const std::string ONNX_MODEL_PATH = argv[1]; |
| 32 | + const std::string SUPERGLUE_ONNX_MODEL_PATH = argv[2]; |
| 33 | + const std::vector<std::string> IMAGE_PATHS = {argv[3], argv[4]}; |
| 34 | + |
| 35 | + Ort::SuperPoint superPointOsh(ONNX_MODEL_PATH, 0, |
| 36 | + std::vector<std::vector<int64_t>>{{1, Ort::SuperPoint::IMG_CHANNEL, |
| 37 | + Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W}}); |
| 38 | + |
| 39 | + std::vector<cv::Mat> images; |
| 40 | + std::vector<cv::Mat> grays; |
| 41 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(images), |
| 42 | + [](const auto& imagePath) { return cv::imread(imagePath); }); |
| 43 | + for (int i = 0; i < 2; ++i) { |
| 44 | + if (images[i].empty()) { |
| 45 | + throw std::runtime_error("failed to open " + IMAGE_PATHS[i]); |
| 46 | + } |
| 47 | + } |
| 48 | + std::transform(IMAGE_PATHS.begin(), IMAGE_PATHS.end(), std::back_inserter(grays), |
| 49 | + [](const auto& imagePath) { return cv::imread(imagePath, 0); }); |
| 50 | + |
| 51 | + std::vector<float> dst(Ort::SuperPoint::IMG_CHANNEL * Ort::SuperPoint::IMG_H * Ort::SuperPoint::IMG_W); |
| 52 | + |
| 53 | + std::vector<KeyPointAndDesc> superPointResults; |
| 54 | + std::transform(grays.begin(), grays.end(), std::back_inserter(superPointResults), |
| 55 | + [&superPointOsh, &dst](const auto& gray) { |
| 56 | + return processOneFrameSuperPoint(superPointOsh, gray, dst.data()); |
| 57 | + }); |
| 58 | + |
| 59 | + for (auto& curKeyPointAndDesc : superPointResults) { |
| 60 | + normalizeDescriptors(curKeyPointAndDesc.second); |
| 61 | + } |
| 62 | + |
| 63 | + // superglue |
| 64 | + static const int DUMMY_NUM_KEYPOINTS = 256; |
| 65 | + Ort::SuperPoint superGlueOsh(SUPERGLUE_ONNX_MODEL_PATH, 0, |
| 66 | + std::vector<std::vector<int64_t>>{ |
| 67 | + {4}, |
| 68 | + {1, DUMMY_NUM_KEYPOINTS}, |
| 69 | + {1, DUMMY_NUM_KEYPOINTS, 2}, |
| 70 | + {1, 256, DUMMY_NUM_KEYPOINTS}, |
| 71 | + {4}, |
| 72 | + {1, DUMMY_NUM_KEYPOINTS}, |
| 73 | + {1, DUMMY_NUM_KEYPOINTS, 2}, |
| 74 | + {1, 256, DUMMY_NUM_KEYPOINTS}, |
| 75 | + }); |
| 76 | + |
| 77 | + int numKeypoints0 = superPointResults[0].first.size(); |
| 78 | + int numKeypoints1 = superPointResults[1].first.size(); |
| 79 | + std::vector<std::vector<int64_t>> inputShapes = { |
| 80 | + {4}, {1, numKeypoints0}, {1, numKeypoints0, 2}, {1, 256, numKeypoints0}, |
| 81 | + {4}, {1, numKeypoints1}, {1, numKeypoints1, 2}, {1, 256, numKeypoints1}, |
| 82 | + }; |
| 83 | + superGlueOsh.updateInputShapes(inputShapes); |
| 84 | + |
| 85 | + std::vector<std::vector<float>> imageShapes(2); |
| 86 | + std::vector<std::vector<float>> scores(2); |
| 87 | + std::vector<std::vector<float>> keypoints(2); |
| 88 | + std::vector<std::vector<float>> descriptors(2); |
| 89 | + |
| 90 | + cv::Mat buffer; |
| 91 | + for (int i = 0; i < 2; ++i) { |
| 92 | + imageShapes[i] = {1, 1, static_cast<float>(images[0].rows), static_cast<float>(images[0].cols)}; |
| 93 | + std::transform(superPointResults[i].first.begin(), superPointResults[i].first.end(), |
| 94 | + std::back_inserter(scores[i]), [](const cv::KeyPoint& keypoint) { return keypoint.response; }); |
| 95 | + for (const auto& k : superPointResults[i].first) { |
| 96 | + keypoints[i].emplace_back(k.pt.y); |
| 97 | + keypoints[i].emplace_back(k.pt.x); |
| 98 | + } |
| 99 | + |
| 100 | + transposeNDWrapper(superPointResults[i].second, {1, 0}, buffer); |
| 101 | + std::copy(buffer.begin<float>(), buffer.end<float>(), std::back_inserter(descriptors[i])); |
| 102 | + buffer.release(); |
| 103 | + } |
| 104 | + std::vector<Ort::OrtSessionHandler::DataOutputType> superGlueOrtOutput = |
| 105 | + superGlueOsh({imageShapes[0].data(), scores[0].data(), keypoints[0].data(), descriptors[0].data(), |
| 106 | + imageShapes[1].data(), scores[1].data(), keypoints[1].data(), descriptors[1].data()}); |
| 107 | + |
| 108 | + // match keypoints 0 to keypoints 1 |
| 109 | + std::vector<int64_t> matchIndices(reinterpret_cast<int64_t*>(superGlueOrtOutput[0].first), |
| 110 | + reinterpret_cast<int64_t*>(superGlueOrtOutput[0].first) + numKeypoints0); |
| 111 | + |
| 112 | + std::vector<cv::DMatch> goodMatches; |
| 113 | + for (std::size_t i = 0; i < matchIndices.size(); ++i) { |
| 114 | + if (matchIndices[i] < 0) { |
| 115 | + continue; |
| 116 | + } |
| 117 | + cv::DMatch match; |
| 118 | + match.imgIdx = 0; |
| 119 | + match.queryIdx = i; |
| 120 | + match.trainIdx = matchIndices[i]; |
| 121 | + goodMatches.emplace_back(match); |
| 122 | + } |
| 123 | + |
| 124 | + cv::Mat matchesImage; |
| 125 | + cv::drawMatches(images[0], superPointResults[0].first, images[1], superPointResults[1].first, goodMatches, |
| 126 | + matchesImage, cv::Scalar::all(-1), cv::Scalar::all(-1), std::vector<char>(), |
| 127 | + cv::DrawMatchesFlags::NOT_DRAW_SINGLE_POINTS); |
| 128 | + cv::imwrite("super_point_good_matches.jpg", matchesImage); |
| 129 | + cv::imshow("super_point_good_matches", matchesImage); |
| 130 | + cv::waitKey(); |
| 131 | + |
| 132 | + return EXIT_SUCCESS; |
| 133 | +} |
| 134 | + |
| 135 | +namespace |
| 136 | +{ |
| 137 | +void normalizeDescriptors(cv::Mat& descriptors) |
| 138 | +{ |
| 139 | + cv::Mat rsquaredSumMat; |
| 140 | + cv::reduce(descriptors.mul(descriptors), rsquaredSumMat, 1, cv::REDUCE_SUM); |
| 141 | + cv::sqrt(rsquaredSumMat, rsquaredSumMat); |
| 142 | + for (int i = 0; i < descriptors.rows; ++i) { |
| 143 | + float rsquaredSum = std::max<float>(rsquaredSumMat.ptr<float>()[i], 1e-12); |
| 144 | + descriptors.row(i) /= rsquaredSum; |
| 145 | + } |
| 146 | +} |
| 147 | + |
| 148 | +KeyPointAndDesc processOneFrameSuperPoint(const Ort::SuperPoint& superPointOsh, const cv::Mat& inputImg, float* dst, |
| 149 | + int borderRemove, float confidenceThresh, bool alignCorners, int distThresh) |
| 150 | +{ |
| 151 | + int origW = inputImg.cols, origH = inputImg.rows; |
| 152 | + cv::Mat scaledImg; |
| 153 | + cv::resize(inputImg, scaledImg, cv::Size(Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H), 0, 0, cv::INTER_CUBIC); |
| 154 | + superPointOsh.preprocess(dst, scaledImg.data, Ort::SuperPoint::IMG_W, Ort::SuperPoint::IMG_H, |
| 155 | + Ort::SuperPoint::IMG_CHANNEL); |
| 156 | + auto inferenceOutput = superPointOsh({dst}); |
| 157 | + |
| 158 | + std::vector<cv::KeyPoint> keyPoints = superPointOsh.getKeyPoints(inferenceOutput, borderRemove, confidenceThresh); |
| 159 | + |
| 160 | + std::vector<int> descriptorShape(inferenceOutput[1].second.begin(), inferenceOutput[1].second.end()); |
| 161 | + cv::Mat coarseDescriptorMat(descriptorShape.size(), descriptorShape.data(), CV_32F, |
| 162 | + inferenceOutput[1].first); // 1 x 256 x H/8 x W/8 |
| 163 | + |
| 164 | + std::vector<int> keepIndices = |
| 165 | + superPointOsh.nmsFast(keyPoints, Ort::SuperPoint::IMG_H, Ort::SuperPoint::IMG_W, distThresh); |
| 166 | + |
| 167 | + std::vector<cv::KeyPoint> keepKeyPoints; |
| 168 | + keepKeyPoints.reserve(keepIndices.size()); |
| 169 | + std::transform(keepIndices.begin(), keepIndices.end(), std::back_inserter(keepKeyPoints), |
| 170 | + [&keyPoints](int idx) { return keyPoints[idx]; }); |
| 171 | + keyPoints = std::move(keepKeyPoints); |
| 172 | + |
| 173 | + cv::Mat descriptors = superPointOsh.getDescriptors(coarseDescriptorMat, keyPoints, Ort::SuperPoint::IMG_H, |
| 174 | + Ort::SuperPoint::IMG_W, alignCorners); |
| 175 | + |
| 176 | + for (auto& keyPoint : keyPoints) { |
| 177 | + keyPoint.pt.x *= static_cast<float>(origW) / Ort::SuperPoint::IMG_W; |
| 178 | + keyPoint.pt.y *= static_cast<float>(origH) / Ort::SuperPoint::IMG_H; |
| 179 | + } |
| 180 | + |
| 181 | + return {keyPoints, descriptors}; |
| 182 | +} |
| 183 | +} // namespace |
0 commit comments