@@ -964,6 +964,87 @@ class MultidimensionalPGMIndex {
964
964
* @return an iterator pointing to an element inside the query hyperrectangle
965
965
*/
966
966
iterator range (const value_type &min, const value_type &max) { return iterator (this , min, max); }
967
+
968
+ /* *
969
+ * (approximate) k-nearest neighbor query.
970
+ * Returns @p k nearest points from query point @p p.
971
+ *
972
+ * @param p the query point.
973
+ * @param k the number of nearest points.
974
+ * @return a vector of k nearest points.
975
+ */
976
+ std::vector<value_type> knn (const value_type &p, uint32_t k){
977
+ // to access coordinate of point dynamically
978
+ using swallow = int [];
979
+ auto sequence = std::make_index_sequence<Dimensions>{};
980
+
981
+ // return euclidean distance between given point and query point p
982
+ auto dist_from_p = [&]<std::size_t ... indices>(value_type point, std::index_sequence<indices...>){
983
+ uint64_t squared_sum = 0 ;
984
+ squared_sum = (std::pow ((int64_t )std::get<indices>(point) - (int64_t )std::get<indices>(p), 2 ) + ... );
985
+ return std::sqrt (squared_sum);
986
+ };
987
+
988
+ // return first point of range query for knn query
989
+ auto k_range_first = [&]<std::size_t ... indices>(uint64_t dist, std::index_sequence<indices...>) -> value_type{
990
+ value_type point;
991
+ swallow{
992
+ (std::get<indices>(point) = std::max<int64_t >((int64_t )std::get<indices>(p) - dist, 0 ), 0 )...
993
+ };
994
+ return point;
995
+ };
996
+
997
+ // return end point of range query for knn query
998
+ auto k_range_end = [&]<std::size_t ... indices>(uint64_t dist, std::index_sequence<indices...>) -> value_type{
999
+ value_type point;
1000
+ swallow{
1001
+ (std::get<indices>(point) = std::min (std::get<indices>(p) + dist, this ->data .size ()), 0 )...
1002
+ };
1003
+ return point;
1004
+ };
1005
+
1006
+ // for debug
1007
+ auto print_point = [&]<std::size_t ... indices>(value_type point, std::index_sequence<indices...>){
1008
+ std::cout << " (" ;
1009
+ swallow{
1010
+ (std::cout << std::get<indices>(point) << " , " , 0 )...
1011
+ };
1012
+ std::cout << " )" ;
1013
+ };
1014
+
1015
+ // get 2k points around zp to make temporary answer
1016
+ auto zp = encode (p);
1017
+ auto range = pgm.search (zp);
1018
+ auto it = std::lower_bound (data.begin () + range.lo , data.begin () + range.hi , zp);
1019
+
1020
+ std::vector<value_type> tmp_ans;
1021
+ for (auto i = it - k >= data.begin () ? it - k : data.begin (); i != it + k && i != data.end (); ++i)
1022
+ tmp_ans.push_back (morton::Decode (*i));
1023
+
1024
+ std::sort (tmp_ans.begin (), tmp_ans.end (), [&](auto const & lhs, auto const & rhs) {
1025
+ double dist_l = dist_from_p (lhs, sequence);
1026
+ double dist_r = dist_from_p (rhs, sequence);
1027
+ return dist_l < dist_r;
1028
+ });
1029
+
1030
+ // calc distance of k nearest point in tmp_ans, and get a range(hyperrectangle) that contains more than k points around p
1031
+ uint64_t k_range_dist = dist_from_p (tmp_ans[k - 1 ], sequence) + 1 ;
1032
+ value_type first = k_range_first (k_range_dist, sequence);
1033
+ value_type end = k_range_end (k_range_dist, sequence);
1034
+
1035
+ // execute range query and get k nearest points
1036
+ std::vector<value_type> ans;
1037
+ for (auto it = this ->range (first, end); it != this ->end (); ++it)
1038
+ ans.push_back (*it);
1039
+
1040
+ std::sort (ans.begin (), ans.end (), [&](auto const & lhs, auto const & rhs) {
1041
+ double dist_l = dist_from_p (lhs, sequence);
1042
+ double dist_r = dist_from_p (rhs, sequence);
1043
+ return dist_l < dist_r;
1044
+ });
1045
+
1046
+ return std::vector<value_type> {ans.begin (), ans.begin () + k};
1047
+ }
967
1048
968
1049
private:
969
1050
0 commit comments