Skip to content

Commit

Permalink
made some changes for num local start pts
Browse files Browse the repository at this point in the history
  • Loading branch information
rakri committed Oct 24, 2024
1 parent 1a89fc0 commit ca2fb3b
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 19 deletions.
9 changes: 8 additions & 1 deletion apps/search_memory_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -447,7 +447,7 @@ int main(int argc, char **argv)
{
std::string data_type, dist_fn, index_path_prefix, result_path, query_file, gt_file, filter_label, label_type,
query_filters_file;
uint32_t num_threads, K, filter_penalty_threshold, bruteforce_threshold, clustering_threshold, L_for_print;
uint32_t num_threads, K, filter_penalty_threshold, bruteforce_threshold, clustering_threshold, L_for_print, num_local;
std::vector<uint32_t> Lvec;
bool print_all_recalls, dynamic, tags, show_qps_per_thread, global_start;
float fail_if_recall_below = 0.0f;
Expand Down Expand Up @@ -499,6 +499,10 @@ int main(int argc, char **argv)
optional_configs.add_options()("use_global_start",
po::value<bool>(&global_start)->default_value(false),
"Whether or not to use global start or predicate-aware starting point in graph search");
optional_configs.add_options()("num_local_start",
po::value<uint32_t>(&num_local)->default_value(0),
"How many local start points to use");


optional_configs.add_options()("label_type", po::value<std::string>(&label_type)->default_value("uint"),
program_options_utils::LABEL_TYPE_DESCRIPTION);
Expand Down Expand Up @@ -607,6 +611,9 @@ int main(int argc, char **argv)
}

use_global_start = global_start;
num_start_points = num_local;

std::cout<<"Num local start points: " << num_start_points << std::endl;

try
{
Expand Down
3 changes: 2 additions & 1 deletion include/index.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ inline int64_t curr_query = -1;
inline uint32_t penalty_scale = 10;
inline uint32_t num_sp = 2;
inline bool use_global_start = false;
inline uint32_t num_start_points = 1;

namespace diskann
{
Expand Down Expand Up @@ -277,7 +278,7 @@ template <typename T, typename TagT = uint32_t, typename LabelT = uint32_t> clas

std::vector<std::pair<LabelT, uint32_t>> sort_filter_counts(const std::vector<LabelT> &filter_label);

std::pair<uint32_t, uint32_t> sample_intersection(roaring::Roaring &intersection_bitmap,
std::pair<uint32_t, std::vector<uint32_t>> sample_intersection(roaring::Roaring &intersection_bitmap,
const std::vector<LabelT> &filter_label);

std::unordered_map<std::string, LabelT> load_label_map(const std::string &map_file);
Expand Down
39 changes: 22 additions & 17 deletions src/index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2541,7 +2541,7 @@ std::vector<std::pair<LabelT, uint32_t>> Index<T, TagT, LabelT>::sort_filter_cou
}

template <typename T, typename TagT, typename LabelT>
std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::sample_intersection(roaring::Roaring &intersection_bitmap,
std::pair<uint32_t, std::vector<uint32_t>> Index<T, TagT, LabelT>::sample_intersection(roaring::Roaring &intersection_bitmap,
const std::vector<LabelT> &filter_label)
{
intersection_bitmap = _labels_to_points_sample[filter_label[0]];
Expand All @@ -2551,12 +2551,16 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::sample_intersection(roarin
}
uint32_t val = std::numeric_limits<uint32_t>::max();
auto x = intersection_bitmap.begin();
if (x != intersection_bitmap.end())
std::vector<uint32_t> results;
results.reserve(num_start_points);
while (x != intersection_bitmap.end() && results.size() < num_start_points)
{
val = _sample_map[*x];
results.emplace_back(val);
x++;
}
// std::cout<<intersection_bitmap.cardinality() << " " << val << std::endl;
return std::make_pair((uint32_t)(intersection_bitmap.cardinality() * (1.0 / (_sample_prob))), val);
return std::make_pair((uint32_t)(intersection_bitmap.cardinality() * (1.0 / (_sample_prob))), results);
}

template <typename T, typename TagT, typename LabelT>
Expand Down Expand Up @@ -2666,17 +2670,18 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
case 2:
num_graphs++;
auto [inter_estim, cand] = sample_intersection(scratch->get_valid_bitmap(), filter_label);
if (!use_global_start) {
if (cand < std::numeric_limits<uint32_t>::max())

if (cand.size() > 0)
{
init_ids.emplace_back(cand);
} else {
init_ids.insert(init_ids.end(), cand.begin(), cand.end());
// init_ids.emplace_back(cand);
} /*else {
if (_label_to_start_id.find(filter_label[0]) != _label_to_start_id.end())
{
init_ids.emplace_back(_label_to_start_id[filter_label[0]]);
}
}
} else {
} */
if (use_global_start) {
init_ids.emplace_back(_start);
}

Expand All @@ -2685,7 +2690,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
{
std::ofstream out("query_stats.txt", std::ios_base::app);
out << "estimated intersection size is " << inter_estim << std::endl;
out << "setting up init ids with id " << cand << std::endl;
//out << "setting up init ids with id " << cand << std::endl;
out.close();
}
retval = iterate_to_fixed_point(scratch, L, init_ids, true, filter_vec, true);
Expand Down Expand Up @@ -2737,17 +2742,17 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
/* if (_dynamic_index) */
/* tl.unlock(); */

if (!use_global_start) {
if (cand < std::numeric_limits<uint32_t>::max())
if (cand.size() > 0)
{
init_ids.emplace_back(cand);
} else {
init_ids.insert(init_ids.end(), cand.begin(), cand.end());
// init_ids.emplace_back(cand);
} /*else {
if (_label_to_start_id.find(filter_label[0]) != _label_to_start_id.end())
{
init_ids.emplace_back(_label_to_start_id[filter_label[0]]);
}
}
} else {
}*/
if (use_global_start) {
init_ids.emplace_back(_start);
}

Expand All @@ -2760,7 +2765,7 @@ std::pair<uint32_t, uint32_t> Index<T, TagT, LabelT>::search_with_filters(const
out << filt << "/" << _labels_to_points[filt].cardinality() << " ";
out << std::endl;
out << "estimated intersection size is " << estimated_match << std::endl;
out << "setting up init ids with id " << cand << std::endl;
//out << "setting up init ids with id " << cand << std::endl;
out << std::endl;
out.close();
}
Expand Down

0 comments on commit ca2fb3b

Please sign in to comment.