Skip to content

Commit bee019d

Browse files
authored
Merge pull request #6 from dbgroup-nagoya-u/add-knn
Add knn
2 parents d01e4bd + 76550c5 commit bee019d

File tree

3 files changed

+92
-1
lines changed

3 files changed

+92
-1
lines changed

.gitignore

+5-1
Original file line numberDiff line numberDiff line change
@@ -152,4 +152,8 @@ Network Trash Folder
152152
Temporary Items
153153
.apdisk
154154

155-
# End of https://www.gitignore.io/api/c++,clion,macos
155+
# End of https://www.gitignore.io/api/c++,clion,macos
156+
157+
.vscode
158+
dataset
159+
build

examples/multidimensional.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -33,5 +33,11 @@ int main() {
3333
auto count = std::distance(pgm_3d.range({0, 0, 0}, {5, 10, 15}), pgm_3d.end());
3434
std::cout << "points in range({0,0,0}, {5,10,15}) = " << count << std::endl;
3535

36+
auto knn = pgm_3d.knn({4, 1, 2}, 5);
37+
std::cout << "5 nearest points from {4,1,2} = ";
38+
for (auto point : knn){
39+
std::cout << "(" << std::get<0>(point) << "," << std::get<1>(point) << "," << std::get<2>(point) << ") ";
40+
}
41+
std::cout << std::endl;
3642
return 0;
3743
}

include/pgm/pgm_index_variants.hpp

+81
Original file line numberDiff line numberDiff line change
@@ -964,6 +964,87 @@ class MultidimensionalPGMIndex {
964964
* @return an iterator pointing to an element inside the query hyperrectangle
965965
*/
966966
iterator range(const value_type &min, const value_type &max) { return iterator(this, min, max); }
967+
968+
/**
969+
* (approximate) k-nearest neighbor query.
970+
* Returns @p k nearest points from query point @p p.
971+
*
972+
* @param p the query point.
973+
* @param k the number of nearest points.
974+
* @return a vector of k nearest points.
975+
*/
976+
std::vector<value_type> knn(const value_type &p, uint32_t k){
977+
// to access coordinate of point dynamically
978+
using swallow = int[];
979+
auto sequence = std::make_index_sequence<Dimensions>{};
980+
981+
// return euclidean distance between given point and query point p
982+
auto dist_from_p = [&]<std::size_t... indices>(value_type point, std::index_sequence<indices...>){
983+
uint64_t squared_sum = 0;
984+
squared_sum = (std::pow((int64_t)std::get<indices>(point) - (int64_t)std::get<indices>(p), 2) + ... );
985+
return std::sqrt(squared_sum);
986+
};
987+
988+
// return first point of range query for knn query
989+
auto k_range_first = [&]<std::size_t... indices>(uint64_t dist, std::index_sequence<indices...>) -> value_type{
990+
value_type point;
991+
swallow{
992+
(std::get<indices>(point) = std::max<int64_t>((int64_t)std::get<indices>(p) - dist, 0), 0)...
993+
};
994+
return point;
995+
};
996+
997+
// return end point of range query for knn query
998+
auto k_range_end = [&]<std::size_t... indices>(uint64_t dist, std::index_sequence<indices...>) -> value_type{
999+
value_type point;
1000+
swallow{
1001+
(std::get<indices>(point) = std::min(std::get<indices>(p) + dist, this->data.size()), 0)...
1002+
};
1003+
return point;
1004+
};
1005+
1006+
// for debug
1007+
auto print_point = [&]<std::size_t... indices>(value_type point, std::index_sequence<indices...>){
1008+
std::cout << "(";
1009+
swallow{
1010+
(std::cout << std::get<indices>(point) << " , ", 0)...
1011+
};
1012+
std::cout << ")";
1013+
};
1014+
1015+
// get 2k points around zp to make temporary answer
1016+
auto zp = encode(p);
1017+
auto range = pgm.search(zp);
1018+
auto it = std::lower_bound(data.begin() + range.lo, data.begin() + range.hi, zp);
1019+
1020+
std::vector<value_type> tmp_ans;
1021+
for (auto i = it - k >= data.begin() ? it - k : data.begin(); i != it + k && i != data.end(); ++i)
1022+
tmp_ans.push_back(morton::Decode(*i));
1023+
1024+
std::sort(tmp_ans.begin(), tmp_ans.end(), [&](auto const& lhs, auto const& rhs) {
1025+
double dist_l = dist_from_p(lhs, sequence);
1026+
double dist_r = dist_from_p(rhs, sequence);
1027+
return dist_l < dist_r;
1028+
});
1029+
1030+
// calc distance of k nearest point in tmp_ans, and get a range(hyperrectangle) that contains more than k points around p
1031+
uint64_t k_range_dist = dist_from_p(tmp_ans[k - 1], sequence) + 1;
1032+
value_type first = k_range_first(k_range_dist, sequence);
1033+
value_type end = k_range_end(k_range_dist, sequence);
1034+
1035+
// execute range query and get k nearest points
1036+
std::vector<value_type> ans;
1037+
for (auto it = this->range(first, end); it != this->end(); ++it)
1038+
ans.push_back(*it);
1039+
1040+
std::sort(ans.begin(), ans.end(), [&](auto const& lhs, auto const& rhs) {
1041+
double dist_l = dist_from_p(lhs, sequence);
1042+
double dist_r = dist_from_p(rhs, sequence);
1043+
return dist_l < dist_r;
1044+
});
1045+
1046+
return std::vector<value_type> {ans.begin(), ans.begin() + k};
1047+
}
9671048

9681049
private:
9691050

0 commit comments

Comments
 (0)