diff --git a/.github/workflows/pr-test.yml b/.github/workflows/pr-test.yml index f8c5b228b..46495dbf6 100644 --- a/.github/workflows/pr-test.yml +++ b/.github/workflows/pr-test.yml @@ -6,6 +6,7 @@ jobs: runs-on: ${{ matrix.os }} strategy: + fail-fast: false matrix: os: [ubuntu-latest, windows-2019, windows-latest] @@ -71,6 +72,9 @@ jobs: run: | ${{ env.diskann_built_tests }}/build_memory_index --data_type float --dist_fn l2 --data_path ./rand_float_10D_10K_norm1.0.bin --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 + - name: Searching with fast_l2 distance function + if: runner.os != 'Windows' + run: | ${{ env.diskann_built_tests }}/search_memory_index --data_type float --dist_fn fast_l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_float_10D_10K_norm1.0 --query_file ./rand_float_10D_1K_norm1.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_float_10D_10K_norm1.0_10D_1K_norm1.0_gt100 -L 16 32 - name: build and search in-memory index with MIPS metric run: | @@ -117,7 +121,6 @@ jobs: - name: Generate 10K random int8 index vectors, 1K query vectors, in 10 dims and compute GT - run: | run: | ${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 ${{ env.diskann_built_utils }}/rand_data_gen --data_type int8 --output_file ./rand_int8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 @@ -154,10 +157,11 @@ jobs: ${{ env.diskann_built_tests }}/search_disk_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_int8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file ./l2_rand_int8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - name: build and search an incremental index run: | - ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200; + ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200 ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_1K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags ${{ env.diskann_built_tests }}/search_memory_index --data_type int8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_int8_10D_1K_norm50.0.bin --gt_file gt100_random10D_1K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 - name: test a streaming index + if: success() || failure() run: | ${{ env.diskann_built_tests }}/test_streaming_scenario --data_type int8 --dist_fn l2 --data_path rand_int8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 ${{ env.diskann_built_utils }}/compute_groundtruth --data_type int8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_int8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags @@ -165,6 +169,7 @@ jobs: - name: Generate 10K random uint8 index vectors, 1K query vectors, in 10 dims and compute GT + if: success() || failure() run: | ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 @@ -172,44 +177,77 @@ jobs: ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn mips --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./mips_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn cosine --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --K 100 - name: build and search in-memory index with L2 metrics + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - name: build and search in-memory index with cosine metric + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn cosine --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_cosine_rand_uint8_10D_10K_norm50.0 ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn cosine --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./cosine_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - name: build and search in-memory index with L2 metrics with PQ base distance comparisons + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --build_PQ_bytes 5 ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50.0_buildpq5 --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 -L 16 32 - name: build and search disk index (one shot graph build, L2, no diskPQ) + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot -R 16 -L 32 -B 0.00003 -M 1 ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - name: build and search disk index (one shot graph build, L2, no diskPQ, build with PQ distance comparisons) + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 -R 16 -L 32 -B 0.00003 -M 1 --build_PQ_bytes 5 ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_oneshot_buildpq5 --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - name: build and search disk index (sharded graph build, L2, no diskPQ) + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded -R 16 -L 32 -B 0.00003 -M 0.00006 ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskfull_sharded --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - name: build and search disk index (one shot graph build, L2, diskPQ) + if: success() || failure() run: | ${{ env.diskann_built_tests }}/build_disk_index --data_type uint8 --dist_fn l2 --data_path ./rand_uint8_10D_10K_norm50.0.bin --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot -R 16 -L 32 -B 0.00003 -M 1 --PQ_disk_bytes 5 ${{ env.diskann_built_tests }}/search_disk_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix ./disk_index_l2_rand_uint8_10D_10K_norm50.0_diskpq_oneshot --result_path /tmp/res --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100 --recall_at 5 -L 5 12 -W 2 --num_nodes_to_cache 10 -T 16 - name: build and search an incremental index + if: success() || failure() run: | ${{ env.diskann_built_tests }}/test_insert_deletes_consolidate --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_ins_del -R 64 -L 300 --alpha 1.2 -T 8 --points_to_skip 0 --max_points_to_insert 7500 --beginning_index_size 0 --points_per_checkpoint 1000 --checkpoints_per_snapshot 0 --points_to_delete_from_beginning 2500 --start_deletes_after 5000 --do_concurrent true --start_point_norm 200; ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_ins_del.after-concurrent-delete-del2500-7500.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_random10D_10K-conc-2500-7500 --tags_file index_ins_del.after-concurrent-delete-del2500-7500.tags ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_ins_del.after-concurrent-delete-del2500-7500 --result_path res_ins_del --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_random10D_10K-conc-2500-7500 -K 10 -L 20 40 60 80 100 -T 8 --dynamic true --tags 1 - name: test a streaming index + if: success() || failure() run: | ${{ env.diskann_built_tests }}/test_streaming_scenario --data_type uint8 --dist_fn l2 --data_path rand_uint8_10D_10K_norm50.0.bin --index_path_prefix index_stream -R 64 -L 600 --alpha 1.2 --insert_threads 4 --consolidate_threads 4 --max_points_to_insert 10000 --active_window 4000 --consolidate_interval 2000 --start_point_norm 200 ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --base_file index_stream.after-streaming-act4000-cons2000-max10000.data --query_file rand_uint8_10D_1K_norm50.0.bin --K 100 --gt_file gt100_base-act4000-cons2000-max10000 --tags_file index_stream.after-streaming-act4000-cons2000-max10000.tags ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --fail_if_recall_below 70 --index_path_prefix index_stream.after-streaming-act4000-cons2000-max10000 --result_path res_stream --query_file ./rand_uint8_10D_1K_norm50.0.bin --gt_file gt100_base-act4000-cons2000-max10000 -K 10 -L 20 40 60 80 100 -T 64 --dynamic true --tags 1 + - name: Generate 10K random uint8 index vectors, 1K query vectors, 10K Label Points (50 unique labels), in 10 dims and compute GT + if: success() || failure() + run: | + ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_10K_norm50.0.bin -D 10 -N 10000 --norm 50.0 + ${{ env.diskann_built_utils }}/rand_data_gen --data_type uint8 --output_file ./rand_uint8_10D_1K_norm50.0.bin -D 10 -N 1000 --norm 50.0 + ${{ env.diskann_built_utils }}/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_10_10K.txt + ${{ env.diskann_built_utils }}/compute_groundtruth --data_type uint8 --dist_fn l2 --universal_label 0 --filter_label 10 --base_file ./rand_uint8_10D_10K_norm50.0.bin --query_file ./rand_uint8_10D_1K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel --K 100 + - name: build and search in-memory index with labels using L2 metrics + if: success() || failure() + run: | + ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel + ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + + - name: build and search in-memory index with pq_dist of 5 with 10 dimensions + if: success() || failure() + run: | + ${{ env.diskann_built_tests }}/build_memory_index --data_type uint8 --dist_fn l2 --FilteredLbuild 90 --universal_label 0 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --build_PQ_bytes 5 + ${{ env.diskann_built_tests }}/search_memory_index --data_type uint8 --dist_fn l2 --filter_label 10 --fail_if_recall_below 70 --index_path_prefix ./index_l2_rand_uint8_10D_10K_norm50_wlabel --query_file ./rand_uint8_10D_1K_norm50.0.bin --recall_at 10 --result_path temp --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -L 16 32 + + - name: Build and search stitched vamana + if: success() || failure() + run: | + ${{ env.diskann_built_tests }}/build_stitched_index --num_threads 48 --data_type uint8 --data_path ./rand_uint8_10D_10K_norm50.0.bin --label_file ./rand_labels_10_10K.txt -R 32 -L 100 --alpha 1.2 --stitched_R 64 --index_path_prefix ./stit_32_100_64_new --universal_label 0 + ${{ env.diskann_built_tests }}/search_memory_index --num_threads 48 --data_type uint8 --dist_fn l2 --filter_label 10 --index_path_prefix ./stit_32_100_64_new --query_file ./rand_uint8_10D_1K_norm50.0.bin --result_path ./rand_stit_96_10_90_new --gt_file ./l2_rand_uint8_10D_10K_norm50.0_10D_1K_norm50.0_gt100_wlabel -K 10 -L 10 10 10 10 10 30 50 70 90 110 130 150 170 190 210 230 250 270 290 310 330 350 370 390 410 - uses: actions/setup-python@v3 - name: Install cibuildwheel run: python -m pip install cibuildwheel==2.11.3 diff --git a/CMakeSettings.json b/CMakeSettings.json new file mode 100644 index 000000000..af5d7b5c7 --- /dev/null +++ b/CMakeSettings.json @@ -0,0 +1,28 @@ +{ + "configurations": [ + { + "name": "x64-Release", + "generator": "Ninja", + "configurationType": "Release", + "inheritEnvironments": [ "msvc_x64" ], + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "" + }, + { + "name": "WSL-GCC-Release", + "generator": "Ninja", + "configurationType": "RelWithDebInfo", + "buildRoot": "${projectDir}\\out\\build\\${name}", + "installRoot": "${projectDir}\\out\\install\\${name}", + "cmakeExecutable": "cmake", + "cmakeCommandArgs": "", + "buildCommandArgs": "", + "ctestCommandArgs": "", + "inheritEnvironments": [ "linux_x64" ], + "wslPath": "${defaultWSLPath}" + } + ] +} \ No newline at end of file diff --git a/README.md b/README.md index 8294b9df3..88afa84d6 100644 --- a/README.md +++ b/README.md @@ -87,4 +87,17 @@ Please see the following pages on using the compiled code: - [Commandline interface for building and search SSD based indices](workflows/SSD_index.md) - [Commandline interface for building and search in memory indices](workflows/in_memory_index.md) - [Commandline examples for using in-memory streaming indices](workflows/dynamic_index.md) +- [Commandline interface for building and search in memory indices with label data and filters](workflows/filtered_in_memory.md) - To be added: Python interfaces and docker files + +Please cite this software in your work as: + +``` +@misc{diskann-github, + author = {Simhadri, Harsha Vardhan and Krishnaswamy, Ravishankar and Srinivasa, Gopal and Subramanya, Suhas Jayaram and Antonijevic, Andrija and Pryce, Dax and Kaczynski, David and Williams, Shane and Gollapudi, Siddarth and Sivashankar, Varun and Karia, Neel and Singh, Aditi and Jaiswal, Shikhar and Mahapatro, Neelam and Adams, Philip and Tower, Bryan}}, + title = {{DiskANN: Scalable, efficient and Feature-rich ANNS}}, + url = {https://github.com/Microsoft/DiskANN}, + version = {0.5}, + year = {2023} +} +``` diff --git a/include/disk_utils.h b/include/disk_utils.h index 88305ca76..1bdaec02d 100644 --- a/include/disk_utils.h +++ b/include/disk_utils.h @@ -40,7 +40,7 @@ namespace diskann { const uint32_t WARMUP_L = 20; const uint32_t NUM_KMEANS_REPS = 12; - template + template class PQFlashIndex; DISKANN_DLLEXPORT double get_memory_budget(const std::string &mem_budget_str); @@ -68,38 +68,47 @@ namespace diskann { uint64_t warmup_aligned_dim); #endif - DISKANN_DLLEXPORT int merge_shards(const std::string &vamana_prefix, - const std::string &vamana_suffix, - const std::string &idmaps_prefix, - const std::string &idmaps_suffix, - const _u64 nshards, unsigned max_degree, - const std::string &output_vamana, - const std::string &medoids_file); + DISKANN_DLLEXPORT int merge_shards( + const std::string &vamana_prefix, const std::string &vamana_suffix, + const std::string &idmaps_prefix, const std::string &idmaps_suffix, + const _u64 nshards, unsigned max_degree, const std::string &output_vamana, + const std::string &medoids_file, bool use_filters = false, + const std::string &labels_to_medoids_file = std::string("")); + + DISKANN_DLLEXPORT void extract_shard_labels( + const std::string &in_label_file, const std::string &shard_ids_bin, + const std::string &shard_label_file); template DISKANN_DLLEXPORT std::string preprocess_base_file( const std::string &infile, const std::string &indexPrefix, diskann::Metric &distMetric); - template + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric _compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_file, - std::string centroids_file, size_t build_pq_bytes, bool use_opq); + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters = false, const std::string &label_file = std::string(""), + const std::string &labels_to_medoids_file = std::string(""), + const std::string &universal_label = "", const _u32 Lf = 0); - template + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &_pFlashIndex, T *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, - uint32_t nthreads, uint32_t start_bw = 2); - - template - DISKANN_DLLEXPORT int build_disk_index(const char *dataFilePath, - const char *indexFilePath, - const char *indexBuildParameters, - diskann::Metric _compareMetric, - bool use_opq = false); + std::unique_ptr> &_pFlashIndex, + T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, + uint32_t L, uint32_t nthreads, uint32_t start_bw = 2); + + template + DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric _compareMetric, + bool use_opq = false, bool use_filters = false, + const std::string &label_file = + std::string(""), // default is empty string for no label_file + const std::string &universal_label = "", const _u32 filter_threshold = 0, + const _u32 Lf = 0); // default is empty string for no universal label template DISKANN_DLLEXPORT void create_disk_layout( diff --git a/include/index.h b/include/index.h index 4ba3d6dbd..4cda901d9 100644 --- a/include/index.h +++ b/include/index.h @@ -24,6 +24,7 @@ #define DEFAULT_MAXC 750 namespace diskann { + inline double estimate_ram_usage(_u64 size, _u32 dim, _u32 datasize, _u32 degree) { double size_of_data = ((double) size) * ROUND_UP(dim, 8) * datasize; @@ -60,7 +61,7 @@ namespace diskann { } }; - template + template class Index { /************************************************************************** * @@ -129,6 +130,17 @@ namespace diskann { Parameters ¶meters, const std::vector &tags); + // Filtered Support + DISKANN_DLLEXPORT void build_filtered_index( + const char *filename, const std::string &label_file, + const size_t num_points_to_load, Parameters ¶meters, + const std::vector &tags = std::vector()); + + DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); + + // Get converted integer label from string to int map (_label_map) + DISKANN_DLLEXPORT LabelT get_converted_label(const std::string &raw_label); + // Set starting point of an index before inserting any points incrementally DISKANN_DLLEXPORT void set_start_point(T *data); // Set starting point to a random point on a sphere of certain radius @@ -155,6 +167,12 @@ namespace diskann { float *distances, std::vector &res_vectors); + // Filter support search + template + DISKANN_DLLEXPORT std::pair search_with_filters( + const T *query, const LabelT &filter_label, const size_t K, + const unsigned L, IndexType *indices, float *distances); + // Will fail if tag already in the index or if tag=0. DISKANN_DLLEXPORT int insert_point(const T *point, const TagT tag); @@ -177,6 +195,8 @@ namespace diskann { DISKANN_DLLEXPORT consolidation_report consolidate_deletes(const Parameters ¶meters); + DISKANN_DLLEXPORT void prune_all_nbrs(const Parameters ¶meters); + DISKANN_DLLEXPORT bool is_index_saved(); // repositions frozen points to the end of _data - if they have been moved @@ -208,8 +228,8 @@ namespace diskann { protected: // No copy/assign. - Index(const Index &) = delete; - Index &operator=(const Index &) = delete; + Index(const Index &) = delete; + Index &operator=(const Index &) = delete; // Use after _data and _nd have been populated // Acquire exclusive _update_lock before calling @@ -223,14 +243,23 @@ namespace diskann { // determines navigating node of the graph by calculating medoid of datafopt unsigned calculate_entry_point(); + void parse_label_file(const std::string &label_file, + size_t &num_pts_labels); + + std::unordered_map load_label_map( + const std::string &map_file); + std::pair iterate_to_fixed_point( const T *node_coords, const unsigned Lindex, const std::vector &init_ids, InMemQueryScratch *scratch, + bool use_filter, const std::vector &filters, bool ret_frozen = true, bool search_invocation = false); void search_for_point_and_prune(int location, _u32 Lindex, std::vector &pruned_list, - InMemQueryScratch *scratch); + InMemQueryScratch *scratch, + bool use_filter = false, + _u32 filteredLindex = 0); void prune_neighbors(const unsigned location, std::vector &pool, std::vector &pruned_list, @@ -342,6 +371,19 @@ namespace diskann { bool _enable_tags = false; bool _normalize_vecs = false; // Using normalied L2 for cosine. + // Filter Support + + bool _filtered_index = false; + std::vector> _pts_to_labels; + tsl::robin_set _labels; + std::string _labels_file; + std::unordered_map _label_to_medoid_id; + std::unordered_map<_u32, _u32> _medoid_counts; + bool _use_universal_label = false; + LabelT _universal_label = 0; + uint32_t _filterIndexingQueueSize; + std::unordered_map _label_map; + // Indexing parameters uint32_t _indexingQueueSize; uint32_t _indexingRange; diff --git a/include/pq_flash_index.h b/include/pq_flash_index.h index 62a6f730e..376b898ee 100644 --- a/include/pq_flash_index.h +++ b/include/pq_flash_index.h @@ -20,7 +20,7 @@ namespace diskann { - template + template class PQFlashIndex { public: DISKANN_DLLEXPORT PQFlashIndex( @@ -70,11 +70,26 @@ namespace diskann { float *res_dists, const _u64 beam_width, const bool use_reorder_data = false, QueryStats *stats = nullptr); + DISKANN_DLLEXPORT void cached_beam_search( + const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, + float *res_dists, const _u64 beam_width, const bool use_filter, + const LabelT &filter_label, const bool use_reorder_data = false, + QueryStats *stats = nullptr); + DISKANN_DLLEXPORT void cached_beam_search( const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, float *res_dists, const _u64 beam_width, const _u32 io_limit, const bool use_reorder_data = false, QueryStats *stats = nullptr); + DISKANN_DLLEXPORT void cached_beam_search( + const T *query, const _u64 k_search, const _u64 l_search, _u64 *res_ids, + float *res_dists, const _u64 beam_width, const bool use_filter, + const LabelT &filter_label, const _u32 io_limit, + const bool use_reorder_data = false, QueryStats *stats = nullptr); + + DISKANN_DLLEXPORT LabelT + get_converted_label(const std::string &filter_label); + DISKANN_DLLEXPORT _u32 range_search(const T *query1, const double range, const _u64 min_l_search, const _u64 max_l_search, @@ -94,12 +109,26 @@ namespace diskann { DISKANN_DLLEXPORT void setup_thread_data(_u64 nthreads, _u64 visited_reserve = 4096); + DISKANN_DLLEXPORT void set_universal_label(const LabelT &label); + private: + DISKANN_DLLEXPORT inline bool point_has_label(_u32 point_id, _u32 label_id); + std::unordered_map load_label_map( + const std::string &map_file); + DISKANN_DLLEXPORT void parse_label_file(const std::string &map_file, + size_t &num_pts_labels); + DISKANN_DLLEXPORT void get_label_file_metadata(std::string map_file, + _u32 &num_pts, + _u32 &num_total_labels); + DISKANN_DLLEXPORT inline int32_t get_filter_number( + const LabelT &filter_label); + // index info // nhood of node `i` is in sector: [i / nnodes_per_sector] // offset in sector: [(i % nnodes_per_sector) * max_node_len] // nnbrs of node `i`: *(unsigned*) (buf) // nbrs of node `i`: ((unsigned*)buf) + 1 + _u64 max_node_len = 0, nnodes_per_sector = 0, max_degree = 0; // Data used for searching with re-order vectors @@ -171,6 +200,20 @@ namespace diskann { bool reorder_data_exists = false; _u64 reoreder_data_offset = 0; + // filter support + _u32 *_pts_to_label_offsets = nullptr; + _u32 *_pts_to_labels = nullptr; + tsl::robin_set _labels; + std::unordered_map _filter_to_medoid_id; + bool _use_universal_label; + _u32 _universal_filter_num; + std::vector _filter_list; + tsl::robin_set<_u32> _dummy_pts; + tsl::robin_set<_u32> _has_dummy_pts; + tsl::robin_map<_u32, _u32> _dummy_to_real_map; + tsl::robin_map<_u32, std::vector<_u32>> _real_to_dummy_map; + std::unordered_map _label_map; + #ifdef EXEC_ENV_OLS // Set to a larger value than the actual header to accommodate // any additions we make to the header. This is an outer limit diff --git a/include/utils.h b/include/utils.h index 2ebe48507..d453ec808 100644 --- a/include/utils.h +++ b/include/utils.h @@ -31,19 +31,24 @@ typedef int FileHandle; #include "memory_mapped_files.h" #endif +#include +#include +#include + // taken from // https://github.com/Microsoft/BLAS-on-flash/blob/master/include/utils.h // round up X to the nearest multiple of Y #define ROUND_UP(X, Y) \ - ((((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) * (Y)) + ((((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0)) * (Y)) -#define DIV_ROUND_UP(X, Y) (((uint64_t)(X) / (Y)) + ((uint64_t)(X) % (Y) != 0)) +#define DIV_ROUND_UP(X, Y) \ + (((uint64_t) (X) / (Y)) + ((uint64_t) (X) % (Y) != 0)) // round down X to the nearest multiple of Y -#define ROUND_DOWN(X, Y) (((uint64_t)(X) / (Y)) * (Y)) +#define ROUND_DOWN(X, Y) (((uint64_t) (X) / (Y)) * (Y)) // alignment tests -#define IS_ALIGNED(X, Y) ((uint64_t)(X) % (uint64_t)(Y) == 0) +#define IS_ALIGNED(X, Y) ((uint64_t) (X) % (uint64_t) (Y) == 0) #define IS_512_ALIGNED(X) IS_ALIGNED(X, 512) #define IS_4096_ALIGNED(X) IS_ALIGNED(X, 4096) #define METADATA_SIZE \ @@ -92,8 +97,9 @@ typedef uint16_t _u16; typedef int16_t _s16; typedef uint8_t _u8; typedef int8_t _s8; -inline void open_file_to_write(std::ofstream& writer, - const std::string& filename) { + +inline void open_file_to_write(std::ofstream& writer, + const std::string& filename) { writer.exceptions(std::ofstream::failbit | std::ofstream::badbit); if (!file_exists(filename)) writer.open(filename, std::ios::binary | std::ios::out); @@ -144,6 +150,48 @@ inline int delete_file(const std::string& fileName) { } } +inline void convert_labels_string_to_int(const std::string& inFileName, + const std::string& outFileName, + const std::string& mapFileName, + const std::string& unv_label) { + std::unordered_map string_int_map; + std::ofstream label_writer(outFileName); + std::ifstream label_reader(inFileName); + if (unv_label != "") + string_int_map[unv_label] = 0; + std::string line, token; + while (std::getline(label_reader, line)) { + std::istringstream new_iss(line); + std::vector<_u32> lbls; + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (string_int_map.find(token) == string_int_map.end()) { + _u32 nextId = (_u32) string_int_map.size() + 1; + string_int_map[token] = nextId; + } + lbls.push_back(string_int_map[token]); + } + if (lbls.size() <= 0) { + std::cout << "No label found"; + exit(-1); + } + for (size_t j = 0; j < lbls.size(); j++) { + if (j != lbls.size() - 1) + label_writer << lbls[j] << ","; + else + label_writer << lbls[j] << std::endl; + } + } + label_writer.close(); + + std::ofstream map_writer(mapFileName); + for (auto mp : string_int_map) { + map_writer << mp.first << "\t" << mp.second << std::endl; + } + map_writer.close(); +} + #ifdef EXEC_ENV_OLS class AlignedFileReader; #endif @@ -568,6 +616,19 @@ namespace diskann { } #endif + inline void copy_file(std::string in_file, std::string out_file) { + std::ifstream source(in_file, std::ios::binary); + std::ofstream dest(out_file, std::ios::binary); + + std::istreambuf_iterator begin_source(source); + std::istreambuf_iterator end_source; + std::ostreambuf_iterator begin_dest(dest); + std::copy(begin_source, end_source, begin_dest); + + source.close(); + dest.close(); + } + DISKANN_DLLEXPORT double calculate_recall( unsigned num_queries, unsigned* gold_std, float* gs_dist, unsigned dim_gs, unsigned* our_results, unsigned dim_or, unsigned recall_at); @@ -947,7 +1008,7 @@ inline void normalize(T* arr, size_t dim) { } sum = sqrt(sum); for (uint32_t i = 0; i < dim; i++) { - arr[i] = (T)(arr[i] / sum); + arr[i] = (T) (arr[i] / sum); } } diff --git a/python/src/diskann_bindings.cpp b/python/src/diskann_bindings.cpp index a2a81c8d0..6893bfdcd 100644 --- a/python/src/diskann_bindings.cpp +++ b/python/src/diskann_bindings.cpp @@ -25,13 +25,12 @@ PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); PYBIND11_MAKE_OPAQUE(std::vector); - namespace py = pybind11; using namespace diskann; template struct DiskANNIndex { - PQFlashIndex * pq_flash_index; + PQFlashIndex *pq_flash_index; std::shared_ptr reader; DiskANNIndex(diskann::Metric metric) { @@ -74,8 +73,7 @@ struct DiskANNIndex { const size_t num_nodes_to_cache, int cache_mechanism) { const std::string index_path = index_path_prefix + std::string("_disk.index"); - int load_success = - pq_flash_index->load(num_threads, index_path.c_str()); + int load_success = pq_flash_index->load(num_threads, index_path.c_str()); if (load_success != 0) { return load_success; } @@ -197,8 +195,8 @@ struct DiskANNIndex { const int num_threads) { py::array_t offsets(num_queries + 1); - std::vector > u64_ids(num_queries); - std::vector > dists(num_queries); + std::vector> u64_ids(num_queries); + std::vector> dists(num_queries); auto offsets_mutable = offsets.mutable_unchecked(); offsets_mutable(0) = 0; @@ -246,8 +244,9 @@ struct DiskANNIndex { stats, num_queries, [](const diskann::QueryStats &stats) { return stats.n_cmps; }); delete[] stats; - return std::make_pair(std::make_pair(offsets, std::make_pair(ids, res_dists)), - collective_stats); + return std::make_pair( + std::make_pair(offsets, std::make_pair(ids, res_dists)), + collective_stats); } }; @@ -259,16 +258,15 @@ PYBIND11_MODULE(diskannpy, m) { m.attr("__version__") = "dev"; #endif - py::bind_vector >(m, "VectorUnsigned"); - py::bind_vector >(m, "VectorFloat"); - py::bind_vector >(m, "VectorInt8"); - py::bind_vector >(m, "VectorUInt8"); - + py::bind_vector>(m, "VectorUnsigned"); + py::bind_vector>(m, "VectorFloat"); + py::bind_vector>(m, "VectorInt8"); + py::bind_vector>(m, "VectorUInt8"); py::enum_(m, "Metric") - .value("L2", Metric::L2) - .value("INNER_PRODUCT", Metric::INNER_PRODUCT) - .export_values(); + .value("L2", Metric::L2) + .value("INNER_PRODUCT", Metric::INNER_PRODUCT) + .export_values(); py::class_(m, "Parameters") .def(py::init<>()) @@ -294,9 +292,11 @@ PYBIND11_MODULE(diskannpy, m) { py::class_(m, "AlignedFileReader"); #ifdef _WINDOWS - py::class_(m, "WindowsAlignedFileReader").def(py::init<>()); + py::class_(m, "WindowsAlignedFileReader") + .def(py::init<>()); #else - py::class_(m, "LinuxAlignedFileReader").def(py::init<>()); + py::class_(m, "LinuxAlignedFileReader") + .def(py::init<>()); #endif m.def( @@ -327,7 +327,7 @@ PYBIND11_MODULE(diskannpy, m) { [](const std::string &path, std::vector &ids, std::vector &distances) { unsigned *id_ptr = nullptr; - float * dist_ptr = nullptr; + float *dist_ptr = nullptr; size_t num, dims; load_truthset(path, id_ptr, dist_ptr, num, dims); // TODO: Remove redundant copies. @@ -349,7 +349,7 @@ PYBIND11_MODULE(diskannpy, m) { const unsigned ground_truth_dims, std::vector &results, const unsigned result_dims, const unsigned recall_at) { unsigned *gti_ptr = ground_truth_ids.data(); - float * gtd_ptr = ground_truth_dists.data(); + float *gtd_ptr = ground_truth_dists.data(); unsigned *r_ptr = results.data(); double total_recall = 0; @@ -390,10 +390,10 @@ PYBIND11_MODULE(diskannpy, m) { std::vector &ground_truth_dists, const unsigned ground_truth_dims, py::array_t - & results, + &results, const unsigned result_dims, const unsigned recall_at) { unsigned *gti_ptr = ground_truth_ids.data(); - float * gtd_ptr = ground_truth_dists.data(); + float *gtd_ptr = ground_truth_dists.data(); unsigned *r_ptr = results.mutable_data(); double total_recall = 0; @@ -434,9 +434,9 @@ PYBIND11_MODULE(diskannpy, m) { size_t dims) { save_bin<_u32>(file_name, data.data(), npts, dims); }, py::arg("file_name"), py::arg("data"), py::arg("npts"), py::arg("dims")); - py::class_ >(m, "DiskANNFloatIndex") + py::class_>(m, "DiskANNFloatIndex") .def(py::init([](diskann::Metric metric) { - return std::unique_ptr >( + return std::unique_ptr>( new DiskANNIndex(metric)); })) .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, @@ -462,8 +462,8 @@ PYBIND11_MODULE(diskannpy, m) { .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, py::arg("queries"), py::arg("dim"), py::arg("num_queries"), - py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"), - py::arg("num_threads")) + py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), + py::arg("beam_width"), py::arg("num_threads")) .def( "build", [](DiskANNIndex &self, const char *data_file_path, @@ -485,13 +485,13 @@ PYBIND11_MODULE(diskannpy, m) { py::arg("indexing_ram_limit"), py::arg("num_threads"), py::arg("pq_disk_bytes") = 0); - py::class_ >(m, "DiskANNInt8Index") + py::class_>(m, "DiskANNInt8Index") .def(py::init([](diskann::Metric metric) { - return std::unique_ptr >( + return std::unique_ptr>( new DiskANNIndex(metric)); })) .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, - py::arg("num_nodes_to_cache")) + py::arg("num_nodes_to_cache")) .def("load_index", &DiskANNIndex::load_index, py::arg("index_path_prefix"), py::arg("num_threads"), py::arg("num_nodes_to_cache"), py::arg("cache_mechanism") = 1) @@ -513,8 +513,8 @@ PYBIND11_MODULE(diskannpy, m) { .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, py::arg("queries"), py::arg("dim"), py::arg("num_queries"), - py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"), - py::arg("num_threads")) + py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), + py::arg("beam_width"), py::arg("num_threads")) .def( "build", [](DiskANNIndex &self, const char *data_file_path, @@ -536,10 +536,9 @@ PYBIND11_MODULE(diskannpy, m) { py::arg("indexing_ram_limit"), py::arg("num_threads"), py::arg("pq_disk_bytes") = 0); - - py::class_ >(m, "DiskANNUInt8Index") + py::class_>(m, "DiskANNUInt8Index") .def(py::init([](diskann::Metric metric) { - return std::unique_ptr >( + return std::unique_ptr>( new DiskANNIndex(metric)); })) .def("cache_bfs_levels", &DiskANNIndex::cache_bfs_levels, @@ -565,8 +564,8 @@ PYBIND11_MODULE(diskannpy, m) { .def("batch_range_search_numpy_input", &DiskANNIndex::batch_range_search_numpy_input, py::arg("queries"), py::arg("dim"), py::arg("num_queries"), - py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), py::arg("beam_width"), - py::arg("num_threads")) + py::arg("range"), py::arg("min_list_size"), py::arg("max_list_size"), + py::arg("beam_width"), py::arg("num_threads")) .def( "build", [](DiskANNIndex &self, const char *data_file_path, diff --git a/src/disk_utils.cpp b/src/disk_utils.cpp index 73dd7f449..14aa102ae 100644 --- a/src/disk_utils.cpp +++ b/src/disk_utils.cpp @@ -242,7 +242,8 @@ namespace diskann { const std::string &idmaps_prefix, const std::string &idmaps_suffix, const _u64 nshards, unsigned max_degree, const std::string &output_vamana, - const std::string &medoids_file) { + const std::string &medoids_file, bool use_filters, + const std::string &labels_to_medoids_file) { // Read ID maps std::vector vamana_names(nshards); std::vector> idmaps(nshards); @@ -283,6 +284,57 @@ namespace diskann { }); diskann::cout << "Finished computing node -> shards map" << std::endl; + // will merge all the labels to medoids files of each shard into one + // combined file + if (use_filters) { + std::unordered_map> global_label_to_medoids; + + for (_u64 i = 0; i < nshards; i++) { + std::ifstream mapping_reader; + std::string map_file = vamana_names[i] + "_labels_to_medoids.txt"; + mapping_reader.open(map_file); + + std::string line, token; + unsigned line_cnt = 0; + + while (std::getline(mapping_reader, line)) { + std::istringstream iss(line); + _u32 cnt = 0; + _u32 medoid; + _u32 label; + while (std::getline(iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), + token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), + token.end()); + + unsigned token_as_num = std::stoul(token); + + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + global_label_to_medoids[label].push_back(idmaps[i][medoid]); + line_cnt++; + } + mapping_reader.close(); + } + + std::ofstream mapping_writer(labels_to_medoids_file); + assert(mapping_writer.is_open()); + for (auto iter : global_label_to_medoids) { + mapping_writer << iter.first << ", "; + auto &vec = iter.second; + for (_u32 idx = 0; idx < vec.size() - 1; idx++) { + mapping_writer << vec[idx] << ", "; + } + mapping_writer << vec[vec.size() - 1] << std::endl; + } + mapping_writer.close(); + } + // create cached vamana readers std::vector vamana_readers(nshards); for (_u64 i = 0; i < nshards; i++) { @@ -384,10 +436,16 @@ namespace diskann { } // read from shard_id ifstream vamana_readers[shard_id].read((char *) &shard_nnbrs, sizeof(unsigned)); - std::vector shard_nhood(shard_nnbrs); - vamana_readers[shard_id].read((char *) shard_nhood.data(), - shard_nnbrs * sizeof(unsigned)); + if (shard_nnbrs == 0) { + diskann::cout << "WARNING: shard #" << shard_id << ", node_id " + << node_id << " has 0 nbrs" << std::endl; + } + + std::vector shard_nhood(shard_nnbrs); + if (shard_nnbrs > 0) + vamana_readers[shard_id].read((char *) shard_nhood.data(), + shard_nnbrs * sizeof(unsigned)); // rename nodes for (_u64 j = 0; j < shard_nnbrs; j++) { if (nhood_set[idmaps[shard_id][shard_nhood[j]]] == 0) { @@ -402,8 +460,10 @@ namespace diskann { nnbrs = (unsigned) (std::min)(final_nhood.size(), (uint64_t) max_degree); // write into merged ofstream merged_vamana_writer.write((char *) &nnbrs, sizeof(unsigned)); - merged_vamana_writer.write((char *) final_nhood.data(), - nnbrs * sizeof(unsigned)); + if (nnbrs > 0) { + merged_vamana_writer.write((char *) final_nhood.data(), + nnbrs * sizeof(unsigned)); + } merged_index_size += (sizeof(unsigned) + nnbrs * sizeof(unsigned)); for (auto &p : final_nhood) nhood_set[p] = 0; @@ -418,47 +478,204 @@ namespace diskann { return 0; } + // TODO: Make this a streaming implementation to avoid exceeding the memory + // budget + /* If the number of filters per point N exceeds the graph degree R, + then it is difficult to have edges to all labels from this point. + This function break up such dense points to have only a threshold of maximum + T labels per point  It divides one graph nodes to multiple nodes and append + the new nodes at the end. The dummy map contains the real graph id of the + new nodes added to the graph */ template - int build_merged_vamana_index(std::string base_file, - diskann::Metric compareMetric, unsigned L, - unsigned R, double sampling_rate, - double ram_budget, std::string mem_index_path, - std::string medoids_file, - std::string centroids_file, - size_t build_pq_bytes, bool use_opq) { + void breakup_dense_points(const std::string data_file, + const std::string labels_file, _u32 density, + const std::string out_data_file, + const std::string out_labels_file, + const std::string out_metadata_file) { + std::string token, line; + std::ifstream labels_stream(labels_file); + T *data; + _u64 npts, ndims; + diskann::load_bin(data_file, data, npts, ndims); + + std::unordered_map<_u32, _u32> dummy_pt_ids; + _u32 next_dummy_id = (_u32) npts; + + _u32 point_cnt = 0; + + std::vector> labels_per_point; + labels_per_point.resize(npts); + + _u32 dense_pts = 0; + if (labels_stream.is_open()) { + while (getline(labels_stream, line)) { + std::stringstream iss(line); + _u32 lbl_cnt = 0; + _u32 label_host = point_cnt; + while (getline(iss, token, ',')) { + if (lbl_cnt == density) { + if (label_host == point_cnt) + dense_pts++; + label_host = next_dummy_id; + labels_per_point.resize(next_dummy_id + 1); + dummy_pt_ids[next_dummy_id] = (_u32) point_cnt; + next_dummy_id++; + lbl_cnt = 0; + } + token.erase(std::remove(token.begin(), token.end(), '\n'), + token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), + token.end()); + unsigned token_as_num = std::stoul(token); + labels_per_point[label_host].push_back(token_as_num); + lbl_cnt++; + } + point_cnt++; + } + } + diskann::cout << "fraction of dense points with >= " << density + << " labels = " << (float) dense_pts / (float) npts + << std::endl; + + if (labels_per_point.size() != 0) { + diskann::cout << labels_per_point.size() << " is the new number of points" + << std::endl; + std::ofstream label_writer(out_labels_file); + assert(label_writer.is_open()); + for (_u32 i = 0; i < labels_per_point.size(); i++) { + for (_u32 j = 0; j < (labels_per_point[i].size() - 1); j++) { + label_writer << labels_per_point[i][j] << ","; + } + if (labels_per_point[i].size() != 0) + label_writer << labels_per_point[i][labels_per_point[i].size() - 1]; + label_writer << std::endl; + } + label_writer.close(); + } + + if (dummy_pt_ids.size() != 0) { + diskann::cout << dummy_pt_ids.size() + << " is the number of dummy points created" << std::endl; + data = (T *) std::realloc((void *) data, + labels_per_point.size() * ndims * sizeof(T)); + std::ofstream dummy_writer(out_metadata_file); + assert(dummy_writer.is_open()); + for (auto i = dummy_pt_ids.begin(); i != dummy_pt_ids.end(); i++) { + dummy_writer << i->first << "," << i->second << std::endl; + std::memcpy(data + i->first * ndims, data + i->second * ndims, + ndims * sizeof(T)); + } + dummy_writer.close(); + } + + diskann::save_bin(out_data_file, data, labels_per_point.size(), ndims); + } + + void extract_shard_labels( + const std::string &in_label_file, const std::string &shard_ids_bin, + const std::string &shard_label_file) { // assumes ith row is for ith + // point in labels file + diskann::cout << "Extracting labels for shard" << std::endl; + + _u32 *ids = nullptr; + _u64 num_ids, tmp_dim; + diskann::load_bin(shard_ids_bin, ids, num_ids, tmp_dim); + + _u32 counter = 0, shard_counter = 0; + std::string cur_line; + + std::ifstream label_reader(in_label_file); + std::ofstream label_writer(shard_label_file); + assert(label_reader.is_open()); + assert(label_reader.is_open()); + if (label_reader && label_writer) { + while (std::getline(label_reader, cur_line)) { + if (shard_counter >= num_ids) { + break; + } + if (counter == ids[shard_counter]) { + label_writer << cur_line << "\n"; + shard_counter++; + } + counter++; + } + } + if (ids != nullptr) + delete[] ids; + } + + template + int build_merged_vamana_index( + std::string base_file, diskann::Metric compareMetric, unsigned L, + unsigned R, double sampling_rate, double ram_budget, + std::string mem_index_path, std::string medoids_file, + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf) { size_t base_num, base_dim; diskann::get_bin_metadata(base_file, base_num, base_dim); double full_index_ram = estimate_ram_usage(base_num, base_dim, sizeof(T), R); + + // TODO: Make this honest when there is filter support if (full_index_ram < ram_budget * 1024 * 1024 * 1024) { diskann::cout << "Full index fits in RAM budget, should consume at most " << full_index_ram / (1024 * 1024 * 1024) << "GiBs, so building in one shot" << std::endl; diskann::Parameters paras; paras.Set("L", (unsigned) L); + paras.Set("Lf", (unsigned) Lf); paras.Set("R", (unsigned) R); paras.Set("C", 750); paras.Set("alpha", 1.2f); paras.Set("num_rnds", 2); - paras.Set("saturate_graph", 1); + if (!use_filters) + paras.Set("saturate_graph", 1); + else + paras.Set("saturate_graph", 0); + using TagT = uint32_t; paras.Set("save_path", mem_index_path); + std::unique_ptr> _pvamanaIndex = + std::unique_ptr>( + new diskann::Index( + compareMetric, base_dim, base_num, false, false, false, + build_pq_bytes > 0, build_pq_bytes, use_opq)); + if (!use_filters) + _pvamanaIndex->build(base_file.c_str(), base_num, paras); + else { + if (universal_label != "") { // indicates no universal label + LabelT unv_label_as_num = 0; + _pvamanaIndex->set_universal_label(unv_label_as_num); + } + _pvamanaIndex->build_filtered_index(base_file.c_str(), label_file, + base_num, paras); + } + _pvamanaIndex->save(mem_index_path.c_str()); - std::unique_ptr> _pvamanaIndex = - std::unique_ptr>(new diskann::Index( - compareMetric, base_dim, base_num, false, false, false, - build_pq_bytes > 0, build_pq_bytes, use_opq)); - _pvamanaIndex->build(base_file.c_str(), base_num, paras); + if (use_filters) { + // need to copy the labels_to_medoids file to the specified input file + std::remove(labels_to_medoids_file.c_str()); + std::string mem_labels_to_medoid_file = + mem_index_path + "_labels_to_medoids.txt"; + copy_file(mem_labels_to_medoid_file, labels_to_medoids_file); + std::remove(mem_labels_to_medoid_file.c_str()); + } - _pvamanaIndex->save(mem_index_path.c_str()); std::remove(medoids_file.c_str()); std::remove(centroids_file.c_str()); return 0; } + + // where the universal label is to be saved in the final graph + std::string final_index_universal_label_file = + mem_index_path + "_universal_label.txt"; + std::string merged_index_prefix = mem_index_path + "_tempFiles"; Timer timer; - int num_parts = + int num_parts = partition_with_ram_budget(base_file, sampling_rate, ram_budget, 2 * R / 3, merged_index_prefix, 2); diskann::cout << timer.elapsed_seconds_for_step("partitioning data") @@ -475,6 +692,9 @@ namespace diskann { std::string shard_ids_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; + std::string shard_labels_file = merged_index_prefix + "_subshard-" + + std::to_string(p) + "_labels.txt"; + retrieve_shard_data_from_ids(base_file, shard_ids_file, shard_base_file); @@ -483,6 +703,7 @@ namespace diskann { diskann::Parameters paras; paras.Set("L", L); + paras.Set("Lf", Lf); paras.Set("R", (2 * (R / 3))); paras.Set("C", 750); paras.Set("alpha", 1.2f); @@ -496,17 +717,43 @@ namespace diskann { std::unique_ptr>(new diskann::Index( compareMetric, shard_base_dim, shard_base_pts, false, false, false, build_pq_bytes > 0, build_pq_bytes, use_opq)); - _pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras); + if (!use_filters) { + _pvamanaIndex->build(shard_base_file.c_str(), shard_base_pts, paras); + } else { + diskann::extract_shard_labels(label_file, shard_ids_file, + shard_labels_file); + if (universal_label != "") { // indicates no universal label + LabelT unv_label_as_num = 0; + _pvamanaIndex->set_universal_label(unv_label_as_num); + } + _pvamanaIndex->build_filtered_index( + shard_base_file.c_str(), shard_labels_file, shard_base_pts, paras); + } _pvamanaIndex->save(shard_index_file.c_str()); + // copy universal label file from first shard to the final destination + // index, since all shards anyway share the universal label + if (p == 0) { + std::string shard_universal_label_file = + shard_index_file + "_universal_label.txt"; + if (universal_label != "") { + copy_file(shard_universal_label_file, + final_index_universal_label_file); + } + } + std::remove(shard_base_file.c_str()); } - diskann::cout << timer.elapsed_seconds_for_step("building indices on shards") << std::endl; + diskann::cout << timer.elapsed_seconds_for_step( + "building indices on shards") + << std::endl; timer.reset(); diskann::merge_shards(merged_index_prefix + "_subshard-", "_mem.index", merged_index_prefix + "_subshard-", "_ids_uint32.bin", - num_parts, R, mem_index_path, medoids_file); - diskann::cout << timer.elapsed_seconds_for_step("merging indices") << std::endl; + num_parts, R, mem_index_path, medoids_file, + use_filters, labels_to_medoids_file); + diskann::cout << timer.elapsed_seconds_for_step("merging indices") + << std::endl; // delete tempFiles for (int p = 0; p < num_parts; p++) { @@ -514,6 +761,8 @@ namespace diskann { merged_index_prefix + "_subshard-" + std::to_string(p) + ".bin"; std::string shard_id_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_ids_uint32.bin"; + std::string shard_labels_file = merged_index_prefix + "_subshard-" + + std::to_string(p) + "_labels.txt"; std::string shard_index_file = merged_index_prefix + "_subshard-" + std::to_string(p) + "_mem.index"; std::string shard_index_file_data = shard_index_file + ".data"; @@ -522,6 +771,17 @@ namespace diskann { std::remove(shard_id_file.c_str()); std::remove(shard_index_file.c_str()); std::remove(shard_index_file_data.c_str()); + if (use_filters) { + std::string shard_index_label_file = shard_index_file + "_labels.txt"; + std::string shard_index_univ_label_file = + shard_index_file + "_universal_label.txt"; + std::string shard_index_label_map_file = + shard_index_file + "_labels_to_medoids.txt"; + std::remove(shard_labels_file.c_str()); + std::remove(shard_index_label_file.c_str()); + std::remove(shard_index_label_map_file.c_str()); + std::remove(shard_index_univ_label_file.c_str()); + } } return 0; } @@ -530,11 +790,11 @@ namespace diskann { // optimizes the beamwidth to maximize QPS for a given L_search subject to // 99.9 latency not blowing up - template + template uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, T *tuning_sample, - _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, - uint32_t nthreads, uint32_t start_bw) { + std::unique_ptr> &pFlashIndex, + T *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, + uint32_t L, uint32_t nthreads, uint32_t start_bw) { uint32_t cur_bw = start_bw; double max_qps = 0; uint32_t best_bw = start_bw; @@ -799,10 +1059,13 @@ namespace diskann { << std::endl; } - template + template int build_disk_index(const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, - diskann::Metric compareMetric, bool use_opq) { + diskann::Metric compareMetric, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &universal_label, + const _u32 filter_threshold, const _u32 Lf) { std::stringstream parser; parser << std::string(indexBuildParameters); std::string cur_param; @@ -863,7 +1126,9 @@ namespace diskann { std::string base_file(dataFilePath); std::string data_file_to_use = base_file; + std::string labels_file_original = label_file; std::string index_prefix_path(indexFilePath); + std::string labels_file_to_use = index_prefix_path + "_label_formatted.txt"; std::string pq_pivots_path = index_prefix_path + "_pq_pivots.bin"; std::string pq_compressed_vectors_path = index_prefix_path + "_pq_compressed.bin"; @@ -871,6 +1136,19 @@ namespace diskann { std::string disk_index_path = index_prefix_path + "_disk.index"; std::string medoids_path = disk_index_path + "_medoids.bin"; std::string centroids_path = disk_index_path + "_centroids.bin"; + + std::string labels_to_medoids_path = + disk_index_path + "_labels_to_medoids.txt"; + std::string mem_labels_file = mem_index_path + "_labels.txt"; + std::string disk_labels_file = disk_index_path + "_labels.txt"; + std::string mem_univ_label_file = mem_index_path + "_universal_label.txt"; + std::string disk_univ_label_file = disk_index_path + "_universal_label.txt"; + std::string disk_labels_int_map_file = disk_index_path + "_labels_map.txt"; + std::string dummy_remap_file = + disk_index_path + + "_dummy_remap.txt"; // remap will be used if we break-up points of high + // label-density to create copies + std::string sample_base_prefix = index_prefix_path + "_sample"; // optional, used if disk index file must store pq data std::string disk_pq_pivots_path = @@ -930,6 +1208,27 @@ namespace diskann { auto s = std::chrono::high_resolution_clock::now(); + // If there is filter support, we break-up points which have too many labels + // into replica dummy points which evenly distribute the filters. The rest + // of index build happens on the augmented base and labels + std::string augmented_data_file, augmented_labels_file; + if (use_filters) { + convert_labels_string_to_int(labels_file_original, labels_file_to_use, + disk_labels_int_map_file, universal_label); + augmented_data_file = index_prefix_path + "_augmented_data.bin"; + augmented_labels_file = index_prefix_path + "_augmented_labels.txt"; + if (filter_threshold != 0) { + dummy_remap_file = index_prefix_path + "_dummy_remap.txt"; + breakup_dense_points( + data_file_to_use, labels_file_to_use, filter_threshold, + augmented_data_file, augmented_labels_file, + dummy_remap_file); // RKNOTE: This has large memory footprint, need + // to make this streaming + data_file_to_use = augmented_data_file; + labels_file_to_use = augmented_labels_file; + } + } + size_t points_num, dim; Timer timer; @@ -956,7 +1255,8 @@ namespace diskann { generate_quantized_data(data_file_to_use, pq_pivots_path, pq_compressed_vectors_path, compareMetric, p_val, num_pq_chunks, use_opq); - diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") << std::endl; + diskann::cout << timer.elapsed_seconds_for_step("generating quantized data") + << std::endl; // Gopal. Splitting diskann_dll into separate DLLs for search and build. // This code should only be available in the "build" DLL. @@ -966,10 +1266,11 @@ namespace diskann { #endif timer.reset(); - diskann::build_merged_vamana_index( + diskann::build_merged_vamana_index( data_file_to_use.c_str(), diskann::Metric::L2, L, R, p_val, indexing_ram_budget, mem_index_path, medoids_path, centroids_path, - build_pq_bytes, use_opq); + build_pq_bytes, use_opq, use_filters, labels_file_to_use, + labels_to_medoids_path, universal_label, Lf); diskann::cout << timer.elapsed_seconds_for_step( "building merged vamana index") << std::endl; @@ -997,6 +1298,17 @@ namespace diskann { double sample_sampling_rate = num_sample_points / points_num; gen_random_slice(data_file_to_use.c_str(), sample_base_prefix, sample_sampling_rate); + if (use_filters) { + copy_file(labels_file_to_use, disk_labels_file); + std::remove(mem_labels_file.c_str()); + if (universal_label != "") { + copy_file(mem_univ_label_file, disk_univ_label_file); + std::remove(mem_univ_label_file.c_str()); + } + std::remove(augmented_data_file.c_str()); + std::remove(augmented_labels_file.c_str()); + std::remove(labels_file_to_use.c_str()); + } std::remove(mem_index_path.c_str()); if (use_disk_pq) @@ -1041,48 +1353,123 @@ namespace diskann { uint64_t &warmup_num, uint64_t warmup_dim, uint64_t warmup_aligned_dim); #endif - template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, + int8_t *tuning_sample, _u64 tuning_sample_num, + _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, + uint32_t start_bw); + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, + uint8_t *tuning_sample, _u64 tuning_sample_num, + _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, + uint32_t start_bw); + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, + float *tuning_sample, _u64 tuning_sample_num, + _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, + uint32_t start_bw); + + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, int8_t *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); - template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, uint8_t *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); - template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( - std::unique_ptr> &pFlashIndex, + template DISKANN_DLLEXPORT uint32_t optimize_beamwidth( + std::unique_ptr> &pFlashIndex, float *tuning_sample, _u64 tuning_sample_num, _u64 tuning_sample_aligned_dim, uint32_t L, uint32_t nthreads, uint32_t start_bw); - template DISKANN_DLLEXPORT int build_disk_index( + template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); + template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); + template DISKANN_DLLEXPORT int build_disk_index( const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq); - template DISKANN_DLLEXPORT int build_disk_index( + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); + // LabelT = uint16 + template DISKANN_DLLEXPORT int build_disk_index( const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq); - template DISKANN_DLLEXPORT int build_disk_index( + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); + template DISKANN_DLLEXPORT int build_disk_index( const char *dataFilePath, const char *indexFilePath, const char *indexBuildParameters, diskann::Metric compareMetric, - bool use_opq); + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); + template DISKANN_DLLEXPORT int build_disk_index( + const char *dataFilePath, const char *indexFilePath, + const char *indexBuildParameters, diskann::Metric compareMetric, + bool use_opq, bool use_filters, const std::string &label_file, + const std::string &universal_label, const _u32 filter_threshold, + const _u32 Lf); - template DISKANN_DLLEXPORT int build_merged_vamana_index( + template DISKANN_DLLEXPORT int build_merged_vamana_index( + std::string base_file, diskann::Metric compareMetric, unsigned L, + unsigned R, double sampling_rate, double ram_budget, + std::string mem_index_path, std::string medoids_path, + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); + template DISKANN_DLLEXPORT int build_merged_vamana_index( + std::string base_file, diskann::Metric compareMetric, unsigned L, + unsigned R, double sampling_rate, double ram_budget, + std::string mem_index_path, std::string medoids_path, + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); + template DISKANN_DLLEXPORT int build_merged_vamana_index( + std::string base_file, diskann::Metric compareMetric, unsigned L, + unsigned R, double sampling_rate, double ram_budget, + std::string mem_index_path, std::string medoids_path, + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); + // Label=16_t + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq); - template DISKANN_DLLEXPORT int build_merged_vamana_index( + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq); - template DISKANN_DLLEXPORT int build_merged_vamana_index( + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); + template DISKANN_DLLEXPORT int build_merged_vamana_index( std::string base_file, diskann::Metric compareMetric, unsigned L, unsigned R, double sampling_rate, double ram_budget, std::string mem_index_path, std::string medoids_path, - std::string centroids_file, size_t build_pq_bytes, bool use_opq); + std::string centroids_file, size_t build_pq_bytes, bool use_opq, + bool use_filters, const std::string &label_file, + const std::string &labels_to_medoids_file, + const std::string &universal_label, const _u32 Lf); }; // namespace diskann diff --git a/src/index.cpp b/src/index.cpp index 22f6c8ab0..34d63a478 100644 --- a/src/index.cpp +++ b/src/index.cpp @@ -26,19 +26,20 @@ namespace diskann { // Initialize an index with metric m, load the data of type T with filename // (bin), and initialize max_points - template - Index::Index(Metric m, const size_t dim, const size_t max_points, - const bool dynamic_index, const Parameters &indexParams, - const Parameters &searchParams, const bool enable_tags, - const bool concurrent_consolidate, - const bool pq_dist_build, const size_t num_pq_chunks, - const bool use_opq) + template + Index::Index( + Metric m, const size_t dim, const size_t max_points, + const bool dynamic_index, const Parameters &indexParams, + const Parameters &searchParams, const bool enable_tags, + const bool concurrent_consolidate, const bool pq_dist_build, + const size_t num_pq_chunks, const bool use_opq) : Index(m, dim, max_points, dynamic_index, enable_tags, concurrent_consolidate) { _indexingQueueSize = indexParams.Get("L"); _indexingRange = indexParams.Get("R"); _indexingMaxC = indexParams.Get("C"); _indexingAlpha = indexParams.Get("alpha"); + _filterIndexingQueueSize = indexParams.Get("Lf"); uint32_t num_threads_srch = searchParams.Get("num_threads"); uint32_t num_threads_indx = indexParams.Get("num_threads"); @@ -49,12 +50,14 @@ namespace diskann { _indexingRange, _indexingMaxC, dim); } - template - Index::Index(Metric m, const size_t dim, const size_t max_points, - const bool dynamic_index, const bool enable_tags, - const bool concurrent_consolidate, - const bool pq_dist_build, const size_t num_pq_chunks, - const bool use_opq) + template + Index::Index(Metric m, const size_t dim, + const size_t max_points, + const bool dynamic_index, + const bool enable_tags, + const bool concurrent_consolidate, + const bool pq_dist_build, + const size_t num_pq_chunks, const bool use_opq) : _dist_metric(m), _dim(dim), _max_points(max_points), _dynamic_index(dynamic_index), _enable_tags(enable_tags), _indexingMaxC(DEFAULT_MAXC), _query_scratch(nullptr), @@ -130,8 +133,8 @@ namespace diskann { } } - template - Index::~Index() { + template + Index::~Index() { // Ensure that no other activity is happening before dtor() std::unique_lock ul(_update_lock); std::unique_lock cl(_consolidate_lock); @@ -154,15 +157,16 @@ namespace diskann { delete[] _opt_graph; } - ScratchStoreManager> manager(_query_scratch); - manager.destroy(); + if (!_query_scratch.empty()) { + ScratchStoreManager> manager(_query_scratch); + manager.destroy(); + } } - template - void Index::initialize_query_scratch(uint32_t num_threads, - uint32_t search_l, - uint32_t indexing_l, uint32_t r, - uint32_t maxc, size_t dim) { + template + void Index::initialize_query_scratch( + uint32_t num_threads, uint32_t search_l, uint32_t indexing_l, uint32_t r, + uint32_t maxc, size_t dim) { for (uint32_t i = 0; i < num_threads; i++) { auto scratch = new InMemQueryScratch(search_l, indexing_l, r, maxc, dim, _pq_dist); @@ -170,8 +174,8 @@ namespace diskann { } } - template - _u64 Index::save_tags(std::string tags_file) { + template + _u64 Index::save_tags(std::string tags_file) { if (!_enable_tags) { diskann::cout << "Not saving tags as they are not enabled." << std::endl; return 0; @@ -200,8 +204,8 @@ namespace diskann { return tag_bytes_written; } - template - _u64 Index::save_data(std::string data_file) { + template + _u64 Index::save_data(std::string data_file) { return save_data_in_base_dimensions(data_file, _data, _nd + _num_frozen_pts, _dim, _aligned_dim); } @@ -209,8 +213,8 @@ namespace diskann { // save the graph index on a file as an adjacency list. For each point, // first store the number of neighbors, and then the neighbor list (each as // 4 byte unsigned) - template - _u64 Index::save_graph(std::string graph_file) { + template + _u64 Index::save_graph(std::string graph_file) { std::ofstream out; open_file_to_write(out, graph_file); @@ -239,8 +243,8 @@ namespace diskann { return index_size; // number of bytes written } - template - _u64 Index::save_delete_list(const std::string &filename) { + template + _u64 Index::save_delete_list(const std::string &filename) { if (_delete_set->size() == 0) { return 0; } @@ -253,8 +257,9 @@ namespace diskann { return save_bin<_u32>(filename, delete_list.get(), _delete_set->size(), 1); } - template - void Index::save(const char *filename, bool compact_before_save) { + template + void Index::save(const char *filename, + bool compact_before_save) { diskann::Timer timer; std::unique_lock ul(_update_lock); @@ -274,6 +279,43 @@ namespace diskann { } if (!_save_as_one_file) { + if (_filtered_index) { + if (_label_to_medoid_id.size() > 0) { + std::ofstream medoid_writer(std::string(filename) + + "_labels_to_medoids.txt"); + if (medoid_writer.fail()) { + throw diskann::ANNException( + std::string("Failed to open file ") + filename, -1); + } + for (auto iter : _label_to_medoid_id) { + medoid_writer << iter.first << ", " << iter.second << std::endl; + } + medoid_writer.close(); + } + + if (_use_universal_label) { + std::ofstream universal_label_writer(std::string(filename) + + "_universal_label.txt"); + assert(universal_label_writer.is_open()); + universal_label_writer << _universal_label << std::endl; + universal_label_writer.close(); + } + + if (_pts_to_labels.size() > 0) { + std::ofstream label_writer(std::string(filename) + "_labels.txt"); + assert(label_writer.is_open()); + for (_u32 i = 0; i < _pts_to_labels.size(); i++) { + for (_u32 j = 0; j < (_pts_to_labels[i].size() - 1); j++) { + label_writer << _pts_to_labels[i][j] << ","; + } + if (_pts_to_labels[i].size() != 0) + label_writer << _pts_to_labels[i][_pts_to_labels[i].size() - 1]; + label_writer << std::endl; + } + label_writer.close(); + } + } + std::string graph_file = std::string(filename); std::string tags_file = std::string(filename) + ".tags"; std::string data_file = std::string(filename) + ".data"; @@ -304,11 +346,11 @@ namespace diskann { } #ifdef EXEC_ENV_OLS - template - size_t Index::load_tags(AlignedFileReader &reader) { + template + size_t Index::load_tags(AlignedFileReader &reader) { #else - template - size_t Index::load_tags(const std::string tag_filename) { + template + size_t Index::load_tags(const std::string tag_filename) { if (_enable_tags && !file_exists(tag_filename)) { diskann::cerr << "Tag file provided does not exist!" << std::endl; throw diskann::ANNException("Tag file provided does not exist!", -1, @@ -355,11 +397,11 @@ namespace diskann { return file_num_points; } - template + template #ifdef EXEC_ENV_OLS - size_t Index::load_data(AlignedFileReader &reader) { + size_t Index::load_data(AlignedFileReader &reader) { #else - size_t Index::load_data(std::string filename) { + size_t Index::load_data(std::string filename) { #endif size_t file_dim, file_num_points; #ifdef EXEC_ENV_OLS @@ -406,11 +448,11 @@ namespace diskann { } #ifdef EXEC_ENV_OLS - template - size_t Index::load_delete_set(AlignedFileReader &reader) { + template + size_t Index::load_delete_set(AlignedFileReader &reader) { #else - template - size_t Index::load_delete_set(const std::string &filename) { + template + size_t Index::load_delete_set(const std::string &filename) { #endif std::unique_ptr<_u32[]> delete_list; _u64 npts, ndim; @@ -429,13 +471,13 @@ namespace diskann { // load the index from file and update the max_degree, cur (navigating // node loc), and _final_graph (adjacency list) - template + template #ifdef EXEC_ENV_OLS - void Index::load(AlignedFileReader &reader, uint32_t num_threads, - uint32_t search_l) { + void Index::load(AlignedFileReader &reader, + uint32_t num_threads, uint32_t search_l) { #else - void Index::load(const char *filename, uint32_t num_threads, - uint32_t search_l) { + void Index::load(const char *filename, uint32_t num_threads, + uint32_t search_l) { #endif std::unique_lock ul(_update_lock); std::unique_lock cl(_consolidate_lock); @@ -444,7 +486,13 @@ namespace diskann { _has_built = true; - size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0; + size_t tags_file_num_pts = 0, graph_num_pts = 0, data_file_num_pts = 0, + label_num_pts = 0; + + std::string mem_index_file(filename); + std::string labels_file = mem_index_file + "_labels.txt"; + std::string labels_to_medoids = mem_index_file + "_labels_to_medoids.txt"; + std::string labels_map_file = mem_index_file + "_labels_map.txt"; if (!_save_as_one_file) { // For DLVS Store, we will not support saving the index in multiple files. @@ -484,6 +532,51 @@ namespace diskann { __LINE__); } + if (file_exists(labels_file)) { + _label_map = load_label_map(labels_map_file); + parse_label_file(labels_file, label_num_pts); + assert(label_num_pts == data_file_num_pts); + if (file_exists(labels_to_medoids)) { + std::ifstream medoid_stream(labels_to_medoids); + assert(label_num_pts == data_file_num_pts); + std::string line, token; + unsigned line_cnt = 0; + + _label_to_medoid_id.clear(); + + while (std::getline(medoid_stream, line)) { + std::istringstream iss(line); + _u32 cnt = 0; + _u32 medoid = 0; + LabelT label; + while (std::getline(iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), + token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), + token.end()); + LabelT token_as_num = std::stoul(token); + if (cnt == 0) + label = token_as_num; + else + medoid = token_as_num; + cnt++; + } + _label_to_medoid_id[label] = medoid; + line_cnt++; + } + } + + std::string universal_label_file(filename); + universal_label_file += "_universal_label.txt"; + if (file_exists(universal_label_file)) { + std::ifstream universal_label_reader(universal_label_file); + assert(label_num_pts == data_file_num_pts); + universal_label_reader >> _universal_label; + _use_universal_label = true; + universal_label_reader.close(); + } + } + _nd = data_file_num_pts - _num_frozen_pts; _empty_slots.clear(); _empty_slots.reserve(_max_points); @@ -511,14 +604,14 @@ namespace diskann { } #ifdef EXEC_ENV_OLS - template - size_t Index::load_graph(AlignedFileReader &reader, - size_t expected_num_points) { + template + size_t Index::load_graph(AlignedFileReader &reader, + size_t expected_num_points) { #else - template - size_t Index::load_graph(std::string filename, - size_t expected_num_points) { + template + size_t Index::load_graph(std::string filename, + size_t expected_num_points) { #endif size_t expected_file_size; _u64 file_frozen_pts; @@ -610,13 +703,13 @@ namespace diskann { } } #else - size_t bytes_read = vamana_metadata_size; size_t cc = 0; unsigned nodes_read = 0; while (bytes_read != expected_file_size) { unsigned k; in.read((char *) &k, sizeof(unsigned)); + if (k == 0) { diskann::cerr << "ERROR: Point found with no out-neighbors, point#" << nodes_read << std::endl; @@ -642,8 +735,8 @@ namespace diskann { return nodes_read; } - template - int Index::get_vector_by_tag(TagT &tag, T *vec) { + template + int Index::get_vector_by_tag(TagT &tag, T *vec) { std::shared_lock lock(_tag_lock); if (_tag_to_location.find(tag) == _tag_to_location.end()) { diskann::cout << "Tag " << tag << " does not exist" << std::endl; @@ -656,8 +749,8 @@ namespace diskann { return 0; } - template - unsigned Index::calculate_entry_point() { + template + unsigned Index::calculate_entry_point() { // TODO: need to compute medoid with PQ data too, for now sample at random if (_pq_dist) { size_t r = (size_t) rand() * (size_t) RAND_MAX + (size_t) rand(); @@ -706,11 +799,12 @@ namespace diskann { return min_idx; } - template - std::pair Index::iterate_to_fixed_point( + template + std::pair Index::iterate_to_fixed_point( const T *query, const unsigned Lsize, const std::vector &init_ids, InMemQueryScratch *scratch, - bool ret_frozen, bool search_invocation) { + bool use_filter, const std::vector &filter_label, bool ret_frozen, + bool search_invocation) { std::vector &expanded_nodes = scratch->pool(); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); best_L_nodes.reserve(Lsize); @@ -719,7 +813,7 @@ namespace diskann { boost::dynamic_bitset<> &inserted_into_pool_bs = scratch->inserted_into_pool_bs(); std::vector &id_scratch = scratch->id_scratch(); - std::vector &dist_scratch = scratch->dist_scratch(); + std::vector &dist_scratch = scratch->dist_scratch(); assert(id_scratch.size() == 0); T *aligned_query = scratch->aligned_query(); memcpy(aligned_query, query, _dim * sizeof(T)); @@ -799,6 +893,23 @@ namespace diskann { __FILE__, __LINE__); } + if (use_filter) { + std::vector common_filters; + auto &x = _pts_to_labels[id]; + std::set_intersection(filter_label.begin(), filter_label.end(), + x.begin(), x.end(), + std::back_inserter(common_filters)); + if (_use_universal_label) { + if (std::find(filter_label.begin(), filter_label.end(), + _universal_label) != filter_label.end() || + std::find(x.begin(), x.end(), _universal_label) != x.end()) + common_filters.emplace_back(_universal_label); + } + + if (common_filters.size() == 0) + continue; + } + if (is_not_visited(id)) { if (fast_iterate) { inserted_into_pool_bs[id] = 1; @@ -822,13 +933,24 @@ namespace diskann { uint32_t cmps = 0; while (best_L_nodes.has_unexpanded_node()) { - auto nbr = best_L_nodes.closest_unexpanded(); + auto nbr = best_L_nodes.closest_unexpanded(); auto n = nbr.id; + // Add node to expanded nodes to create pool for prune later if (!search_invocation && (n != _start || _num_frozen_pts == 0 || ret_frozen)) { - expanded_nodes.emplace_back(nbr); + if (!use_filter) { + expanded_nodes.emplace_back(nbr); + } else { // in filter based indexing, the same point might invoke + // multiple iterate_to_fixed_points, so need to be careful + // not to add the same item to pool multiple times. + if (std::find(expanded_nodes.begin(), expanded_nodes.end(), nbr) == + expanded_nodes.end()) { + expanded_nodes.emplace_back(nbr); + } + } } + // Find which of the nodes in des have not been visited before id_scratch.clear(); dist_scratch.clear(); @@ -837,6 +959,25 @@ namespace diskann { _locks[n].lock(); for (auto id : _final_graph[n]) { assert(id < _max_points + _num_frozen_pts); + + if (use_filter) { + // NOTE: NEED TO CHECK IF THIS CORRECT WITH NEW LOCKS. + std::vector common_filters; + auto &x = _pts_to_labels[id]; + std::set_intersection(filter_label.begin(), filter_label.end(), + x.begin(), x.end(), + std::back_inserter(common_filters)); + if (_use_universal_label) { + if (std::find(filter_label.begin(), filter_label.end(), + _universal_label) != filter_label.end() || + std::find(x.begin(), x.end(), _universal_label) != x.end()) + common_filters.emplace_back(_universal_label); + } + + if (common_filters.size() == 0) + continue; + } + if (is_not_visited(id)) { id_scratch.push_back(id); } @@ -871,7 +1012,7 @@ namespace diskann { sizeof(T) * _aligned_dim); } - dist_scratch.push_back( _distance->compare( + dist_scratch.push_back(_distance->compare( aligned_query, _data + _aligned_dim * (size_t) id, (unsigned) _aligned_dim)); } @@ -886,15 +1027,26 @@ namespace diskann { return std::make_pair(hops, cmps); } - template - void Index::search_for_point_and_prune( + template + void Index::search_for_point_and_prune( int location, _u32 Lindex, std::vector &pruned_list, - InMemQueryScratch *scratch) { + InMemQueryScratch *scratch, bool use_filter, _u32 filteredLindex) { std::vector init_ids; init_ids.emplace_back(_start); - iterate_to_fixed_point(_data + _aligned_dim * location, Lindex, init_ids, - scratch, true, false); + std::vector dummy; + + if (!use_filter) { + iterate_to_fixed_point(_data + _aligned_dim * location, Lindex, init_ids, + scratch, false, dummy, true, false); + } else { + std::vector<_u32> filter_specific_start_nodes; + for (auto &x : _pts_to_labels[location]) + filter_specific_start_nodes.emplace_back(_label_to_medoid_id[x]); + iterate_to_fixed_point(_data + _aligned_dim * location, filteredLindex, + filter_specific_start_nodes, scratch, true, + _pts_to_labels[location], true, false); + } auto &pool = scratch->pool(); @@ -916,8 +1068,8 @@ namespace diskann { assert(_final_graph.size() == _max_points + _num_frozen_pts); } - template - void Index::occlude_list( + template + void Index::occlude_list( const unsigned location, std::vector &pool, const float alpha, const unsigned degree, const unsigned maxc, std::vector &result, InMemQueryScratch *scratch, @@ -937,7 +1089,6 @@ namespace diskann { // Initialize occlude_factor to pool.size() many 0.0f values for correctness occlude_factor.insert(occlude_factor.end(), pool.size(), 0.0f); - float cur_alpha = 1; while (cur_alpha <= alpha && result.size() < degree) { // used for MIPS, where we store a value of eps in cur_alpha to @@ -951,7 +1102,8 @@ namespace diskann { } // Set the entry to float::max so that is not considered again occlude_factor[iter - pool.begin()] = std::numeric_limits::max(); - // Add the entry to the result if its not been deleted, and doesn't add a self loop + // Add the entry to the result if its not been deleted, and doesn't add + // a self loop if (delete_set_ptr == nullptr || delete_set_ptr->find(iter->id) == delete_set_ptr->end()) { if (iter->id != location) { @@ -964,6 +1116,23 @@ namespace diskann { auto t = iter2 - pool.begin(); if (occlude_factor[t] > alpha) continue; + + bool prune_allowed = true; + if (_filtered_index) { + _u32 a = iter->id; + _u32 b = iter2->id; + for (auto &x : _pts_to_labels[b]) { + if (std::find(_pts_to_labels[a].begin(), _pts_to_labels[a].end(), + x) == _pts_to_labels[a].end()) { + prune_allowed = false; + } + if (!prune_allowed) + break; + } + } + if (!prune_allowed) + continue; + float djk = _distance->compare(_data + _aligned_dim * (size_t) iter2->id, _data + _aligned_dim * (size_t) iter->id, @@ -987,30 +1156,30 @@ namespace diskann { } } - template - void Index::prune_neighbors(const unsigned location, - std::vector &pool, - std::vector &pruned_list, - InMemQueryScratch *scratch) { + template + void Index::prune_neighbors( + const unsigned location, std::vector &pool, + std::vector &pruned_list, InMemQueryScratch *scratch) { prune_neighbors(location, pool, _indexingRange, _indexingMaxC, _indexingAlpha, pruned_list, scratch); } - template - void Index::prune_neighbors( + template + void Index::prune_neighbors( const unsigned location, std::vector &pool, const _u32 range, const _u32 max_candidate_size, const float alpha, std::vector &pruned_list, InMemQueryScratch *scratch) { if (pool.size() == 0) { - throw diskann::ANNException("Pool passed to prune_neighbors is empty", -1, - __FUNCSIG__, __FILE__, __LINE__); + // if the pool is empty, behave like a noop + pruned_list.clear(); + return; } _max_observed_degree = (std::max)(_max_observed_degree, range); // If using _pq_build, over-write the PQ distances with actual distances if (_pq_dist) { - for (auto& ngh : pool) + for (auto &ngh : pool) ngh.distance = _distance->compare( _data + _aligned_dim * (size_t) ngh.id, _data + _aligned_dim * (size_t) location, (unsigned) _aligned_dim); @@ -1036,11 +1205,11 @@ namespace diskann { } } - template - void Index::inter_insert(unsigned n, - std::vector &pruned_list, - const _u32 range, - InMemQueryScratch *scratch) { + template + void Index::inter_insert(unsigned n, + std::vector &pruned_list, + const _u32 range, + InMemQueryScratch *scratch) { const auto &src_pool = pruned_list; assert(!src_pool.empty()); @@ -1098,15 +1267,15 @@ namespace diskann { } } - template - void Index::inter_insert(unsigned n, - std::vector &pruned_list, - InMemQueryScratch *scratch) { + template + void Index::inter_insert(unsigned n, + std::vector &pruned_list, + InMemQueryScratch *scratch) { inter_insert(n, pruned_list, _indexingRange, scratch); } - template - void Index::link(Parameters ¶meters) { + template + void Index::link(Parameters ¶meters) { unsigned num_threads = parameters.Get("num_threads"); if (num_threads != 0) omp_set_num_threads(num_threads); @@ -1117,6 +1286,7 @@ namespace diskann { omp_set_num_threads(num_threads); _indexingQueueSize = parameters.Get("L"); // Search list size + _filterIndexingQueueSize = parameters.Get("Lf"); _indexingRange = parameters.Get("R"); _indexingMaxC = parameters.Get("C"); _indexingAlpha = parameters.Get("alpha"); @@ -1158,8 +1328,15 @@ namespace diskann { auto scratch = manager.scratch_space(); std::vector pruned_list; - search_for_point_and_prune(node, _indexingQueueSize, pruned_list, - scratch); + if (_filtered_index) { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, + scratch, _filtered_index, + _filterIndexingQueueSize); + + } else { + search_for_point_and_prune(node, _indexingQueueSize, pruned_list, + scratch); + } { LockGuard guard(_locks[node]); _final_graph[node].reserve( @@ -1216,8 +1393,68 @@ namespace diskann { } } - template - void Index::set_start_point(T *data) { + template + void Index::prune_all_nbrs(const Parameters ¶meters) { + const unsigned range = parameters.Get("R"); + const unsigned maxc = parameters.Get("C"); + const float alpha = parameters.Get("alpha"); + _filtered_index = true; + + diskann::Timer timer; +#pragma omp parallel for + for (_s64 node = 0; node < (_s64) (_max_points + _num_frozen_pts); node++) { + if ((size_t) node < _nd || (size_t) node == _max_points) { + if (_final_graph[node].size() > range) { + tsl::robin_set dummy_visited(0); + std::vector dummy_pool(0); + std::vector new_out_neighbors; + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + for (auto cur_nbr : _final_graph[node]) { + if (dummy_visited.find(cur_nbr) == dummy_visited.end() && + cur_nbr != node) { + float dist = + _distance->compare(_data + _aligned_dim * (size_t) node, + _data + _aligned_dim * (size_t) cur_nbr, + (unsigned) _aligned_dim); + dummy_pool.emplace_back(Neighbor(cur_nbr, dist)); + dummy_visited.insert(cur_nbr); + } + } + + prune_neighbors((_u32) node, dummy_pool, range, maxc, alpha, + new_out_neighbors, scratch); + _final_graph[node].clear(); + for (auto id : new_out_neighbors) + _final_graph[node].emplace_back(id); + } + } + } + + diskann::cout << "Prune time : " << timer.elapsed() / 1000 << "ms" + << std::endl; + size_t max = 0, min = 1 << 30, total = 0, cnt = 0; + for (size_t i = 0; i < (_nd + _num_frozen_pts); i++) { + std::vector pool = _final_graph[i]; + max = (std::max)(max, pool.size()); + min = (std::min)(min, pool.size()); + total += pool.size(); + if (pool.size() < 2) + cnt++; + } + if (min > max) + min = max; + if (_nd > 0) { + diskann::cout << "Index built with degree: max:" << max << " avg:" + << (float) total / (float) (_nd + _num_frozen_pts) + << " min:" << min << " count(deg<2):" << cnt << std::endl; + } + } + + template + void Index::set_start_point(T *data) { std::unique_lock ul(_update_lock); std::unique_lock tl(_tag_lock); if (_nd > 0) @@ -1229,8 +1466,8 @@ namespace diskann { diskann::cout << "Index start point set" << std::endl; } - template - void Index::set_start_point_at_random(T radius) { + template + void Index::set_start_point_at_random(T radius) { std::vector real_vec; std::random_device rd{}; std::mt19937 gen{rd()}; @@ -1250,8 +1487,8 @@ namespace diskann { set_start_point(start_vec.data()); } - template - void Index::build_with_data_populated( + template + void Index::build_with_data_populated( Parameters ¶meters, const std::vector &tags) { diskann::cout << "Starting index build with " << _nd << " points... " << std::endl; @@ -1307,10 +1544,11 @@ namespace diskann { _has_built = true; } - template - void Index::build(const T *data, const size_t num_points_to_load, - Parameters ¶meters, - const std::vector &tags) { + template + void Index::build(const T *data, + const size_t num_points_to_load, + Parameters ¶meters, + const std::vector &tags) { if (num_points_to_load == 0) { throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1322,7 +1560,7 @@ namespace diskann { } std::unique_lock ul(_update_lock); - + { std::unique_lock tl(_tag_lock); _nd = num_points_to_load; @@ -1339,11 +1577,11 @@ namespace diskann { build_with_data_populated(parameters, tags); } - template - void Index::build(const char *filename, - const size_t num_points_to_load, - Parameters ¶meters, - const std::vector &tags) { + template + void Index::build(const char *filename, + const size_t num_points_to_load, + Parameters ¶meters, + const std::vector &tags) { std::unique_lock ul(_update_lock); if (num_points_to_load == 0) throw ANNException("Do not call build with 0 points", -1, __FUNCSIG__, @@ -1450,10 +1688,11 @@ namespace diskann { build_with_data_populated(parameters, tags); } - template - void Index::build(const char *filename, - const size_t num_points_to_load, - Parameters ¶meters, const char *tag_filename) { + template + void Index::build(const char *filename, + const size_t num_points_to_load, + Parameters ¶meters, + const char *tag_filename) { std::vector tags; if (_enable_tags) { @@ -1490,13 +1729,164 @@ namespace diskann { build(filename, num_points_to_load, parameters, tags); } - template + template + std::unordered_map + Index::load_label_map(const std::string &labels_map_file) { + std::unordered_map string_to_int_mp; + std::ifstream map_reader(labels_map_file); + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = std::stoul(token); + string_to_int_mp[label_str] = token_as_num; + } + return string_to_int_mp; + } + + template + LabelT Index::get_converted_label( + const std::string &raw_label) { + if (_label_map.find(raw_label) != _label_map.end()) { + return _label_map[raw_label]; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + exit(-1); + } + + template + void Index::parse_label_file(const std::string &label_file, + size_t &num_points) { + // Format of Label txt file: filters with comma separators + + std::ifstream infile(label_file); + if (infile.fail()) { + throw diskann::ANNException( + std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + unsigned line_cnt = 0; + + while (std::getline(infile, line)) { + line_cnt++; + } + _pts_to_labels.resize(line_cnt, std::vector()); + + infile.clear(); + infile.seekg(0, std::ios::beg); + line_cnt = 0; + + while (std::getline(infile, line)) { + std::istringstream iss(line); + std::vector lbls(0); + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = std::stoul(token); + lbls.push_back(token_as_num); + _labels.insert(token_as_num); + } + if (lbls.size() <= 0) { + diskann::cout << "No label found"; + exit(-1); + } + std::sort(lbls.begin(), lbls.end()); + _pts_to_labels[line_cnt] = lbls; + line_cnt++; + } + diskann::cout << "Identified " << _labels.size() << " distinct label(s)" + << std::endl; + } + + template + void Index::set_universal_label(const LabelT &label) { + _use_universal_label = true; + _universal_label = label; + } + + template + void Index::build_filtered_index( + const char *filename, const std::string &label_file, + const size_t num_points_to_load, Parameters ¶meters, + const std::vector &tags) { + _labels_file = label_file; + _filtered_index = true; + _label_to_medoid_id.clear(); + size_t num_points_labels = 0; + parse_label_file( + label_file, + num_points_labels); // determines medoid for each label and + // identifies the points to label mapping + + std::unordered_map> label_to_points; + + for (int lbl = 0; lbl < _labels.size(); lbl++) { + auto itr = _labels.begin(); + std::advance(itr, lbl); + auto &x = *itr; + + std::vector<_u32> labeled_points; + for (_u32 point_id = 0; point_id < num_points_to_load; point_id++) { + bool pt_has_lbl = std::find(_pts_to_labels[point_id].begin(), + _pts_to_labels[point_id].end(), + x) != _pts_to_labels[point_id].end(); + + bool pt_has_univ_lbl = + (_use_universal_label && + (std::find(_pts_to_labels[point_id].begin(), + _pts_to_labels[point_id].end(), + _universal_label) != _pts_to_labels[point_id].end())); + + if (pt_has_lbl || pt_has_univ_lbl) { + labeled_points.emplace_back(point_id); + } + } + label_to_points[x] = labeled_points; + } + + _u32 num_cands = 25; + for (auto itr = _labels.begin(); itr != _labels.end(); itr++) { + _u32 best_medoid_count = std::numeric_limits<_u32>::max(); + auto &curr_label = *itr; + _u32 best_medoid; + auto labeled_points = label_to_points[curr_label]; + for (_u32 cnd = 0; cnd < num_cands; cnd++) { + _u32 cur_cnd = labeled_points[rand() % labeled_points.size()]; + _u32 cur_cnt = std::numeric_limits<_u32>::max(); + if (_medoid_counts.find(cur_cnd) == _medoid_counts.end()) { + _medoid_counts[cur_cnd] = 0; + cur_cnt = 0; + } else { + cur_cnt = _medoid_counts[cur_cnd]; + } + if (cur_cnt < best_medoid_count) { + best_medoid_count = cur_cnt; + best_medoid = cur_cnd; + } + } + _label_to_medoid_id[curr_label] = best_medoid; + _medoid_counts[best_medoid]++; + } + + this->build(filename, num_points_to_load, parameters, tags); + } + + template template - std::pair Index::search(const T *query, - const size_t K, - const unsigned L, - IdType *indices, - float *distances) { + std::pair Index::search( + const T *query, const size_t K, const unsigned L, IdType *indices, + float *distances) { if (K > (uint64_t) L) { throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1514,17 +1904,93 @@ namespace diskann { << scratch->get_L() << std::endl; } + std::vector dummy; std::vector init_ids; init_ids.push_back(_start); std::shared_lock lock(_update_lock); - auto retval = - iterate_to_fixed_point(query, L, init_ids, scratch, true, true); + + auto retval = iterate_to_fixed_point(query, L, init_ids, scratch, false, + dummy, true, true); + NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); size_t pos = 0; for (int i = 0; i < best_L_nodes.size(); ++i) { if (best_L_nodes[i].id < _max_points) { - // safe because Index uses uint32_t ids internally + // safe because Index uses uint32_t ids internally + // and IDType will be uint32_t or uint64_t + indices[pos] = (IdType) best_L_nodes[i].id; + if (distances != nullptr) { +#ifdef EXEC_ENV_OLS + // DLVS expects negative distances + distances[pos] = best_L_nodes[i].distance; +#else + distances[pos] = _dist_metric == diskann::Metric::INNER_PRODUCT + ? -1 * best_L_nodes[i].distance + : best_L_nodes[i].distance; +#endif + } + pos++; + } + if (pos == K) + break; + } + if (pos < K) { + diskann::cerr << "Found fewer than K elements for query" << std::endl; + } + + return retval; + } + + template + template + std::pair Index::search_with_filters( + const T *query, const LabelT &filter_label, const size_t K, + const unsigned L, IdType *indices, float *distances) { + if (K > (uint64_t) L) { + throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, + __FILE__, __LINE__); + } + + ScratchStoreManager> manager(_query_scratch); + auto scratch = manager.scratch_space(); + + if (L > scratch->get_L()) { + diskann::cout << "Attempting to expand query scratch_space. Was created " + << "with Lsize: " << scratch->get_L() + << " but search L is: " << L << std::endl; + scratch->resize_for_new_L(L); + diskann::cout << "Resize completed. New scratch->L is " + << scratch->get_L() << std::endl; + } + + std::vector filter_vec; + std::vector init_ids; + init_ids.push_back(_start); + std::shared_lock lock(_update_lock); + + if (_label_to_medoid_id.find(filter_label) != _label_to_medoid_id.end()) { + init_ids.emplace_back(_label_to_medoid_id[filter_label]); + } else { + diskann::cout + << "No filtered medoid found. exitting " + << std::endl; // RKNOTE: If universal label found start there + throw diskann::ANNException("No filtered medoid found. exitting ", -1); + } + filter_vec.emplace_back(filter_label); + + T *aligned_query = scratch->aligned_query(); + memcpy(aligned_query, query, _dim * sizeof(T)); + + auto retval = iterate_to_fixed_point(aligned_query, L, init_ids, scratch, + true, filter_vec, true, true); + + auto best_L_nodes = scratch->best_l_nodes(); + + size_t pos = 0; + for (int i = 0; i < best_L_nodes.size(); ++i) { + if (best_L_nodes[i].id < _max_points) { + // safe because Index uses uint32_t ids internally // and IDType will be uint32_t or uint64_t indices[pos] = (IdType) best_L_nodes[i].id; if (distances != nullptr) { @@ -1549,11 +2015,10 @@ namespace diskann { return retval; } - template - size_t Index::search_with_tags(const T *query, const uint64_t K, - const unsigned L, TagT *tags, - float *distances, - std::vector &res_vectors) { + template + size_t Index::search_with_tags( + const T *query, const uint64_t K, const unsigned L, TagT *tags, + float *distances, std::vector &res_vectors) { if (K > (uint64_t) L) { throw ANNException("Set L to a value of at least K", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1573,7 +2038,10 @@ namespace diskann { std::shared_lock ul(_update_lock); std::vector init_ids(1, _start); - iterate_to_fixed_point(query, L, init_ids, scratch, true, true); + std::vector dummy; + + iterate_to_fixed_point(query, L, init_ids, scratch, false, dummy, true, + true); NeighborPriorityQueue &best_L_nodes = scratch->best_l_nodes(); assert(best_L_nodes.size() <= L); @@ -1610,20 +2078,20 @@ namespace diskann { return pos; } - template - size_t Index::get_num_points() { + template + size_t Index::get_num_points() { std::shared_lock tl(_tag_lock); return _nd; } - template - size_t Index::get_max_points() { + template + size_t Index::get_max_points() { std::shared_lock tl(_tag_lock); return _max_points; } - template - int Index::generate_frozen_point() { + template + int Index::generate_frozen_point() { if (_num_frozen_pts == 0) return 0; @@ -1647,8 +2115,8 @@ namespace diskann { return 0; } - template - int Index::enable_delete() { + template + int Index::enable_delete() { assert(_enable_tags); if (!_enable_tags) { @@ -1669,8 +2137,8 @@ namespace diskann { return 0; } - template - inline void Index::process_delete( + template + inline void Index::process_delete( const tsl::robin_set &old_delete_set, size_t loc, const unsigned range, const unsigned maxc, const float alpha, InMemQueryScratch *scratch) { @@ -1731,11 +2199,10 @@ namespace diskann { } } } - - + // Returns number of live points left after consolidation - template - consolidation_report Index::consolidate_deletes( + template + consolidation_report Index::consolidate_deletes( const Parameters ¶ms) { if (!_enable_tags) throw diskann::ANNException("Point tag array not instantiated", -1, @@ -1846,8 +2313,8 @@ namespace diskann { num_calls_to_process_delete, duration); } - template - void Index::compact_frozen_point() { + template + void Index::compact_frozen_point() { if (_nd < _max_points) { if (_num_frozen_pts == 1) { // set new _start to be frozen point @@ -1874,8 +2341,8 @@ namespace diskann { } // Should be called after acquiring _update_lock - template - void Index::compact_data() { + template + void Index::compact_data() { if (!_dynamic_index) throw ANNException("Can not compact a non-dynamic index", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -1982,11 +2449,11 @@ namespace diskann { << timer.elapsed() / 1000000. << "s." << std::endl; } - // + // // Caller must hold unique _tag_lock and _delete_lock before calling this // - template - int Index::reserve_location() { + template + int Index::reserve_location() { if (_nd >= _max_points) { return -1; } @@ -2009,8 +2476,8 @@ namespace diskann { return location; } - template - size_t Index::release_location(int location) { + template + size_t Index::release_location(int location) { if (_empty_slots.is_in_set(location)) throw ANNException( "Trying to release location, but location already in empty slots", -1, @@ -2021,8 +2488,8 @@ namespace diskann { return _nd; } - template - size_t Index::release_locations( + template + size_t Index::release_locations( const tsl::robin_set &locations) { for (auto location : locations) { if (_empty_slots.is_in_set(location)) @@ -2041,9 +2508,9 @@ namespace diskann { return _nd; } - template - void Index::reposition_point(unsigned old_location, - unsigned new_location) { + template + void Index::reposition_point(unsigned old_location, + unsigned new_location) { for (unsigned i = 0; i < _nd; i++) for (unsigned j = 0; j < _final_graph[i].size(); j++) if (_final_graph[i][j] == old_location) @@ -2062,8 +2529,8 @@ namespace diskann { sizeof(T) * _aligned_dim); } - template - void Index::reposition_frozen_point_to_end() { + template + void Index::reposition_frozen_point_to_end() { if (_num_frozen_pts == 0) return; @@ -2077,8 +2544,8 @@ namespace diskann { _start = (_u32) _max_points; } - template - void Index::resize(size_t new_max_points) { + template + void Index::resize(size_t new_max_points) { auto start = std::chrono::high_resolution_clock::now(); assert(_empty_slots.size() == 0); // should not resize if there are empty slots. @@ -2113,8 +2580,8 @@ namespace diskann { << std::endl; } - template - int Index::insert_point(const T *point, const TagT tag) { + template + int Index::insert_point(const T *point, const TagT tag) { assert(_has_built); if (tag == static_cast(0)) { throw diskann::ANNException( @@ -2191,8 +2658,13 @@ namespace diskann { ScratchStoreManager> manager(_query_scratch); auto scratch = manager.scratch_space(); std::vector pruned_list; - search_for_point_and_prune(location, _indexingQueueSize, pruned_list, - scratch); + if (_filtered_index) { + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, + scratch, true, _filterIndexingQueueSize); + } else { + search_for_point_and_prune(location, _indexingQueueSize, pruned_list, + scratch); + } { std::shared_lock tlock(_tag_lock, std::defer_lock); @@ -2221,8 +2693,8 @@ namespace diskann { return 0; } - template - int Index::lazy_delete(const TagT &tag) { + template + int Index::lazy_delete(const TagT &tag) { std::shared_lock ul(_update_lock); std::unique_lock tl(_tag_lock); std::unique_lock dl(_delete_lock); @@ -2242,9 +2714,9 @@ namespace diskann { return 0; } - template - void Index::lazy_delete(const std::vector &tags, - std::vector &failed_tags) { + template + void Index::lazy_delete(const std::vector &tags, + std::vector &failed_tags) { if (failed_tags.size() > 0) { throw ANNException("failed_tags should be passed as an empty list", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -2266,13 +2738,14 @@ namespace diskann { } } - template - bool Index::is_index_saved() { + template + bool Index::is_index_saved() { return _is_saved; } - template - void Index::get_active_tags(tsl::robin_set &active_tags) { + template + void Index::get_active_tags( + tsl::robin_set &active_tags) { active_tags.clear(); std::shared_lock tl(_tag_lock); for (auto iter : _tag_to_location) { @@ -2280,8 +2753,8 @@ namespace diskann { } } - template - void Index::print_status() { + template + void Index::print_status() { std::shared_lock ul(_update_lock); std::shared_lock cl(_consolidate_lock); std::shared_lock tl(_tag_lock); @@ -2304,8 +2777,8 @@ namespace diskann { << std::endl; } - template - void Index::count_nodes_at_bfs_levels() { + template + void Index::count_nodes_at_bfs_levels() { std::unique_lock ul(_update_lock); boost::dynamic_bitset<> visited(_max_points + _num_frozen_pts); @@ -2341,8 +2814,9 @@ namespace diskann { delete[] bfs_sets; } - template - void Index::optimize_index_layout() { // use after build or load + template + void + Index::optimize_index_layout() { // use after build or load if (_dynamic_index) { throw diskann::ANNException( "Optimize_index_layout not implemented for dyanmic indices", -1, @@ -2372,10 +2846,10 @@ namespace diskann { _final_graph.shrink_to_fit(); } - template - void Index::search_with_optimized_layout(const T *query, size_t K, - size_t L, - unsigned *indices) { + template + void Index::search_with_optimized_layout(const T *query, + size_t K, size_t L, + unsigned *indices) { DistanceFastL2 *dist_fast = (DistanceFastL2 *) _distance; NeighborPriorityQueue retset(L); @@ -2454,75 +2928,282 @@ namespace diskann { } /* Internals of the library */ - template - const float Index::INDEX_GROWTH_FACTOR = 1.5f; + template + const float Index::INDEX_GROWTH_FACTOR = 1.5f; // EXPORTS - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; - template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + // Label with short int 2 byte + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + template DISKANN_DLLEXPORT class Index; + + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + // TagT==uint32_t + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const float *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const float *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const uint8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const uint8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const int8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const int8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + // TagT==uint32_t + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const float *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const float *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const uint8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const uint8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const int8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search_with_filters( + const int8_t *query, const uint32_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); + + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + // TagT==uint32_t + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const float *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const uint8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint64_t *indices, + float *distances); + template DISKANN_DLLEXPORT std::pair + Index::search(const int8_t *query, + const size_t K, + const unsigned L, + uint32_t *indices, + float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const float *query, const size_t K, - const unsigned L, uint64_t *indices, - float *distances); + Index::search_with_filters( + const float *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const float *query, const size_t K, - const unsigned L, uint32_t *indices, - float *distances); + Index::search_with_filters( + const float *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const uint8_t *query, - const size_t K, const unsigned L, - uint64_t *indices, - float *distances); + Index::search_with_filters( + const uint8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const uint8_t *query, - const size_t K, const unsigned L, - uint32_t *indices, - float *distances); + Index::search_with_filters( + const uint8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const int8_t *query, const size_t K, - const unsigned L, uint64_t *indices, - float *distances); + Index::search_with_filters( + const int8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const int8_t *query, const size_t K, - const unsigned L, uint32_t *indices, - float *distances); + Index::search_with_filters( + const int8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); // TagT==uint32_t template DISKANN_DLLEXPORT std::pair - Index::search(const float *query, const size_t K, - const unsigned L, uint64_t *indices, - float *distances); + Index::search_with_filters( + const float *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const float *query, const size_t K, - const unsigned L, uint32_t *indices, - float *distances); + Index::search_with_filters( + const float *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const uint8_t *query, - const size_t K, const unsigned L, - uint64_t *indices, - float *distances); + Index::search_with_filters( + const uint8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const uint8_t *query, - const size_t K, const unsigned L, - uint32_t *indices, - float *distances); + Index::search_with_filters( + const uint8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const int8_t *query, const size_t K, - const unsigned L, uint64_t *indices, - float *distances); + Index::search_with_filters( + const int8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint64_t *indices, float *distances); template DISKANN_DLLEXPORT std::pair - Index::search(const int8_t *query, const size_t K, - const unsigned L, uint32_t *indices, - float *distances); + Index::search_with_filters( + const int8_t *query, const uint16_t &filter_label, const size_t K, + const unsigned L, uint32_t *indices, float *distances); } // namespace diskann diff --git a/src/partition.cpp b/src/partition.cpp index be6f49a0e..aa2bf1f91 100644 --- a/src/partition.cpp +++ b/src/partition.cpp @@ -29,7 +29,7 @@ // block size for reading/ processing large files and matrices in blocks #define BLOCK_SIZE 5000000 -//#define SAVE_INFLATED_PQ true +// #define SAVE_INFLATED_PQ true template void gen_random_slice(const std::string base_file, @@ -591,9 +591,9 @@ int partition_with_ram_budget(const std::string data_file, train_dim, k_base, cluster_sizes); for (auto &p : cluster_sizes) { - p = (_u64) (p / - sampling_rate); // to account for the fact that p is the size - // of the shard over the testing sample. + // to account for the fact that p is the size of the shard over the + // testing sample. + p = (_u64) (p / sampling_rate); double cur_shard_ram_estimate = diskann::estimate_ram_usage(p, train_dim, sizeof(T), graph_degree); diff --git a/src/pq_flash_index.cpp b/src/pq_flash_index.cpp index e648d75d6..f3a93144a 100644 --- a/src/pq_flash_index.cpp +++ b/src/pq_flash_index.cpp @@ -41,9 +41,9 @@ namespace diskann { - template - PQFlashIndex::PQFlashIndex(std::shared_ptr &fileReader, - diskann::Metric m) + template + PQFlashIndex::PQFlashIndex( + std::shared_ptr &fileReader, diskann::Metric m) : reader(fileReader), metric(m) { if (m == diskann::Metric::COSINE || m == diskann::Metric::INNER_PRODUCT) { if (std::is_floating_point::value) { @@ -63,8 +63,8 @@ namespace diskann { this->dist_cmp_float.reset(diskann::get_distance_function(metric)); } - template - PQFlashIndex::~PQFlashIndex() { + template + PQFlashIndex::~PQFlashIndex() { #ifndef EXEC_ENV_OLS if (data != nullptr) { delete[] data; @@ -86,10 +86,18 @@ namespace diskann { this->reader->deregister_all_threads(); reader->close(); } + if (_pts_to_label_offsets != nullptr) { + delete[] _pts_to_label_offsets; + } + + if (_pts_to_labels != nullptr) { + delete[] _pts_to_labels; + } } - template - void PQFlashIndex::setup_thread_data(_u64 nthreads, _u64 visited_reserve) { + template + void PQFlashIndex::setup_thread_data(_u64 nthreads, + _u64 visited_reserve) { diskann::cout << "Setting up thread-specific contexts for nthreads: " << nthreads << std::endl; // omp parallel for to generate unique thread IDs @@ -107,8 +115,9 @@ namespace diskann { load_flag = true; } - template - void PQFlashIndex::load_cache_list(std::vector &node_list) { + template + void PQFlashIndex::load_cache_list( + std::vector &node_list) { diskann::cout << "Loading the cache list into memory.." << std::flush; _u64 num_cached_nodes = node_list.size(); @@ -180,14 +189,14 @@ namespace diskann { } #ifdef EXEC_ENV_OLS - template - void PQFlashIndex::generate_cache_list_from_sample_queries( + template + void PQFlashIndex::generate_cache_list_from_sample_queries( MemoryMappedFiles &files, std::string sample_bin, _u64 l_search, _u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads, std::vector &node_list) { #else - template - void PQFlashIndex::generate_cache_list_from_sample_queries( + template + void PQFlashIndex::generate_cache_list_from_sample_queries( std::string sample_bin, _u64 l_search, _u64 beamwidth, _u64 num_nodes_to_cache, uint32_t nthreads, std::vector &node_list) { @@ -245,10 +254,10 @@ namespace diskann { diskann::aligned_free(samples); } - template - void PQFlashIndex::cache_bfs_levels(_u64 num_nodes_to_cache, - std::vector &node_list, - const bool shuffle) { + template + void PQFlashIndex::cache_bfs_levels( + _u64 num_nodes_to_cache, std::vector &node_list, + const bool shuffle) { std::random_device rng; std::mt19937 urng(rng()); @@ -379,8 +388,8 @@ namespace diskann { diskann::cout << "done" << std::endl; } - template - void PQFlashIndex::use_medoids_data_as_centroids() { + template + void PQFlashIndex::use_medoids_data_as_centroids() { if (centroid_data != nullptr) aligned_free(centroid_data); alloc_aligned(((void **) ¢roid_data), @@ -425,13 +434,170 @@ namespace diskann { } } + template + inline int32_t PQFlashIndex::get_filter_number( + const LabelT &filter_label) { + int idx = -1; + for (_u32 i = 0; i < _filter_list.size(); i++) { + if (_filter_list[i] == filter_label) { + idx = i; + break; + } + } + return idx; + } + + template + std::unordered_map + PQFlashIndex::load_label_map(const std::string &labels_map_file) { + std::unordered_map string_to_int_mp; + std::ifstream map_reader(labels_map_file); + std::string line, token; + LabelT token_as_num; + std::string label_str; + while (std::getline(map_reader, line)) { + std::istringstream iss(line); + getline(iss, token, '\t'); + label_str = token; + getline(iss, token, '\t'); + token_as_num = std::stoul(token); + string_to_int_mp[label_str] = token_as_num; + } + return string_to_int_mp; + } + + template + LabelT PQFlashIndex::get_converted_label( + const std::string &filter_label) { + if (_label_map.find(filter_label) != _label_map.end()) { + return _label_map[filter_label]; + } + std::stringstream stream; + stream << "Unable to find label in the Label Map"; + diskann::cerr << stream.str() << std::endl; + throw diskann::ANNException(stream.str(), -1, __FUNCSIG__, __FILE__, + __LINE__); + exit(-1); + } + + template + void PQFlashIndex::get_label_file_metadata( + std::string map_file, _u32 &num_pts, _u32 &num_total_labels) { + std::ifstream infile(map_file); + std::string line, token; + num_pts = 0; + num_total_labels = 0; + + while (std::getline(infile, line)) { + std::istringstream iss(line); + while (getline(iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + num_total_labels++; + } + num_pts++; + } + + diskann::cout << "Labels file metadata: num_points: " << num_pts + << ", #total_labels: " << num_total_labels << std::endl; + infile.close(); + } + + template + inline bool PQFlashIndex::point_has_label(_u32 point_id, + _u32 label_id) { + _u32 start_vec = _pts_to_label_offsets[point_id]; + _u32 num_lbls = _pts_to_labels[start_vec]; + bool ret_val = false; + for (_u32 i = 0; i < num_lbls; i++) { + if (_pts_to_labels[start_vec + 1 + i] == label_id) { + ret_val = true; + break; + } + } + return ret_val; + } + + template + void PQFlashIndex::parse_label_file(const std::string &label_file, + size_t &num_points_labels) { + std::ifstream infile(label_file); + if (infile.fail()) { + throw diskann::ANNException( + std::string("Failed to open file ") + label_file, -1); + } + + std::string line, token; + _u32 line_cnt = 0; + + _u32 num_pts_in_label_file; + _u32 num_total_labels; + get_label_file_metadata(label_file, num_pts_in_label_file, + num_total_labels); + + _pts_to_label_offsets = new _u32[num_pts_in_label_file]; + _pts_to_labels = new _u32[num_pts_in_label_file + num_total_labels]; + _u32 counter = 0; + + while (std::getline(infile, line)) { + std::istringstream iss(line); + std::vector<_u32> lbls(0); + + _pts_to_label_offsets[line_cnt] = counter; + _u32 &num_lbls_in_cur_pt = _pts_to_labels[counter]; + num_lbls_in_cur_pt = 0; + counter++; + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + LabelT token_as_num = std::stoul(token); + if (_labels.find(token_as_num) == _labels.end()) { + _filter_list.emplace_back(token_as_num); + } + int32_t filter_num = get_filter_number(token_as_num); + if (filter_num == -1) { + diskann::cout << "Error!! " << std::endl; + exit(-1); + } + _pts_to_labels[counter++] = filter_num; + num_lbls_in_cur_pt++; + _labels.insert(token_as_num); + } + + if (num_lbls_in_cur_pt == 0) { + diskann::cout << "No label found for point " << line_cnt << std::endl; + exit(-1); + } + line_cnt++; + } + infile.close(); + num_points_labels = line_cnt; + } + + template + void PQFlashIndex::set_universal_label(const LabelT &label) { + int32_t temp_filter_num = get_filter_number(label); + if (temp_filter_num == -1) { + diskann::cout << "Error, could not find universal label. Exitting." + << std::endl; + exit(-1); + } else { + _use_universal_label = true; + _universal_filter_num = (_u32) temp_filter_num; + } + } + #ifdef EXEC_ENV_OLS - template - int PQFlashIndex::load(MemoryMappedFiles &files, uint32_t num_threads, - const char *index_prefix) { + template + int PQFlashIndex::load(MemoryMappedFiles &files, + uint32_t num_threads, + const char *index_prefix) { #else - template - int PQFlashIndex::load(uint32_t num_threads, const char *index_prefix) { + template + int PQFlashIndex::load(uint32_t num_threads, + const char *index_prefix) { #endif std::string pq_table_bin = std::string(index_prefix) + "_pq_pivots.bin"; std::string pq_compressed_vectors = @@ -449,14 +615,14 @@ namespace diskann { } #ifdef EXEC_ENV_OLS - template - int PQFlashIndex::load_from_separate_paths( + template + int PQFlashIndex::load_from_separate_paths( diskann::MemoryMappedFiles &files, uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, const char *compressed_filepath) { #else - template - int PQFlashIndex::load_from_separate_paths( + template + int PQFlashIndex::load_from_separate_paths( uint32_t num_threads, const char *index_filepath, const char *pivots_filepath, const char *compressed_filepath) { #endif @@ -467,6 +633,15 @@ namespace diskann { std::string centroids_file = std::string(disk_index_file) + "_centroids.bin"; + std::string labels_file = std ::string(disk_index_file) + "_labels.txt"; + std::string labels_to_medoids = + std ::string(disk_index_file) + "_labels_to_medoids.txt"; + std::string dummy_map_file = + std ::string(disk_index_file) + "_dummy_map.txt"; + std::string labels_map_file = + std ::string(disk_index_file) + "_labels_map.txt"; + size_t num_pts_in_label_file = 0; + size_t pq_file_dim, pq_file_num_centroids; #ifdef EXEC_ENV_OLS get_bin_metadata(files, pq_table_bin, pq_file_num_centroids, pq_file_dim, @@ -503,6 +678,77 @@ namespace diskann { this->num_points = npts_u64; this->n_chunks = nchunks_u64; + if (file_exists(labels_file)) { + parse_label_file(labels_file, num_pts_in_label_file); + assert(num_pts_in_label_file == this->num_points); + _label_map = load_label_map(labels_map_file); + if (file_exists(labels_to_medoids)) { + std::ifstream medoid_stream(labels_to_medoids); + assert(medoid_stream.is_open()); + std::string line, token; + + _filter_to_medoid_id.clear(); + try { + while (std::getline(medoid_stream, line)) { + std::istringstream iss(line); + _u32 cnt = 0; + _u32 medoid; + LabelT label; + while (std::getline(iss, token, ',')) { + if (cnt == 0) + label = std::stoul(token); + else + medoid = (_u32) stoul(token); + cnt++; + } + _filter_to_medoid_id[label] = medoid; + } + } catch (std::system_error &e) { + throw FileException(labels_to_medoids, e, __FUNCSIG__, __FILE__, + __LINE__); + } + } + std::string univ_label_file = + std ::string(disk_index_file) + "_universal_label.txt"; + if (file_exists(univ_label_file)) { + std::ifstream universal_label_reader(univ_label_file); + assert(universal_label_reader.is_open()); + std::string univ_label; + universal_label_reader >> univ_label; + universal_label_reader.close(); + LabelT label_as_num = std::stoul(univ_label); + set_universal_label(label_as_num); + } + if (file_exists(dummy_map_file)) { + std::ifstream dummy_map_stream(dummy_map_file); + assert(dummy_map_stream.is_open()); + std::string line, token; + + while (std::getline(dummy_map_stream, line)) { + std::istringstream iss(line); + _u32 cnt = 0; + _u32 dummy_id; + _u32 real_id; + while (std::getline(iss, token, ',')) { + if (cnt == 0) + dummy_id = (_u32) stoul(token); + else + real_id = (_u32) stoul(token); + cnt++; + } + _dummy_pts.insert(dummy_id); + _has_dummy_pts.insert(real_id); + _dummy_to_real_map[dummy_id] = real_id; + + if (_real_to_dummy_map.find(real_id) == _real_to_dummy_map.end()) + _real_to_dummy_map[real_id] = std::vector<_u32>(); + + _real_to_dummy_map[real_id].emplace_back(dummy_id); + } + dummy_map_stream.close(); + diskann::cout << "Loaded dummy map" << std::endl; + } + } #ifdef EXEC_ENV_OLS pq_table.load_pq_centroid_bin(files, pq_table_bin.c_str(), nchunks_u64); @@ -541,8 +787,8 @@ namespace diskann { disk_bytes_per_point = disk_pq_n_chunks * sizeof(_u8); // revising disk_bytes_per_point since DISK PQ is used. - std::cout << "Disk index uses PQ data compressed down to " - << disk_pq_n_chunks << " bytes per point." << std::endl; + diskann::cout << "Disk index uses PQ data compressed down to " + << disk_pq_n_chunks << " bytes per point." << std::endl; } // read index metadata @@ -704,8 +950,8 @@ namespace diskann { float *norm_val; diskann::load_bin(norm_file, norm_val, dumr, dumc); this->max_base_norm = norm_val[0]; - std::cout << "Setting re-scaling factor of base vectors to " - << this->max_base_norm << std::endl; + diskann::cout << "Setting re-scaling factor of base vectors to " + << this->max_base_norm << std::endl; delete[] norm_val; } diskann::cout << "done.." << std::endl; @@ -742,23 +988,58 @@ namespace diskann { } #endif - template - void PQFlashIndex::cached_beam_search(const T *query1, const _u64 k_search, - const _u64 l_search, _u64 *indices, - float *distances, - const _u64 beam_width, - const bool use_reorder_data, - QueryStats *stats) { + template + void PQFlashIndex::cached_beam_search( + const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices, + float *distances, const _u64 beam_width, const bool use_reorder_data, + QueryStats *stats) { cached_beam_search(query1, k_search, l_search, indices, distances, beam_width, std::numeric_limits<_u32>::max(), use_reorder_data, stats); } - template - void PQFlashIndex::cached_beam_search( + template + void PQFlashIndex::cached_beam_search( + const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices, + float *distances, const _u64 beam_width, const bool use_filter, + const LabelT &filter_label, const bool use_reorder_data, + QueryStats *stats) { + cached_beam_search(query1, k_search, l_search, indices, distances, + beam_width, use_filter, filter_label, + std::numeric_limits<_u32>::max(), use_reorder_data, + stats); + } + + template + void PQFlashIndex::cached_beam_search( const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices, float *distances, const _u64 beam_width, const _u32 io_limit, const bool use_reorder_data, QueryStats *stats) { + LabelT dummy_filter = 0; + cached_beam_search(query1, k_search, l_search, indices, distances, + beam_width, false, dummy_filter, + std::numeric_limits<_u32>::max(), use_reorder_data, + stats); + } + + template + void PQFlashIndex::cached_beam_search( + const T *query1, const _u64 k_search, const _u64 l_search, _u64 *indices, + float *distances, const _u64 beam_width, const bool use_filter, + const LabelT &filter_label, const _u32 io_limit, + const bool use_reorder_data, QueryStats *stats) { + int32_t filter_num = 0; + if (use_filter) { + filter_num = get_filter_number(filter_label); + if (filter_num < 0) { + if (!_use_universal_label) { + return; + } else { + filter_num = _universal_filter_num; + } + } + } + if (beam_width > MAX_N_SECTOR_READS) throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS", -1, __FUNCSIG__, __FILE__, __LINE__); @@ -838,14 +1119,22 @@ namespace diskann { _u32 best_medoid = 0; float best_dist = (std::numeric_limits::max)(); - for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) { - float cur_expanded_dist = dist_cmp_float->compare( - query_float, centroid_data + aligned_dim * cur_m, - (unsigned) aligned_dim); - if (cur_expanded_dist < best_dist) { - best_medoid = medoids[cur_m]; - best_dist = cur_expanded_dist; + if (!use_filter) { + for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) { + float cur_expanded_dist = dist_cmp_float->compare( + query_float, centroid_data + aligned_dim * cur_m, + (unsigned) aligned_dim); + if (cur_expanded_dist < best_dist) { + best_medoid = medoids[cur_m]; + best_dist = cur_expanded_dist; + } } + } else if (_filter_to_medoid_id.find(filter_label) != + _filter_to_medoid_id.end()) { + best_medoid = _filter_to_medoid_id[filter_label]; + } else { + throw ANNException("Cannot find medoid for specified filter.", -1, + __FUNCSIG__, __FILE__, __LINE__); } compute_dists(&best_medoid, 1, dist_scratch); @@ -963,6 +1252,12 @@ namespace diskann { for (_u64 m = 0; m < nnbrs; ++m) { unsigned id = node_nbrs[m]; if (visited.insert(id).second) { + if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + continue; + + if (use_filter && !point_has_label(id, filter_num) && + !point_has_label(id, _universal_filter_num)) + continue; cmps++; float dist = dist_scratch[m]; Neighbor nn(id, dist); @@ -1024,6 +1319,12 @@ namespace diskann { for (_u64 m = 0; m < nnbrs; ++m) { unsigned id = node_nbrs[m]; if (visited.insert(id).second) { + if (!use_filter && _dummy_pts.find(id) != _dummy_pts.end()) + continue; + + if (use_filter && !point_has_label(id, filter_num) && + !point_has_label(id, _universal_filter_num)) + continue; cmps++; float dist = dist_scratch[m]; if (stats != nullptr) { @@ -1096,6 +1397,11 @@ namespace diskann { // copy k_search values for (_u64 i = 0; i < k_search; i++) { indices[i] = full_retset[i].id; + + if (_dummy_pts.find(indices[i]) != _dummy_pts.end()) { + indices[i] = _dummy_to_real_map[indices[i]]; + } + if (distances != nullptr) { distances[i] = full_retset[i].distance; if (metric == diskann::Metric::INNER_PRODUCT) { @@ -1121,14 +1427,12 @@ namespace diskann { // range search returns results of all neighbors within distance of range. // indices and distances need to be pre-allocated of size l_search and the // return value is the number of matching hits. - template - _u32 PQFlashIndex::range_search(const T *query1, const double range, - const _u64 min_l_search, - const _u64 max_l_search, - std::vector<_u64> &indices, - std::vector &distances, - const _u64 min_beam_width, - QueryStats *stats) { + template + _u32 PQFlashIndex::range_search( + const T *query1, const double range, const _u64 min_l_search, + const _u64 max_l_search, std::vector<_u64> &indices, + std::vector &distances, const _u64 min_beam_width, + QueryStats *stats) { _u32 res_count = 0; bool stop_flag = false; @@ -1162,23 +1466,23 @@ namespace diskann { return res_count; } - template - _u64 PQFlashIndex::get_data_dim() { + template + _u64 PQFlashIndex::get_data_dim() { return data_dim; } - template - diskann::Metric PQFlashIndex::get_metric() { + template + diskann::Metric PQFlashIndex::get_metric() { return this->metric; } #ifdef EXEC_ENV_OLS - template - char *PQFlashIndex::getHeaderBytes() { + template + char *PQFlashIndex::getHeaderBytes() { IOContext &ctx = reader->get_ctx(); AlignedRead readReq; - readReq.buf = new char[PQFlashIndex::HEADER_SIZE]; - readReq.len = PQFlashIndex::HEADER_SIZE; + readReq.buf = new char[PQFlashIndex::HEADER_SIZE]; + readReq.len = PQFlashIndex::HEADER_SIZE; readReq.offset = 0; std::vector readReqs; @@ -1194,5 +1498,8 @@ namespace diskann { template class PQFlashIndex<_u8>; template class PQFlashIndex<_s8>; template class PQFlashIndex; + template class PQFlashIndex<_u8, uint16_t>; + template class PQFlashIndex<_s8, uint16_t>; + template class PQFlashIndex; } // namespace diskann diff --git a/src/restapi/search_wrapper.cpp b/src/restapi/search_wrapper.cpp index 624a3bbb5..9b11a1fbb 100644 --- a/src/restapi/search_wrapper.cpp +++ b/src/restapi/search_wrapper.cpp @@ -87,15 +87,11 @@ namespace diskann { } template - InMemorySearch::InMemorySearch( - const std::string& baseFile, - const std::string& indexFile, - const std::string& tagsFile, - Metric m, - uint32_t num_threads, - uint32_t search_l - ): BaseSearch(tagsFile) { - + InMemorySearch::InMemorySearch(const std::string& baseFile, + const std::string& indexFile, + const std::string& tagsFile, Metric m, + uint32_t num_threads, uint32_t search_l) + : BaseSearch(tagsFile) { size_t dimensions, total_points = 0; diskann::get_bin_metadata(baseFile, total_points, dimensions); _index = std::unique_ptr>( @@ -136,9 +132,9 @@ namespace diskann { } template - PQFlashSearch::PQFlashSearch(const std::string& indexPrefix, - const unsigned num_nodes_to_cache, - const unsigned num_threads, + PQFlashSearch::PQFlashSearch(const std::string& indexPrefix, + const unsigned num_nodes_to_cache, + const unsigned num_threads, const std::string& tagsFile, Metric m) : BaseSearch(tagsFile) { #ifdef _WINDOWS @@ -162,7 +158,8 @@ namespace diskann { int res = _index->load(num_threads, index_prefix_path.c_str()); if (res != 0) { - std::cerr << "Unable to load index. Status code: " << res << "." << std::endl; + std::cerr << "Unable to load index. Status code: " << res << "." + << std::endl; } std::vector node_list; @@ -174,13 +171,13 @@ namespace diskann { } template - SearchResult PQFlashSearch::search(const T* query, + SearchResult PQFlashSearch::search(const T* query, const unsigned int dimensions, const unsigned int K, const unsigned int Ls) { - _u64* indices_u64 = new _u64[K]; + _u64* indices_u64 = new _u64[K]; unsigned* indices = new unsigned[K]; - float* distances = new float[K]; + float* distances = new float[K]; auto startTime = std::chrono::high_resolution_clock::now(); _index->cached_beam_search(query, K, Ls, indices_u64, distances, DEFAULT_W); diff --git a/src/restapi/server.cpp b/src/restapi/server.cpp index aec949946..3ba2b6c78 100644 --- a/src/restapi/server.cpp +++ b/src/restapi/server.cpp @@ -62,7 +62,7 @@ namespace diskann { std::vector pos(numsearchers, 0); for (size_t k = 0; k < K; ++k) { - float best_distance = std::numeric_limits::max(); + float best_distance = std::numeric_limits::max(); unsigned best_partition = 0; for (size_t i = 0; i < numsearchers; ++i) { @@ -71,20 +71,23 @@ namespace diskann { best_partition = i; } } - best_distances[k] = best_distance; - best_indices[k] = results[best_partition].get_indices()[pos[best_partition]]; - best_partitions[k] = best_partition; - if (results[best_partition].tags_enabled()) - best_tags[k] = results[best_partition].get_tags()[pos[best_partition]]; - std::cout << best_partition << " " << pos[best_partition] << std::endl; - pos[best_partition]++; + best_distances[k] = best_distance; + best_indices[k] = + results[best_partition].get_indices()[pos[best_partition]]; + best_partitions[k] = best_partition; + if (results[best_partition].tags_enabled()) + best_tags[k] = + results[best_partition].get_tags()[pos[best_partition]]; + std::cout << best_partition << " " << pos[best_partition] << std::endl; + pos[best_partition]++; } unsigned int total_time = 0; for (size_t i = 0; i < numsearchers; ++i) total_time += results[i].get_time(); - diskann::SearchResult result = SearchResult( - K, total_time, best_indices, best_distances, best_tags, best_partitions); + diskann::SearchResult result = + SearchResult(K, total_time, best_indices, best_distances, best_tags, + best_partitions); delete[] best_indices; delete[] best_distances; @@ -101,8 +104,8 @@ namespace diskann { void Server::handle_post(web::http::http_request message) { message.extract_string(true) .then([=](utility::string_t body) { - int64_t queryId = -1; - unsigned int K = 0; + int64_t queryId = -1; + unsigned int K = 0; try { T* queryVector = nullptr; unsigned int dimensions = 0; @@ -113,7 +116,8 @@ namespace diskann { std::vector results; for (auto& searcher : _multi_searcher) - results.push_back(searcher->search(queryVector, dimensions, (unsigned int) K, Ls)); + results.push_back(searcher->search(queryVector, dimensions, + (unsigned int) K, Ls)); diskann::SearchResult result = aggregate_results(K, results); diskann::aligned_free(queryVector); web::json::value response = prepareResponse(queryId, K); @@ -139,13 +143,17 @@ namespace diskann { return std::make_pair(web::http::status_codes::InternalError, response); } catch (...) { - std::cerr << "Uncaught exception while processing query: " << queryId; + std::cerr << "Uncaught exception while processing query: " + << queryId; web::json::value response = prepareResponse(queryId, K); - response[ERROR_MESSAGE_KEY] = web::json::value::string(UNKNOWN_ERROR); - return std::make_pair(web::http::status_codes::InternalError, response); + response[ERROR_MESSAGE_KEY] = + web::json::value::string(UNKNOWN_ERROR); + return std::make_pair(web::http::status_codes::InternalError, + response); } }) - .then([=](std::pair response_status) { + .then([=](std::pair + response_status) { try { message.reply(response_status.first, response_status.second).wait(); } catch (const std::exception& ex) { @@ -180,7 +188,8 @@ namespace diskann { if (k <= 0 || k > Ls) { throw new std::invalid_argument( - "Num of expected NN (k) must be greater than zero and less than or equal to Ls."); + "Num of expected NN (k) must be greater than zero and less than or " + "equal to Ls."); } if (queryArr.size() == 0) { throw new std::invalid_argument("Query vector has zero elements."); diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index c34be5690..6aa5532ef 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -6,6 +6,9 @@ set(CMAKE_CXX_STANDARD 14) add_executable(build_memory_index build_memory_index.cpp) target_link_libraries(build_memory_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) +add_executable(build_stitched_index build_stitched_index.cpp) +target_link_libraries(build_stitched_index ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) + add_executable(search_memory_index search_memory_index.cpp) target_link_libraries(search_memory_index ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS} Boost::program_options) diff --git a/tests/build_disk_index.cpp b/tests/build_disk_index.cpp index 743c067a5..8f8eef925 100644 --- a/tests/build_disk_index.cpp +++ b/tests/build_disk_index.cpp @@ -13,11 +13,12 @@ namespace po = boost::program_options; int main(int argc, char** argv) { - std::string data_type, dist_fn, data_path, index_path_prefix; - unsigned num_threads, R, L, disk_PQ, build_PQ; - float B, M; - bool append_reorder_data = false; - bool use_opq = false; + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, + universal_label, label_type; + unsigned num_threads, R, L, disk_PQ, build_PQ, Lf, filter_threshold; + float B, M; + bool append_reorder_data = false; + bool use_opq = false; po::options_description desc{"Arguments"}; try { @@ -60,9 +61,34 @@ int main(int argc, char** argv) { desc.add_options()( "build_PQ_bytes", po::value(&build_PQ)->default_value(0), "Number of PQ bytes to build the index; 0 for full precision build"); - desc.add_options()("use_opq", po::bool_switch()->default_value(false), "Use Optimized Product Quantization (OPQ)."); + desc.add_options()( + "label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index build ." + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + desc.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, Use only in conjuction with label file for filtered " + "index build. If a graph node has all the labels against it, we can " + "assign a special universal filter to the point instead of comma " + "separated filters for that point"); + desc.add_options()("filtered_Lbuild,Lf", + po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + desc.add_options()( + "filter_threshold,F", + po::value(&filter_threshold)->default_value(0), + "Threshold to break up the existing nodes to generate new graph " + "internally where each node has a maximum F labels."); + desc.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -80,6 +106,11 @@ int main(int argc, char** argv) { return -1; } + bool use_filters = false; + if (label_file != "") { + use_filters = true; + } + diskann::Metric metric; if (dist_fn == std::string("l2")) metric = diskann::Metric::L2; @@ -116,21 +147,46 @@ int main(int argc, char** argv) { std::string(std::to_string(build_PQ)); try { - if (data_type == std::string("int8")) - return diskann::build_disk_index(data_path.c_str(), - index_path_prefix.c_str(), - params.c_str(), metric, use_opq); - else if (data_type == std::string("uint8")) - return diskann::build_disk_index( - data_path.c_str(), index_path_prefix.c_str(), params.c_str(), metric, - use_opq); - else if (data_type == std::string("float")) - return diskann::build_disk_index(data_path.c_str(), - index_path_prefix.c_str(), - params.c_str(), metric, use_opq); - else { - diskann::cerr << "Error. Unsupported data type" << std::endl; - return -1; + if (label_file != "" && label_type == "ushort") { + if (data_type == std::string("int8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } + } else { + if (data_type == std::string("int8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else if (data_type == std::string("uint8")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else if (data_type == std::string("float")) + return diskann::build_disk_index( + data_path.c_str(), index_path_prefix.c_str(), params.c_str(), + metric, use_opq, use_filters, label_file, universal_label, + filter_threshold, Lf); + else { + diskann::cerr << "Error. Unsupported data type" << std::endl; + return -1; + } } } catch (const std::exception& e) { std::cout << std::string(e.what()) << std::endl; diff --git a/tests/build_memory_index.cpp b/tests/build_memory_index.cpp index b830026f0..50b7282c5 100644 --- a/tests/build_memory_index.cpp +++ b/tests/build_memory_index.cpp @@ -20,44 +20,62 @@ namespace po = boost::program_options; -template +template int build_in_memory_index(const diskann::Metric& metric, const std::string& data_path, const unsigned R, const unsigned L, const float alpha, const std::string& save_path, const unsigned num_threads, const bool use_pq_build, - const size_t num_pq_bytes, const bool use_opq) { + const size_t num_pq_bytes, const bool use_opq, + const std::string& label_file, + const std::string& universal_label, const _u32 Lf) { diskann::Parameters paras; paras.Set("R", R); paras.Set("L", L); + paras.Set("Lf", Lf); paras.Set( "C", 750); // maximum candidate set size during pruning procedure paras.Set("alpha", alpha); paras.Set("saturate_graph", 0); paras.Set("num_threads", num_threads); + std::string labels_file_to_use = save_path + "_label_formatted.txt"; + std::string mem_labels_int_map_file = save_path + "_labels_map.txt"; _u64 data_num, data_dim; diskann::get_bin_metadata(data_path, data_num, data_dim); - diskann::Index index(metric, data_dim, data_num, false, false, false, - use_pq_build, num_pq_bytes, use_opq); - auto s = std::chrono::high_resolution_clock::now(); - index.build(data_path.c_str(), data_num, paras); - + diskann::Index index(metric, data_dim, data_num, false, + false, false, use_pq_build, + num_pq_bytes, use_opq); + auto s = std::chrono::high_resolution_clock::now(); + if (label_file == "") { + index.build(data_path.c_str(), data_num, paras); + } else { + convert_labels_string_to_int(label_file, labels_file_to_use, + mem_labels_int_map_file, universal_label); + if (universal_label != "") { + LabelT unv_label_as_num = std::stoul(universal_label); + index.set_universal_label(unv_label_as_num); + } + index.build_filtered_index(data_path.c_str(), labels_file_to_use, data_num, + paras); + } std::chrono::duration diff = std::chrono::high_resolution_clock::now() - s; std::cout << "Indexing time: " << diff.count() << "\n"; index.save(save_path.c_str()); - + if (label_file != "") + std::remove(labels_file_to_use.c_str()); return 0; } int main(int argc, char** argv) { - std::string data_type, dist_fn, data_path, index_path_prefix; - unsigned num_threads, R, L, build_PQ_bytes; - float alpha; - bool use_pq_build, use_opq; + std::string data_type, dist_fn, data_path, index_path_prefix, label_file, + universal_label, label_type; + unsigned num_threads, R, L, Lf, build_PQ_bytes; + float alpha; + bool use_pq_build, use_opq; po::options_description desc{"Arguments"}; try { @@ -89,12 +107,31 @@ int main(int argc, char** argv) { "Number of threads used for building index (defaults to " "omp_get_num_procs())"); desc.add_options()( - "build_PQ_bytes", po::value(&build_PQ_bytes)->default_value(0), + "build_PQ_bytes", + po::value(&build_PQ_bytes)->default_value(0), "Number of PQ bytes to build the index; 0 for full precision build"); desc.add_options()( "use_opq", po::bool_switch()->default_value(false), "Set true for OPQ compression while using PQ distance comparisons for " "building the index, and false for PQ compression"); + desc.add_options()( + "label_file", po::value(&label_file)->default_value(""), + "Input label file in txt format for Filtered Index search. " + "The file should contain comma separated filters for each node " + "with each line corresponding to a graph node"); + desc.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with labels_file"); + desc.add_options()("FilteredLbuild,Lf", + po::value(&Lf)->default_value(0), + "Build complexity for filtered points, higher value " + "results in better graphs"); + desc.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); @@ -128,22 +165,48 @@ int main(int argc, char** argv) { diskann::cout << "Starting index build with R: " << R << " Lbuild: " << L << " alpha: " << alpha << " #threads: " << num_threads << std::endl; - if (data_type == std::string("int8")) - return build_in_memory_index(metric, data_path, R, L, alpha, - index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq); - else if (data_type == std::string("uint8")) - return build_in_memory_index( - metric, data_path, R, L, alpha, index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq); - else if (data_type == std::string("float")) - return build_in_memory_index(metric, data_path, R, L, alpha, - index_path_prefix, num_threads, - use_pq_build, build_PQ_bytes, use_opq); - else { - std::cout << "Unsupported type. Use one of int8, uint8 or float." - << std::endl; - return -1; + if (label_file != "" && label_type == "ushort") { + if (data_type == std::string("int8")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else if (data_type == std::string("uint8")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else if (data_type == std::string("float")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else { + std::cout << "Unsupported type. Use one of int8, uint8 or float." + << std::endl; + return -1; + } + } else { + if (data_type == std::string("int8")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else if (data_type == std::string("uint8")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else if (data_type == std::string("float")) + return build_in_memory_index( + metric, data_path, R, L, alpha, index_path_prefix, num_threads, + use_pq_build, build_PQ_bytes, use_opq, label_file, universal_label, + Lf); + else { + std::cout << "Unsupported type. Use one of int8, uint8 or float." + << std::endl; + return -1; + } } } catch (const std::exception& e) { std::cout << std::string(e.what()) << std::endl; diff --git a/tests/build_stitched_index.cpp b/tests/build_stitched_index.cpp new file mode 100644 index 000000000..fcd2fba7d --- /dev/null +++ b/tests/build_stitched_index.cpp @@ -0,0 +1,855 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include + +#include +#ifndef _WINDOWS +#include +#endif + +#include "index.h" +#include "memory_mapper.h" +#include "parameters.h" +#include "utils.h" + +namespace po = boost::program_options; + +// macros +#define PBSTR "||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||" +#define PBWIDTH 60 + +// custom types (for readability) +typedef tsl::robin_set label_set; +typedef std::string path; + +// structs for returning multiple items from a function +typedef std::tuple, tsl::robin_map, + label_set> + parse_label_file_return_values; +typedef std::tuple>, _u64> + load_label_index_return_values; +typedef std::tuple>, _u64> + stitch_indices_return_values; + +/* + * Inline function to display progress bar. + */ +inline void print_progress(double percentage) { + int val = (int) (percentage * 100); + int lpad = (int) (percentage * PBWIDTH); + int rpad = PBWIDTH - lpad; + printf("\r%3d%% [%.*s%*s]", val, lpad, PBSTR, rpad, ""); + fflush(stdout); +} + +/* + * Inline function to generate a random integer in a range. + */ +inline size_t random(size_t range_from, size_t range_to) { + std::random_device rand_dev; + std::mt19937 generator(rand_dev()); + std::uniform_int_distribution distr(range_from, range_to); + return distr(generator); +} + +/* + * function to handle command line parsing. + * + * Arguments are merely the inputs from the command line. + */ +void handle_args(int argc, char **argv, std::string &data_type, + path &input_data_path, path &final_index_path_prefix, + path &label_data_path, std::string &universal_label, + unsigned &num_threads, unsigned &R, unsigned &L, + unsigned &stitched_R, float &alpha) { + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("data_path", + po::value(&input_data_path)->required(), + "Input data file in bin format"); + desc.add_options()("index_path_prefix", + po::value(&final_index_path_prefix)->required(), + "Path prefix for saving index file components"); + desc.add_options()("max_degree,R", + po::value(&R)->default_value(64), + "Maximum graph degree"); + desc.add_options()( + "Lbuild,L", po::value(&L)->default_value(100), + "Build complexity, higher value results in better graphs"); + desc.add_options()("stitched_R", + po::value(&stitched_R)->default_value(100), + "Degree to prune final graph down to"); + desc.add_options()( + "alpha", po::value(&alpha)->default_value(1.2f), + "alpha controls density and diameter of graph, set 1 for sparse graph, " + "1.2 or 1.4 for denser graphs with lower diameter"); + desc.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + "Number of threads used for building index (defaults to " + "omp_get_num_procs())"); + desc.add_options()("label_file", + po::value(&label_data_path)->default_value(""), + "Input label file in txt format if present"); + desc.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "If a point comes with the specified universal label (and only the " + "univ. " + "label), then the point is considered to have every possible label"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + exit(0); + } + po::notify(vm); + } catch (const std::exception &ex) { + std::cerr << ex.what() << '\n'; + throw; + } +} + +/* + * Parses the label datafile, which has comma-separated labels on + * each line. Line i corresponds to point id i. + * + * Returns three objects via std::tuple: + * 1. map: key is point id, value is vector of labels said point has + * 2. map: key is label, value is number of points with the label + * 3. the label universe as a set + */ +parse_label_file_return_values parse_label_file(path label_data_path, + std::string universal_label) { + std::ifstream label_data_stream(label_data_path); + std::string line, token; + unsigned line_cnt = 0; + + // allows us to reserve space for the points_to_labels vector + while (std::getline(label_data_stream, line)) + line_cnt++; + label_data_stream.clear(); + label_data_stream.seekg(0, std::ios::beg); + + // values to return + std::vector point_ids_to_labels(line_cnt); + tsl::robin_map labels_to_number_of_points; + label_set all_labels; + + std::vector<_u32> points_with_universal_label; + line_cnt = 0; + while (std::getline(label_data_stream, line)) { + std::istringstream current_labels_comma_separated(line); + label_set current_labels; + + // get point id + _u32 point_id = line_cnt; + + // parse comma separated labels + bool current_universal_label_check = false; + while (getline(current_labels_comma_separated, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + + // if token is empty, there's no labels for the point + if (token == universal_label) { + points_with_universal_label.push_back(point_id); + current_universal_label_check = true; + } else { + all_labels.insert(token); + current_labels.insert(token); + labels_to_number_of_points[token]++; + } + } + + if (current_labels.size() <= 0 && !current_universal_label_check) { + std::cerr << "Error: " << point_id << " has no labels." << std::endl; + exit(-1); + } + point_ids_to_labels[point_id] = current_labels; + line_cnt++; + } + + // for every point with universal label, set its label set to all labels + // also, increment the count for number of points a label has + for (const auto &point_id : points_with_universal_label) { + point_ids_to_labels[point_id] = all_labels; + for (const auto &lbl : all_labels) + labels_to_number_of_points[lbl]++; + } + + std::cout << "Identified " << all_labels.size() << " distinct label(s) for " + << point_ids_to_labels.size() << " points\n" + << std::endl; + + return std::make_tuple(point_ids_to_labels, labels_to_number_of_points, + all_labels); +} + +/* + * For each label, generates a file containing all vectors that have said label. + * Also copies data from original bin file to new dimension-aligned file. + * + * Utilizes POSIX functions mmap and writev in order to minimize memory + * overhead, so we include an STL version as well. + * + * Each data file is saved under the following format: + * input_data_path + "_" + label + */ +template +tsl::robin_map> +generate_label_specific_vector_files( + path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) { + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + diskann::MemoryMapper input_data(input_data_path); + char *input_start = input_data.getBuf(); + + _u32 number_of_points, dimension; + std::memcpy(&number_of_points, input_start, sizeof(_u32)); + std::memcpy(&dimension, input_start + sizeof(_u32), sizeof(_u32)); + const _u32 VECTOR_SIZE = dimension * sizeof(T); + const size_t METADATA = 2 * sizeof(_u32); + if (number_of_points != point_ids_to_labels.size()) { + std::cerr << "Error: number of points in labels file and data file differ." + << std::endl; + throw; + } + + tsl::robin_map label_to_iovec_map; + tsl::robin_map label_to_curr_iovec; + tsl::robin_map> label_id_to_orig_id; + + // setup iovec list for each label + for (const auto &lbl : all_labels) { + iovec *label_iovecs = + (iovec *) malloc(labels_to_number_of_points[lbl] * sizeof(iovec)); + if (label_iovecs == nullptr) { + throw; + } + label_to_iovec_map[lbl] = label_iovecs; + label_to_curr_iovec[lbl] = 0; + label_id_to_orig_id[lbl].reserve(labels_to_number_of_points[lbl]); + } + + // each point added to corresponding per-label iovec list + for (_u32 point_id = 0; point_id < number_of_points; point_id++) { + char *curr_point = input_start + METADATA + (VECTOR_SIZE * point_id); + iovec curr_iovec; + + curr_iovec.iov_base = curr_point; + curr_iovec.iov_len = VECTOR_SIZE; + for (const auto &lbl : point_ids_to_labels[point_id]) { + *(label_to_iovec_map[lbl] + label_to_curr_iovec[lbl]) = curr_iovec; + label_to_curr_iovec[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } + } + + // write each label iovec to resp. file + for (const auto &lbl : all_labels) { + int label_input_data_fd; + path curr_label_input_data_path(input_data_path + "_" + lbl); + _u32 curr_num_pts = labels_to_number_of_points[lbl]; + + label_input_data_fd = + open(curr_label_input_data_path.c_str(), + O_CREAT | O_WRONLY | O_TRUNC | O_APPEND, (mode_t) 0644); + if (label_input_data_fd == -1) + throw; + + // write metadata + _u32 metadata[2] = {curr_num_pts, dimension}; + int return_value = write(label_input_data_fd, metadata, sizeof(_u32) * 2); + if (return_value == -1) { + throw; + } + + // limits on number of iovec structs per writev means we need to perform + // multiple writevs + size_t i = 0; + while (curr_num_pts > IOV_MAX) { + return_value = writev(label_input_data_fd, + (label_to_iovec_map[lbl] + (IOV_MAX * i)), IOV_MAX); + if (return_value == -1) { + close(label_input_data_fd); + throw; + } + curr_num_pts -= IOV_MAX; + i += 1; + } + return_value = + writev(label_input_data_fd, (label_to_iovec_map[lbl] + (IOV_MAX * i)), + curr_num_pts); + if (return_value == -1) { + close(label_input_data_fd); + throw; + } + + free(label_to_iovec_map[lbl]); + close(label_input_data_fd); + } + + std::chrono::duration file_writing_time = + std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() + << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; + + return label_id_to_orig_id; +} + +// for use on systems without writev (i.e. Windows) +template +tsl::robin_map> +generate_label_specific_vector_files_compat( + path input_data_path, + tsl::robin_map labels_to_number_of_points, + std::vector point_ids_to_labels, label_set all_labels) { + auto file_writing_timer = std::chrono::high_resolution_clock::now(); + std::ifstream input_data_stream(input_data_path); + + _u32 number_of_points, dimension; + input_data_stream.read((char *) &number_of_points, sizeof(_u32)); + input_data_stream.read((char *) &dimension, sizeof(_u32)); + const _u32 VECTOR_SIZE = dimension * sizeof(T); + if (number_of_points != point_ids_to_labels.size()) { + std::cerr << "Error: number of points in labels file and data file differ." + << std::endl; + throw; + } + + tsl::robin_map labels_to_vectors; + tsl::robin_map labels_to_curr_vector; + tsl::robin_map> label_id_to_orig_id; + + for (const auto &lbl : all_labels) { + _u32 number_of_label_pts = labels_to_number_of_points[lbl]; + char *vectors = (char *) malloc(number_of_label_pts * VECTOR_SIZE); + if (vectors == nullptr) { + throw; + } + labels_to_vectors[lbl] = vectors; + labels_to_curr_vector[lbl] = 0; + label_id_to_orig_id[lbl].reserve(number_of_label_pts); + } + + for (_u32 point_id = 0; point_id < number_of_points; point_id++) { + char *curr_vector = (char *) malloc(VECTOR_SIZE); + input_data_stream.read(curr_vector, VECTOR_SIZE); + for (const auto &lbl : point_ids_to_labels[point_id]) { + char *curr_label_vector_ptr = + labels_to_vectors[lbl] + (labels_to_curr_vector[lbl] * VECTOR_SIZE); + memcpy(curr_label_vector_ptr, curr_vector, VECTOR_SIZE); + labels_to_curr_vector[lbl]++; + label_id_to_orig_id[lbl].push_back(point_id); + } + free(curr_vector); + } + + for (const auto &lbl : all_labels) { + path curr_label_input_data_path(input_data_path + "_" + lbl); + _u32 number_of_label_pts = labels_to_number_of_points[lbl]; + + std::ofstream label_file_stream; + label_file_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_file_stream.open(curr_label_input_data_path, std::ios_base::binary); + label_file_stream.write((char *) &number_of_label_pts, sizeof(_u32)); + label_file_stream.write((char *) &dimension, sizeof(_u32)); + label_file_stream.write((char *) labels_to_vectors[lbl], + number_of_label_pts * VECTOR_SIZE); + + label_file_stream.close(); + free(labels_to_vectors[lbl]); + } + input_data_stream.close(); + + std::chrono::duration file_writing_time = + std::chrono::high_resolution_clock::now() - file_writing_timer; + std::cout << "generated " << all_labels.size() + << " label-specific vector files for index building in time " + << file_writing_time.count() << "\n" + << std::endl; + + return label_id_to_orig_id; +} + +/* + * Using passed in parameters and files generated from step 3, + * builds a vanilla diskANN index for each label. + * + * Each index is saved under the following path: + * final_index_path_prefix + "_" + label + */ +template +void generate_label_indices(path input_data_path, path final_index_path_prefix, + label_set all_labels, unsigned R, unsigned L, + float alpha, unsigned num_threads) { + diskann::Parameters label_index_build_parameters; + label_index_build_parameters.Set("R", R); + label_index_build_parameters.Set("L", L); + label_index_build_parameters.Set("C", 750); + label_index_build_parameters.Set("Lf", 0); + label_index_build_parameters.Set("saturate_graph", 0); + label_index_build_parameters.Set("alpha", alpha); + label_index_build_parameters.Set("num_threads", num_threads); + + std::cout << "Generating indices per label..." << std::endl; + // for each label, build an index on resp. points + double total_indexing_time = 0.0, indexing_percentage = 0.0; + std::cout.setstate(std::ios_base::failbit); + diskann::cout.setstate(std::ios_base::failbit); + for (const auto &lbl : all_labels) { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + + size_t number_of_label_points, dimension; + diskann::get_bin_metadata(curr_label_input_data_path, + number_of_label_points, dimension); + diskann::Index index(diskann::Metric::L2, dimension, + number_of_label_points, false, false); + + auto index_build_timer = std::chrono::high_resolution_clock::now(); + index.build(curr_label_input_data_path.c_str(), number_of_label_points, + label_index_build_parameters); + std::chrono::duration current_indexing_time = + std::chrono::high_resolution_clock::now() - index_build_timer; + + total_indexing_time += current_indexing_time.count(); + indexing_percentage += (1 / (double) all_labels.size()); + print_progress(indexing_percentage); + + index.save(curr_label_index_path.c_str()); + } + std::cout.clear(); + diskann::cout.clear(); + + std::cout << "\nDone. Generated per-label indices in " << total_indexing_time + << " seconds\n" + << std::endl; +} + +/* + * Manually loads a graph index in from a given file. + * + * Returns both the graph index and the size of the file in bytes. + */ +load_label_index_return_values load_label_index(path label_index_path, + _u32 label_number_of_points) { + std::ifstream label_index_stream; + label_index_stream.exceptions(std::ios::badbit | std::ios::failbit); + label_index_stream.open(label_index_path, std::ios::binary); + + _u64 index_file_size, index_num_frozen_points; + _u32 index_max_observed_degree, index_entry_point; + const size_t INDEX_METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); + label_index_stream.read((char *) &index_file_size, sizeof(_u64)); + label_index_stream.read((char *) &index_max_observed_degree, sizeof(_u32)); + label_index_stream.read((char *) &index_entry_point, sizeof(_u32)); + label_index_stream.read((char *) &index_num_frozen_points, sizeof(_u64)); + size_t bytes_read = INDEX_METADATA; + + std::vector> label_index(label_number_of_points); + _u32 nodes_read = 0; + while (bytes_read != index_file_size) { + _u32 current_node_num_neighbors; + label_index_stream.read((char *) ¤t_node_num_neighbors, sizeof(_u32)); + nodes_read++; + + std::vector<_u32> current_node_neighbors(current_node_num_neighbors); + label_index_stream.read((char *) current_node_neighbors.data(), + current_node_num_neighbors * sizeof(_u32)); + label_index[nodes_read - 1].swap(current_node_neighbors); + bytes_read += sizeof(_u32) * (current_node_num_neighbors + 1); + } + + return std::make_tuple(label_index, index_file_size); +} + +/* + * Custom index save to write the in-memory index to disk. + * Also writes required files for diskANN API - + * 1. labels_to_medoids + * 2. universal_label + * 3. data (redundant for static indices) + * 4. labels (redundant for static indices) + */ +void save_full_index(path final_index_path_prefix, path input_data_path, + _u64 final_index_size, + std::vector> stitched_graph, + tsl::robin_map entry_points, + std::string universal_label, path label_data_path) { + // aux. file 1 + auto saving_index_timer = std::chrono::high_resolution_clock::now(); + std::ifstream original_label_data_stream; + original_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_label_data_stream.open(label_data_path, std::ios::binary); + std::ofstream new_label_data_stream; + new_label_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_label_data_stream.open(final_index_path_prefix + "_labels.txt", + std::ios::binary); + new_label_data_stream << original_label_data_stream.rdbuf(); + original_label_data_stream.close(); + new_label_data_stream.close(); + + // aux. file 2 + std::ifstream original_input_data_stream; + original_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + original_input_data_stream.open(input_data_path, std::ios::binary); + std::ofstream new_input_data_stream; + new_input_data_stream.exceptions(std::ios::badbit | std::ios::failbit); + new_input_data_stream.open(final_index_path_prefix + ".data", + std::ios::binary); + new_input_data_stream << original_input_data_stream.rdbuf(); + original_input_data_stream.close(); + new_input_data_stream.close(); + + // aux. file 3 + std::ofstream labels_to_medoids_writer; + labels_to_medoids_writer.exceptions(std::ios::badbit | std::ios::failbit); + labels_to_medoids_writer.open(final_index_path_prefix + + "_labels_to_medoids.txt"); + for (auto iter : entry_points) + labels_to_medoids_writer << iter.first << ", " << iter.second << std::endl; + labels_to_medoids_writer.close(); + + // aux. file 4 (only if we're using a universal label) + if (universal_label != "") { + std::ofstream universal_label_writer; + universal_label_writer.exceptions(std::ios::badbit | std::ios::failbit); + universal_label_writer.open(final_index_path_prefix + + "_universal_label.txt"); + universal_label_writer << universal_label << std::endl; + universal_label_writer.close(); + } + + // main index + _u64 index_num_frozen_points = 0, index_num_edges = 0; + _u32 index_max_observed_degree = 0, index_entry_point = 0; + const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); + for (auto &point_neighbors : stitched_graph) { + index_max_observed_degree = + std::max(index_max_observed_degree, (_u32) point_neighbors.size()); + } + + std::ofstream stitched_graph_writer; + stitched_graph_writer.exceptions(std::ios::badbit | std::ios::failbit); + stitched_graph_writer.open(final_index_path_prefix, std::ios_base::binary); + + stitched_graph_writer.write((char *) &final_index_size, sizeof(_u64)); + stitched_graph_writer.write((char *) &index_max_observed_degree, + sizeof(_u32)); + stitched_graph_writer.write((char *) &index_entry_point, sizeof(_u32)); + stitched_graph_writer.write((char *) &index_num_frozen_points, sizeof(_u64)); + + size_t bytes_written = METADATA; + for (_u32 node_point = 0; node_point < stitched_graph.size(); node_point++) { + _u32 current_node_num_neighbors = stitched_graph[node_point].size(); + std::vector<_u32> current_node_neighbors = stitched_graph[node_point]; + stitched_graph_writer.write((char *) ¤t_node_num_neighbors, + sizeof(_u32)); + bytes_written += sizeof(_u32); + for (const auto ¤t_node_neighbor : current_node_neighbors) { + stitched_graph_writer.write((char *) ¤t_node_neighbor, + sizeof(_u32)); + bytes_written += sizeof(_u32); + } + index_num_edges += current_node_num_neighbors; + } + + if (bytes_written != final_index_size) { + std::cerr << "Error: written bytes does not match allocated space" + << std::endl; + throw; + } + + stitched_graph_writer.close(); + + std::chrono::duration saving_index_time = + std::chrono::high_resolution_clock::now() - saving_index_timer; + std::cout << "Stitched graph written in " << saving_index_time.count() + << " seconds" << std::endl; + std::cout << "Stitched graph average degree: " + << ((float) index_num_edges) / ((float) (stitched_graph.size())) + << std::endl; + std::cout << "Stitched graph max degree: " << index_max_observed_degree + << std::endl + << std::endl; +} + +/* + * Unions the per-label graph indices together via the following policy: + * - any two nodes can only have at most one edge between them - + * + * Returns the "stitched" graph and its expected file size. + */ +template +stitch_indices_return_values stitch_label_indices( + path final_index_path_prefix, _u32 total_number_of_points, + label_set all_labels, + tsl::robin_map labels_to_number_of_points, + tsl::robin_map &label_entry_points, + tsl::robin_map> label_id_to_orig_id_map) { + size_t final_index_size = 0; + std::vector> stitched_graph(total_number_of_points); + + auto stitching_index_timer = std::chrono::high_resolution_clock::now(); + for (const auto &lbl : all_labels) { + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + std::vector> curr_label_index; + _u64 curr_label_index_size; + _u32 curr_label_entry_point; + + std::tie(curr_label_index, curr_label_index_size) = load_label_index( + curr_label_index_path, labels_to_number_of_points[lbl]); + curr_label_entry_point = random(0, curr_label_index.size()); + label_entry_points[lbl] = + label_id_to_orig_id_map[lbl][curr_label_entry_point]; + + for (_u32 node_point = 0; node_point < curr_label_index.size(); + node_point++) { + _u32 original_point_id = label_id_to_orig_id_map[lbl][node_point]; + for (auto &node_neighbor : curr_label_index[node_point]) { + _u32 original_neighbor_id = label_id_to_orig_id_map[lbl][node_neighbor]; + std::vector<_u32> curr_point_neighbors = + stitched_graph[original_point_id]; + if (std::find(curr_point_neighbors.begin(), curr_point_neighbors.end(), + original_neighbor_id) == curr_point_neighbors.end()) { + stitched_graph[original_point_id].push_back(original_neighbor_id); + final_index_size += sizeof(_u32); + } + } + } + } + + const size_t METADATA = 2 * sizeof(_u64) + 2 * sizeof(_u32); + final_index_size += (total_number_of_points * sizeof(_u32) + METADATA); + + std::chrono::duration stitching_index_time = + std::chrono::high_resolution_clock::now() - stitching_index_timer; + std::cout << "stitched graph generated in memory in " + << stitching_index_time.count() << " seconds" << std::endl; + + return std::make_tuple(stitched_graph, final_index_size); +} + +/* + * Applies the prune_neighbors function from src/index.cpp to + * every node in the stitched graph. + * + * This is an optional step, hence the saving of both the full + * and pruned graph. + */ +template +void prune_and_save(path final_index_path_prefix, path full_index_path_prefix, + path input_data_path, + std::vector> stitched_graph, + unsigned stitched_R, + tsl::robin_map label_entry_points, + std::string universal_label, path label_data_path, + unsigned num_threads) { + size_t dimension, number_of_label_points; + auto diskann_cout_buffer = diskann::cout.rdbuf(nullptr); + auto std_cout_buffer = std::cout.rdbuf(nullptr); + auto pruning_index_timer = std::chrono::high_resolution_clock::now(); + + diskann::get_bin_metadata(input_data_path, number_of_label_points, dimension); + diskann::Index index(diskann::Metric::L2, dimension, + number_of_label_points, false, false); + + // not searching this index, set search_l to 0 + index.load(full_index_path_prefix.c_str(), num_threads, 1); + + diskann::Parameters paras; + paras.Set("R", stitched_R); + paras.Set( + "C", 750); // maximum candidate set size during pruning procedure + paras.Set("alpha", 1.2); + paras.Set("saturate_graph", 1); + std::cout << "parsing labels" << std::endl; + + index.prune_all_nbrs(paras); + index.save((final_index_path_prefix).c_str()); + + diskann::cout.rdbuf(diskann_cout_buffer); + std::cout.rdbuf(std_cout_buffer); + std::chrono::duration pruning_index_time = + std::chrono::high_resolution_clock::now() - pruning_index_timer; + std::cout << "pruning performed in " << pruning_index_time.count() + << " seconds\n" + << std::endl; +} + +/* + * Delete all temporary artifacts. + * In the process of creating the stitched index, some temporary artifacts are + * created: + * 1. the separate bin files for each labels' points + * 2. the separate diskANN indices built for each label + * 3. the '.data' file created while generating the indices + */ +void clean_up_artifacts(path input_data_path, path final_index_path_prefix, + label_set all_labels) { + for (const auto &lbl : all_labels) { + path curr_label_input_data_path(input_data_path + "_" + lbl); + path curr_label_index_path(final_index_path_prefix + "_" + lbl); + path curr_label_index_path_data(curr_label_index_path + ".data"); + + if (std::remove(curr_label_index_path.c_str()) != 0) + throw; + if (std::remove(curr_label_input_data_path.c_str()) != 0) + throw; + if (std::remove(curr_label_index_path_data.c_str()) != 0) + throw; + } +} + +int main(int argc, char **argv) { + // 1. handle cmdline inputs + std::string data_type; + path input_data_path, final_index_path_prefix, label_data_path; + std::string universal_label; + unsigned num_threads, R, L, stitched_R; + float alpha; + + auto index_timer = std::chrono::high_resolution_clock::now(); + handle_args(argc, argv, data_type, input_data_path, final_index_path_prefix, + label_data_path, universal_label, num_threads, R, L, stitched_R, + alpha); + + path labels_file_to_use = final_index_path_prefix + "_label_formatted.txt"; + path labels_map_file = final_index_path_prefix + "_labels_map.txt"; + + convert_labels_string_to_int(label_data_path, labels_file_to_use, + labels_map_file, universal_label); + + // 2. parse label file and create necessary data structures + std::vector point_ids_to_labels; + tsl::robin_map labels_to_number_of_points; + label_set all_labels; + + std::tie(point_ids_to_labels, labels_to_number_of_points, all_labels) = + parse_label_file(labels_file_to_use, universal_label); + + // 3. for each label, make a separate data file + tsl::robin_map> label_id_to_orig_id_map; + _u32 total_number_of_points = point_ids_to_labels.size(); + +#ifndef _WINDOWS + if (data_type == "uint8") + label_id_to_orig_id_map = generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = generate_label_specific_vector_files( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else + throw; +#else + if (data_type == "uint8") + label_id_to_orig_id_map = + generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "int8") + label_id_to_orig_id_map = + generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else if (data_type == "float") + label_id_to_orig_id_map = + generate_label_specific_vector_files_compat( + input_data_path, labels_to_number_of_points, point_ids_to_labels, + all_labels); + else + throw; +#endif + + // 4. for each created data file, create a vanilla diskANN index + if (data_type == "uint8") + generate_label_indices(input_data_path, final_index_path_prefix, + all_labels, R, L, alpha, num_threads); + else if (data_type == "int8") + generate_label_indices(input_data_path, final_index_path_prefix, + all_labels, R, L, alpha, num_threads); + else if (data_type == "float") + generate_label_indices(input_data_path, final_index_path_prefix, + all_labels, R, L, alpha, num_threads); + else + throw; + + // 5. "stitch" the indices together + std::vector> stitched_graph; + tsl::robin_map label_entry_points; + _u64 stitched_graph_size; + + if (data_type == "uint8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else if (data_type == "int8") + std::tie(stitched_graph, stitched_graph_size) = + stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else if (data_type == "float") + std::tie(stitched_graph, stitched_graph_size) = stitch_label_indices( + final_index_path_prefix, total_number_of_points, all_labels, + labels_to_number_of_points, label_entry_points, + label_id_to_orig_id_map); + else + throw; + path full_index_path_prefix = final_index_path_prefix + "_full"; + // 5a. save the stitched graph to disk + save_full_index(full_index_path_prefix, input_data_path, stitched_graph_size, + stitched_graph, label_entry_points, universal_label, + labels_file_to_use); + + // 6. run a prune on the stitched index, and save to disk + if (data_type == "uint8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else if (data_type == "int8") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else if (data_type == "float") + prune_and_save(final_index_path_prefix, full_index_path_prefix, + input_data_path, stitched_graph, stitched_R, + label_entry_points, universal_label, + labels_file_to_use, num_threads); + else + throw; + + std::chrono::duration index_time = + std::chrono::high_resolution_clock::now() - index_timer; + std::cout << "pruned/stitched graph generated in " << index_time.count() + << " seconds" << std::endl; + + clean_up_artifacts(input_data_path, final_index_path_prefix, all_labels); +} diff --git a/tests/range_search_disk_index.cpp b/tests/range_search_disk_index.cpp index e37706034..b13d8c65e 100644 --- a/tests/range_search_disk_index.cpp +++ b/tests/range_search_disk_index.cpp @@ -47,7 +47,7 @@ void print_stats(std::string category, std::vector percentiles, diskann::cout << std::endl; } -template +template int search_disk_index(diskann::Metric& metric, const std::string& index_path_prefix, const std::string& query_file, std::string& gt_file, @@ -99,8 +99,8 @@ int search_disk_index(diskann::Metric& metric, reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); @@ -202,8 +202,8 @@ int search_disk_index(diskann::Metric& metric, std::vector<_u64> indices; std::vector distances; _u32 res_count = _pFlashIndex->range_search( - query + (i * query_aligned_dim), search_range, L, max_list_size, - indices, distances, optimized_beamwidth, stats + i); + query + (i * query_aligned_dim), search_range, L, max_list_size, + indices, distances, optimized_beamwidth, stats + i); query_result_ids[test_id][i].reserve(res_count); query_result_ids[test_id][i].resize(res_count); for (_u32 idx = 0; idx < res_count; idx++) diff --git a/tests/restapi/client.cpp b/tests/restapi/client.cpp index 57574d26e..f716db8b8 100644 --- a/tests/restapi/client.cpp +++ b/tests/restapi/client.cpp @@ -12,7 +12,6 @@ #include #include - #include #include @@ -65,47 +64,40 @@ void query_loop(const std::string& ip_addr_port, const std::string& query_file, } int main(int argc, char* argv[]) { - std::string data_type, query_file, address; - uint32_t num_queries; - uint32_t l_search, k_value; + std::string data_type, query_file, address; + uint32_t num_queries; + uint32_t l_search, k_value; - po::options_description desc{ "Arguments" }; - try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("address", - po::value(&address)->required(), - "Web server address"); - desc.add_options()("query_file", - po::value(&query_file)->required(), - "File containing the queries to search"); - desc.add_options()( - "num_queries,Q", - po::value(&num_queries)->required(), - "Number of queries to search"); - desc.add_options()( - "l_search", - po::value(&l_search)->required(), - "Value of L"); - desc.add_options()( - "k_value,K", - po::value(&k_value)->default_value(10), - "Value of K (default 10)"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception& ex) { - std::cerr << ex.what() << std::endl; - return -1; + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("address", po::value(&address)->required(), + "Web server address"); + desc.add_options()("query_file", + po::value(&query_file)->required(), + "File containing the queries to search"); + desc.add_options()("num_queries,Q", + po::value(&num_queries)->required(), + "Number of queries to search"); + desc.add_options()("l_search", po::value(&l_search)->required(), + "Value of L"); + desc.add_options()("k_value,K", + po::value(&k_value)->default_value(10), + "Value of K (default 10)"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - + po::notify(vm); + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + return -1; + } if (data_type == std::string("float")) { query_loop(address, query_file, num_queries, l_search, k_value); diff --git a/tests/restapi/inmem_server.cpp b/tests/restapi/inmem_server.cpp index f7b975994..ca3d43898 100644 --- a/tests/restapi/inmem_server.cpp +++ b/tests/restapi/inmem_server.cpp @@ -9,7 +9,6 @@ #include #include - #include using namespace diskann; @@ -37,99 +36,96 @@ void teardown(const utility::string_t& address) { } int main(int argc, char* argv[]) { - std::string data_type, index_file, data_file, address, dist_fn, tags_file; - uint32_t num_threads; - uint32_t l_search; + std::string data_type, index_file, data_file, address, dist_fn, tags_file; + uint32_t num_threads; + uint32_t l_search; - po::options_description desc{ "Arguments" }; - try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("address", - po::value(&address)->required(), - "Web server address"); - desc.add_options()("data_file", - po::value(&data_file)->required(), - "File containing the data found in the index"); - desc.add_options()("index_path_prefix", - po::value(&index_file)->required(), - "Path prefix for saving index file components"); - desc.add_options()( - "num_threads,T", - po::value(&num_threads)->required(), - "Number of threads used for building index"); - desc.add_options()( - "l_search", - po::value(&l_search)->required(), - "Value of L"); - desc.add_options()("dist_fn", po::value(&dist_fn)->default_value("l2"), - "distance function "); - desc.add_options()("tags_file", - po::value(&tags_file)->default_value(std::string()), - "Tags file location"); - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; - } - po::notify(vm); - } - catch (const std::exception& ex) { - std::cerr << ex.what() << std::endl; - return -1; - } - diskann::Metric metric; - if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; - else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; - else { - std::cout << "Error. Only l2 and mips distance functions are supported" - << std::endl; - return -1; + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("address", po::value(&address)->required(), + "Web server address"); + desc.add_options()("data_file", + po::value(&data_file)->required(), + "File containing the data found in the index"); + desc.add_options()("index_path_prefix", + po::value(&index_file)->required(), + "Path prefix for saving index file components"); + desc.add_options()("num_threads,T", + po::value(&num_threads)->required(), + "Number of threads used for building index"); + desc.add_options()("l_search", po::value(&l_search)->required(), + "Value of L"); + desc.add_options()("dist_fn", + po::value(&dist_fn)->default_value("l2"), + "distance function "); + desc.add_options()( + "tags_file", + po::value(&tags_file)->default_value(std::string()), + "Tags file location"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } + po::notify(vm); + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + return -1; + } + diskann::Metric metric; + if (dist_fn == std::string("l2")) + metric = diskann::Metric::L2; + else if (dist_fn == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else { + std::cout << "Error. Only l2 and mips distance functions are supported" + << std::endl; + return -1; + } - if (data_type == std::string("float")) { - auto searcher = + if (data_type == std::string("float")) { + auto searcher = std::unique_ptr(new diskann::InMemorySearch( data_file, index_file, tags_file, metric, num_threads, l_search)); - g_inMemorySearch.push_back(std::move(searcher)); - } else if (data_type == std::string("int8")) { - auto searcher = std::unique_ptr( - new diskann::InMemorySearch(data_file, index_file, tags_file, - metric, num_threads, l_search)); - g_inMemorySearch.push_back(std::move(searcher)); - } else if (data_type == std::string("uint8")) { - auto searcher = std::unique_ptr( - new diskann::InMemorySearch(data_file, index_file, tags_file, - metric, num_threads, l_search)); - g_inMemorySearch.push_back(std::move(searcher)); - } else { - std::cerr << "Unsupported data type " << argv[2] << std::endl; - } + g_inMemorySearch.push_back(std::move(searcher)); + } else if (data_type == std::string("int8")) { + auto searcher = std::unique_ptr( + new diskann::InMemorySearch(data_file, index_file, tags_file, + metric, num_threads, l_search)); + g_inMemorySearch.push_back(std::move(searcher)); + } else if (data_type == std::string("uint8")) { + auto searcher = std::unique_ptr( + new diskann::InMemorySearch(data_file, index_file, tags_file, + metric, num_threads, l_search)); + g_inMemorySearch.push_back(std::move(searcher)); + } else { + std::cerr << "Unsupported data type " << argv[2] << std::endl; + } - while (1) { - try { - setup(address, data_type); - std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl; - std::string line; - std::getline(std::cin, line); - if (line == "exit") { - teardown(address); - g_httpServer->close().wait(); - exit(0); - } - } catch (const std::exception& ex) { - std::cerr << "Exception occurred: " << ex.what() << std::endl; - std::cerr << "Restarting HTTP server"; - teardown(address); - } catch (...) { - std::cerr << "Unknown exception occurreed" << std::endl; - std::cerr << "Restarting HTTP server"; - teardown(address); - } + while (1) { + try { + setup(address, data_type); + std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl; + std::string line; + std::getline(std::cin, line); + if (line == "exit") { + teardown(address); + g_httpServer->close().wait(); + exit(0); + } + } catch (const std::exception& ex) { + std::cerr << "Exception occurred: " << ex.what() << std::endl; + std::cerr << "Restarting HTTP server"; + teardown(address); + } catch (...) { + std::cerr << "Unknown exception occurreed" << std::endl; + std::cerr << "Restarting HTTP server"; + teardown(address); } + } } diff --git a/tests/restapi/multiple_ssdindex_server.cpp b/tests/restapi/multiple_ssdindex_server.cpp index 571ab3aa6..2cc36fcd4 100644 --- a/tests/restapi/multiple_ssdindex_server.cpp +++ b/tests/restapi/multiple_ssdindex_server.cpp @@ -37,134 +37,135 @@ void teardown(const utility::string_t& address) { } int main(int argc, char* argv[]) { - std::string data_type, index_prefix_paths, address, dist_fn, tags_file; - uint32_t num_nodes_to_cache; - uint32_t num_threads; - - po::options_description desc{ "Arguments" }; - try { - desc.add_options()("help,h", "Print information on arguments"); - desc.add_options()("address", - po::value(&address)->required(), - "Web server address"); - desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("index_prefix_paths", - po::value(&index_prefix_paths)->required(), - "Path prefix for loading index file components"); - desc.add_options()( - "num_nodes_to_cache", - po::value(&num_nodes_to_cache)->default_value(0), - "Number of nodes to cache during search"); - desc.add_options()( - "num_threads,T", - po::value(&num_threads)->default_value(omp_get_num_procs()), - "Number of threads used for building index (defaults to " - "omp_get_num_procs())"); - desc.add_options()("dist_fn", po::value(&dist_fn)->default_value("l2"), - "distance function "); - desc.add_options()("tags_file", - po::value(&tags_file)->default_value(std::string()), - "Tags file location"); - - po::variables_map vm; - po::store(po::parse_command_line(argc, argv, desc), vm); - if (vm.count("help")) { - std::cout << desc; - return 0; - } - po::notify(vm); + std::string data_type, index_prefix_paths, address, dist_fn, tags_file; + uint32_t num_nodes_to_cache; + uint32_t num_threads; + + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("address", po::value(&address)->required(), + "Web server address"); + desc.add_options()("data_type", + po::value(&data_type)->required(), + "data type "); + desc.add_options()("index_prefix_paths", + po::value(&index_prefix_paths)->required(), + "Path prefix for loading index file components"); + desc.add_options()( + "num_nodes_to_cache", + po::value(&num_nodes_to_cache)->default_value(0), + "Number of nodes to cache during search"); + desc.add_options()( + "num_threads,T", + po::value(&num_threads)->default_value(omp_get_num_procs()), + "Number of threads used for building index (defaults to " + "omp_get_num_procs())"); + desc.add_options()("dist_fn", + po::value(&dist_fn)->default_value("l2"), + "distance function "); + desc.add_options()( + "tags_file", + po::value(&tags_file)->default_value(std::string()), + "Tags file location"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; } - catch (const std::exception& ex) { - std::cerr << ex.what() << std::endl; - return -1; + po::notify(vm); + } catch (const std::exception& ex) { + std::cerr << ex.what() << std::endl; + return -1; + } + + diskann::Metric metric; + if (dist_fn == std::string("l2")) + metric = diskann::Metric::L2; + else if (dist_fn == std::string("mips")) + metric = diskann::Metric::INNER_PRODUCT; + else { + std::cout << "Error. Only l2 and mips distance functions are supported" + << std::endl; + return -1; + } + + std::vector> index_tag_paths; + std::ifstream index_in(index_prefix_paths); + if (!index_in.is_open()) { + std::cerr << "Could not open " << index_prefix_paths << std::endl; + exit(-1); + } + std::ifstream tags_in(tags_file); + if (!tags_in.is_open()) { + std::cerr << "Could not open " << tags_file << std::endl; + exit(-1); + } + std::string prefix, tagfile; + while (std::getline(index_in, prefix)) { + if (std::getline(tags_in, tagfile)) { + index_tag_paths.push_back(std::make_pair(prefix, tagfile)); + } else { + std::cerr << "The number of tags specified does not match the number of " + "indices specified" + << std::endl; + exit(-1); } - - diskann::Metric metric; - if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; - else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; - else { - std::cout << "Error. Only l2 and mips distance functions are supported" - << std::endl; - return -1; + } + index_in.close(); + tags_in.close(); + + if (data_type == std::string("float")) { + for (auto& index_tag : index_tag_paths) { + auto searcher = std::unique_ptr( + new diskann::PQFlashSearch(index_tag.first.c_str(), + num_nodes_to_cache, num_threads, + index_tag.second.c_str(), metric)); + g_ssdSearch.push_back(std::move(searcher)); } - std::vector> index_tag_paths; - std::ifstream index_in(index_prefix_paths); - if (!index_in.is_open()) { - std::cerr << "Could not open " << index_prefix_paths << std::endl; - exit(-1); - } - std::ifstream tags_in(tags_file); - if (!tags_in.is_open()) { - std::cerr << "Could not open " << tags_file << std::endl; - exit(-1); + } else if (data_type == std::string("int8")) { + for (auto& index_tag : index_tag_paths) { + auto searcher = std::unique_ptr( + new diskann::PQFlashSearch(index_tag.first.c_str(), + num_nodes_to_cache, num_threads, + index_tag.second.c_str(), metric)); + g_ssdSearch.push_back(std::move(searcher)); } - std::string prefix, tagfile; - while (std::getline(index_in, prefix)) { - if (std::getline(tags_in, tagfile)) { - index_tag_paths.push_back(std::make_pair(prefix, tagfile)); - } else { - std::cerr << "The number of tags specified does not match the number of " - "indices specified" << std::endl; - exit(-1); - } - } - index_in.close(); - tags_in.close(); - - if (data_type == std::string("float")) { - for (auto& index_tag : index_tag_paths) { - auto searcher = std::unique_ptr( - new diskann::PQFlashSearch(index_tag.first.c_str(), - num_nodes_to_cache, num_threads, - index_tag.second.c_str(), metric)); - g_ssdSearch.push_back(std::move(searcher)); - } - - } else if (data_type == std::string("int8")) { - for (auto& index_tag : index_tag_paths) { - auto searcher = std::unique_ptr( - new diskann::PQFlashSearch(index_tag.first.c_str(), - num_nodes_to_cache, num_threads, - index_tag.second.c_str(), metric)); - g_ssdSearch.push_back(std::move(searcher)); - } - } else if (data_type == std::string("uint8")) { - for (auto& index_tag : index_tag_paths) { - auto searcher = std::unique_ptr( - new diskann::PQFlashSearch(index_tag.first.c_str(), - num_nodes_to_cache, num_threads, - index_tag.second.c_str(), metric)); - g_ssdSearch.push_back(std::move(searcher)); - } - } else { - std::cerr << "Unsupported data type " << data_type << std::endl; - exit(-1); + } else if (data_type == std::string("uint8")) { + for (auto& index_tag : index_tag_paths) { + auto searcher = std::unique_ptr( + new diskann::PQFlashSearch( + index_tag.first.c_str(), num_nodes_to_cache, num_threads, + index_tag.second.c_str(), metric)); + g_ssdSearch.push_back(std::move(searcher)); } + } else { + std::cerr << "Unsupported data type " << data_type << std::endl; + exit(-1); + } - while (1) { - try { - setup(address, data_type); - std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl; - std::string line; - std::getline(std::cin, line); - if (line == "exit") { - teardown(address); - g_httpServer->close().wait(); - exit(0); - } - } catch (const std::exception& ex) { - std::cerr << "Exception occurred: " << ex.what() << std::endl; - std::cerr << "Restarting HTTP server"; - teardown(address); - } catch (...) { - std::cerr << "Unknown exception occurreed" << std::endl; - std::cerr << "Restarting HTTP server"; - teardown(address); - } + while (1) { + try { + setup(address, data_type); + std::cout << "Type 'exit' (case-sensitive) to exit" << std::endl; + std::string line; + std::getline(std::cin, line); + if (line == "exit") { + teardown(address); + g_httpServer->close().wait(); + exit(0); + } + } catch (const std::exception& ex) { + std::cerr << "Exception occurred: " << ex.what() << std::endl; + std::cerr << "Restarting HTTP server"; + teardown(address); + } catch (...) { + std::cerr << "Unknown exception occurreed" << std::endl; + std::cerr << "Restarting HTTP server"; + teardown(address); } + } } diff --git a/tests/restapi/ssd_server.cpp b/tests/restapi/ssd_server.cpp index 658b0cacb..e8e36f39e 100644 --- a/tests/restapi/ssd_server.cpp +++ b/tests/restapi/ssd_server.cpp @@ -10,8 +10,6 @@ #include #include - - #include using namespace diskann; @@ -22,7 +20,7 @@ std::vector> g_ssdSearch; void setup(const utility::string_t& address, const std::string& typestring) { web::http::uri_builder uriBldr(address); - auto uri = uriBldr.to_uri(); + auto uri = uriBldr.to_uri(); std::cout << "Attempting to start server on " << uri.to_string() << std::endl; @@ -40,21 +38,20 @@ void teardown(const utility::string_t& address) { int main(int argc, char* argv[]) { std::string data_type, index_path_prefix, address, dist_fn, tags_file; - uint32_t num_nodes_to_cache; - uint32_t num_threads; + uint32_t num_nodes_to_cache; + uint32_t num_threads; - po::options_description desc{ "Arguments" }; + po::options_description desc{"Arguments"}; try { desc.add_options()("help,h", "Print information on arguments"); desc.add_options()("data_type", - po::value(&data_type)->required(), - "data type "); - desc.add_options()("address", - po::value(&address)->required(), - "Web server address"); + po::value(&data_type)->required(), + "data type "); + desc.add_options()("address", po::value(&address)->required(), + "Web server address"); desc.add_options()("index_path_prefix", - po::value(&index_path_prefix)->required(), - "Path prefix for loading index file components"); + po::value(&index_path_prefix)->required(), + "Path prefix for loading index file components"); desc.add_options()( "num_nodes_to_cache", po::value(&num_nodes_to_cache)->default_value(0), @@ -64,52 +61,52 @@ int main(int argc, char* argv[]) { po::value(&num_threads)->default_value(omp_get_num_procs()), "Number of threads used for building index (defaults to " "omp_get_num_procs())"); - desc.add_options()("dist_fn", po::value(&dist_fn)->default_value("l2"), - "distance function "); - desc.add_options()("tags_file", + desc.add_options()("dist_fn", + po::value(&dist_fn)->default_value("l2"), + "distance function "); + desc.add_options()( + "tags_file", po::value(&tags_file)->default_value(std::string()), "Tags file location"); po::variables_map vm; po::store(po::parse_command_line(argc, argv, desc), vm); if (vm.count("help")) { - std::cout << desc; - return 0; + std::cout << desc; + return 0; } po::notify(vm); } catch (const std::exception& ex) { - std::cerr << ex.what() << std::endl; - return -1; + std::cerr << ex.what() << std::endl; + return -1; } diskann::Metric metric; if (dist_fn == std::string("l2")) - metric = diskann::Metric::L2; + metric = diskann::Metric::L2; else if (dist_fn == std::string("mips")) - metric = diskann::Metric::INNER_PRODUCT; + metric = diskann::Metric::INNER_PRODUCT; else { - std::cout << "Error. Only l2 and mips distance functions are supported" - << std::endl; - return -1; + std::cout << "Error. Only l2 and mips distance functions are supported" + << std::endl; + return -1; } - if (data_type == std::string("float")) { auto searcher = std::unique_ptr( - new diskann::PQFlashSearch( - index_path_prefix, num_nodes_to_cache, num_threads, - tags_file, metric)); + new diskann::PQFlashSearch(index_path_prefix, num_nodes_to_cache, + num_threads, tags_file, metric)); g_ssdSearch.push_back(std::move(searcher)); } else if (data_type == std::string("int8")) { auto searcher = std::unique_ptr(new diskann::PQFlashSearch( - index_path_prefix, num_nodes_to_cache, num_threads, - tags_file, metric)); + index_path_prefix, num_nodes_to_cache, num_threads, tags_file, + metric)); g_ssdSearch.push_back(std::move(searcher)); } else if (data_type == std::string("uint8")) { auto searcher = std::unique_ptr( - new diskann::PQFlashSearch( - index_path_prefix, num_nodes_to_cache, num_threads, - tags_file, metric)); + new diskann::PQFlashSearch(index_path_prefix, + num_nodes_to_cache, num_threads, + tags_file, metric)); g_ssdSearch.push_back(std::move(searcher)); } else { std::cerr << "Unsupported data type " << argv[2] << std::endl; diff --git a/tests/search_disk_index.cpp b/tests/search_disk_index.cpp index 40d8a952d..33339fdc2 100644 --- a/tests/search_disk_index.cpp +++ b/tests/search_disk_index.cpp @@ -44,14 +44,15 @@ void print_stats(std::string category, std::vector percentiles, diskann::cout << std::endl; } -template +template int search_disk_index( diskann::Metric& metric, const std::string& index_path_prefix, const std::string& result_output_prefix, const std::string& query_file, std::string& gt_file, const unsigned num_threads, const unsigned recall_at, const unsigned beamwidth, const unsigned num_nodes_to_cache, const _u32 search_io_limit, const std::vector& Lvec, - const bool use_reorder_data, const float fail_if_recall_below) { + const float fail_if_recall_below, const bool use_reorder_data = false, + const std::string& filter_label = "") { diskann::cout << "Search parameters: #threads: " << num_threads << ", "; if (beamwidth <= 0) diskann::cout << "beamwidth to be optimized for each L value" << std::flush; @@ -62,6 +63,10 @@ int search_disk_index( else diskann::cout << ", io_limit: " << search_io_limit << "." << std::endl; + bool filtered_search = false; + if (filter_label != "") + filtered_search = true; + std::string warmup_query_file = index_path_prefix + "_sample_data.bin"; // load query bin @@ -95,8 +100,8 @@ int search_disk_index( reader.reset(new LinuxAlignedFileReader()); #endif - std::unique_ptr> _pFlashIndex( - new diskann::PQFlashIndex(reader, metric)); + std::unique_ptr> _pFlashIndex( + new diskann::PQFlashIndex(reader, metric)); int res = _pFlashIndex->load(num_threads, index_path_prefix.c_str()); @@ -207,11 +212,22 @@ int search_disk_index( #pragma omp parallel for schedule(dynamic, 1) for (_s64 i = 0; i < (int64_t) query_num; i++) { - _pFlashIndex->cached_beam_search( - query + (i * query_aligned_dim), recall_at, L, - query_result_ids_64.data() + (i * recall_at), - query_result_dists[test_id].data() + (i * recall_at), - optimized_beamwidth, search_io_limit, use_reorder_data, stats + i); + if (!filtered_search) { + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, use_reorder_data, stats + i); + } else { + LabelT label_for_search = + _pFlashIndex->get_converted_label(filter_label); + _pFlashIndex->cached_beam_search( + query + (i * query_aligned_dim), recall_at, L, + query_result_ids_64.data() + (i * recall_at), + query_result_dists[test_id].data() + (i * recall_at), + optimized_beamwidth, true, label_for_search, use_reorder_data, + stats + i); + } } auto e = std::chrono::high_resolution_clock::now(); std::chrono::duration diff = e - s; @@ -282,7 +298,7 @@ int search_disk_index( int main(int argc, char** argv) { std::string data_type, dist_fn, index_path_prefix, result_path_prefix, - query_file, gt_file; + query_file, gt_file, filter_label, label_type; unsigned num_threads, K, W, num_nodes_to_cache, search_io_limit; std::vector Lvec; bool use_reorder_data = false; @@ -333,6 +349,15 @@ int main(int argc, char** argv) { po::bool_switch()->default_value(false), "Include full precision data in the index. Use only in " "conjuction with compressed data on SSD."); + desc.add_options()( + "filter_label", + po::value(&filter_label)->default_value(std::string("")), + "Filter Label for Filtered Search"); + desc.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); desc.add_options()( "fail_if_recall_below", po::value(&fail_if_recall_below)->default_value(0.0f), @@ -382,29 +407,52 @@ int main(int argc, char** argv) { } try { - if (data_type == std::string("float")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - use_reorder_data, fail_if_recall_below); - else if (data_type == std::string("int8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - use_reorder_data, fail_if_recall_below); - else if (data_type == std::string("uint8")) - return search_disk_index( - metric, index_path_prefix, result_path_prefix, query_file, gt_file, - num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, - use_reorder_data, fail_if_recall_below); - else { - std::cerr << "Unsupported data type. Use float or int8 or uint8" - << std::endl; - return -1; + if (filter_label != "" && label_type == "ushort") { + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else if (data_type == std::string("uint8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else { + std::cerr << "Unsupported data type. Use float or int8 or uint8" + << std::endl; + return -1; + } + } else { + if (data_type == std::string("float")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else if (data_type == std::string("int8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else if (data_type == std::string("uint8")) + return search_disk_index( + metric, index_path_prefix, result_path_prefix, query_file, gt_file, + num_threads, K, W, num_nodes_to_cache, search_io_limit, Lvec, + fail_if_recall_below, use_reorder_data, filter_label); + else { + std::cerr << "Unsupported data type. Use float or int8 or uint8" + << std::endl; + return -1; + } } } catch (const std::exception& e) { std::cout << std::string(e.what()) << std::endl; diskann::cerr << "Index search failed." << std::endl; return -1; } -} +} \ No newline at end of file diff --git a/tests/search_memory_index.cpp b/tests/search_memory_index.cpp index 04ce67320..cd441d32b 100644 --- a/tests/search_memory_index.cpp +++ b/tests/search_memory_index.cpp @@ -23,7 +23,7 @@ namespace po = boost::program_options; -template +template int search_memory_index(diskann::Metric& metric, const std::string& index_path, const std::string& result_path_prefix, const std::string& query_file, @@ -32,7 +32,8 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path, const bool print_all_recalls, const std::vector& Lvec, const bool dynamic, const bool tags, const bool show_qps_per_thread, - const float fail_if_recall_below) { + const std::string& filter_label, + const float fail_if_recall_below) { // Load the query file T* query = nullptr; unsigned* gt_ids = nullptr; @@ -54,8 +55,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path, diskann::cout << " Truthset file " << truthset_file << " not found. Not computing recall." << std::endl; } + + bool filtered_search = false; + if (filter_label != "") { + filtered_search = true; + } + using TagT = uint32_t; - diskann::Index index(metric, query_dim, 0, dynamic, tags); + diskann::Index index(metric, query_dim, 0, dynamic, tags); std::cout << "Index class instantiated" << std::endl; index.load(index_path.c_str(), num_threads, *(std::max_element(Lvec.begin(), Lvec.end()))); @@ -117,6 +124,7 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path, } query_result_ids[test_id].resize(recall_at * query_num); + query_result_dists[test_id].resize(recall_at * query_num); std::vector res = std::vector(); auto s = std::chrono::high_resolution_clock::now(); @@ -124,7 +132,14 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path, #pragma omp parallel for schedule(dynamic, 1) for (int64_t i = 0; i < (int64_t) query_num; i++) { auto qs = std::chrono::high_resolution_clock::now(); - if (metric == diskann::FAST_L2) { + if (filtered_search) { + LabelT filter_label_as_num = index.get_converted_label(filter_label); + auto retval = index.search_with_filters( + query + i * query_aligned_dim, filter_label_as_num, recall_at, L, + query_result_ids[test_id].data() + i * recall_at, + query_result_dists[test_id].data() + i * recall_at); + cmp_stats[i] = retval.second; + } else if (metric == diskann::FAST_L2) { index.search_with_optimized_layout( query + i * query_aligned_dim, recall_at, L, query_result_ids[test_id].data() + i * recall_at); @@ -212,10 +227,9 @@ int search_memory_index(diskann::Metric& metric, const std::string& index_path, return best_recall >= fail_if_recall_below ? 0 : -1; } - int main(int argc, char** argv) { std::string data_type, dist_fn, index_path_prefix, result_path, query_file, - gt_file; + gt_file, filter_label, label_type; unsigned num_threads, K; std::vector Lvec; bool print_all_recalls, dynamic, tags, show_qps_per_thread; @@ -238,6 +252,15 @@ int main(int argc, char** argv) { desc.add_options()("query_file", po::value(&query_file)->required(), "Query file in binary format"); + desc.add_options()( + "filter_label", + po::value(&filter_label)->default_value(std::string("")), + "Filter Label for Filtered Search"); + desc.add_options()( + "label_type", + po::value(&label_type)->default_value("uint"), + "Storage type of Labels , default value is uint which " + "will consume memory 4 bytes per filter"); desc.add_options()( "gt_file", po::value(>_file)->default_value(std::string("null")), @@ -313,26 +336,46 @@ int main(int argc, char** argv) { } try { - if (data_type == std::string("int8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, fail_if_recall_below); - } - - else if (data_type == std::string("uint8")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, fail_if_recall_below); - } else if (data_type == std::string("float")) { - return search_memory_index( - metric, index_path_prefix, result_path, query_file, gt_file, - num_threads, K, print_all_recalls, Lvec, dynamic, tags, - show_qps_per_thread, fail_if_recall_below); + if (filter_label != "" && label_type == "ushort") { + if (data_type == std::string("int8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else if (data_type == std::string("uint8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else if (data_type == std::string("float")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + return -1; + } } else { - std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; - return -1; + if (data_type == std::string("int8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else if (data_type == std::string("uint8")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else if (data_type == std::string("float")) { + return search_memory_index( + metric, index_path_prefix, result_path, query_file, gt_file, + num_threads, K, print_all_recalls, Lvec, dynamic, tags, + show_qps_per_thread, filter_label, fail_if_recall_below); + } else { + std::cout << "Unsupported type. Use float/int8/uint8" << std::endl; + return -1; + } } } catch (std::exception& e) { std::cout << std::string(e.what()) << std::endl; diff --git a/tests/test_insert_deletes_consolidate.cpp b/tests/test_insert_deletes_consolidate.cpp index 56178fa81..4b86cbcdf 100644 --- a/tests/test_insert_deletes_consolidate.cpp +++ b/tests/test_insert_deletes_consolidate.cpp @@ -160,6 +160,8 @@ void build_incremental_index( params.Set("saturate_graph", saturate_graph); params.Set("num_rnds", 1); params.Set("num_threads", thread_count); + params.Set("Lf", 0); // TODO: get this from params and default to some + // value to make it backward compatible. size_t dim, aligned_dim; size_t num_points; diff --git a/tests/test_streaming_scenario.cpp b/tests/test_streaming_scenario.cpp index 3e3e3899c..1647371fb 100644 --- a/tests/test_streaming_scenario.cpp +++ b/tests/test_streaming_scenario.cpp @@ -83,9 +83,10 @@ std::string get_save_filename(const std::string& save_path, return final_path; } -template -void insert_next_batch(diskann::Index& index, size_t start, size_t end, - size_t insert_threads, T* data, size_t aligned_dim) { +template +void insert_next_batch(diskann::Index& index, size_t start, + size_t end, size_t insert_threads, T* data, + size_t aligned_dim) { try { diskann::Timer insert_timer; std::cout << std::endl @@ -116,8 +117,8 @@ void insert_next_batch(diskann::Index& index, size_t start, size_t end, } } -template -void delete_and_consolidate(diskann::Index& index, +template +void delete_and_consolidate(diskann::Index& index, diskann::Parameters& delete_params, size_t start, size_t end) { try { @@ -189,6 +190,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L, params.Set("saturate_graph", saturate_graph); params.Set("num_rnds", 1); params.Set("num_threads", insert_threads); + params.Set("Lf", 0); diskann::Parameters delete_params; delete_params.Set("L", L); delete_params.Set("R", R); @@ -222,6 +224,7 @@ void build_incremental_index(const std::string& data_path, const unsigned L, __FUNCSIG__, __FILE__, __LINE__); using TagT = uint32_t; + using LabelT = uint32_t; unsigned num_frozen = 1; const bool enable_tags = true; @@ -232,9 +235,9 @@ void build_incremental_index(const std::string& data_path, const unsigned L, std::cout << "Overriding num_frozen to" << num_frozen << std::endl; } - diskann::Index index(diskann::L2, dim, - active_window + 4 * consolidate_interval, true, - params, params, enable_tags, true); + diskann::Index index( + diskann::L2, dim, active_window + 4 * consolidate_interval, true, params, + params, enable_tags, true); index.set_start_point_at_random(static_cast(start_point_norm)); index.enable_delete(); diff --git a/tests/utils/CMakeLists.txt b/tests/utils/CMakeLists.txt index e360bb7da..40aa050c8 100644 --- a/tests/utils/CMakeLists.txt +++ b/tests/utils/CMakeLists.txt @@ -66,3 +66,8 @@ target_link_libraries(merge_shards ${PROJECT_NAME} ${DISKANN_TOOLS_TCMALLOC_LINK add_executable(create_disk_layout create_disk_layout.cpp) target_link_libraries(create_disk_layout ${PROJECT_NAME} ${DISKANN_ASYNC_LIB} ${DISKANN_TOOLS_TCMALLOC_LINK_OPTIONS}) +add_executable(generate_synthetic_labels generate_synthetic_labels.cpp) +target_link_libraries(generate_synthetic_labels ${PROJECT_NAME} Boost::program_options) + +add_executable(stats_label_data stats_label_data.cpp) +target_link_libraries(stats_label_data ${PROJECT_NAME} Boost::program_options) diff --git a/tests/utils/compute_groundtruth.cpp b/tests/utils/compute_groundtruth.cpp index aefee27e7..202e2c1e2 100644 --- a/tests/utils/compute_groundtruth.cpp +++ b/tests/utils/compute_groundtruth.cpp @@ -302,6 +302,66 @@ inline void load_bin_as_float(const char *filename, float *&data, std::cout << "Finished converting part data to float." << std::endl; } +template +inline std::vector load_filtered_bin_as_float( + const char *filename, float *&data, size_t &npts, size_t &ndims, + int part_num, const char *label_file, const std::string &filter_label, + const std::string &universal_label, size_t &npoints_filt, + std::vector> &pts_to_labels) { + std::ifstream reader(filename, std::ios::binary); + if (reader.fail()) { + throw diskann::ANNException(std::string("Failed to open file ") + filename, + -1); + } + + std::cout << "Reading bin file " << filename << " ...\n"; + int npts_i32, ndims_i32; + std::vector rev_map; + reader.read((char *) &npts_i32, sizeof(int)); + reader.read((char *) &ndims_i32, sizeof(int)); + uint64_t start_id = part_num * PARTSIZE; + uint64_t end_id = (std::min)(start_id + PARTSIZE, (uint64_t) npts_i32); + npts = end_id - start_id; + ndims = (unsigned) ndims_i32; + uint64_t nptsuint64_t = (uint64_t) npts; + uint64_t ndimsuint64_t = (uint64_t) ndims; + npoints_filt = 0; + std::cout << "#pts in part = " << npts << ", #dims = " << ndims + << ", size = " << nptsuint64_t * ndimsuint64_t * sizeof(T) << "B" + << std::endl; + std::cout << "start and end ids: " << start_id << ", " << end_id << std::endl; + reader.seekg(start_id * ndims * sizeof(T) + 2 * sizeof(uint32_t), + std::ios::beg); + + T *data_T = new T[nptsuint64_t * ndimsuint64_t]; + reader.read((char *) data_T, sizeof(T) * nptsuint64_t * ndimsuint64_t); + std::cout << "Finished reading part of the bin file." << std::endl; + reader.close(); + + data = aligned_malloc(nptsuint64_t * ndimsuint64_t, ALIGNMENT); + + for (int64_t i = 0; i < (int64_t) nptsuint64_t; i++) { + if (std::find(pts_to_labels[start_id + i].begin(), + pts_to_labels[start_id + i].end(), + filter_label) != pts_to_labels[start_id + i].end() || + std::find(pts_to_labels[start_id + i].begin(), + pts_to_labels[start_id + i].end(), + universal_label) != pts_to_labels[start_id + i].end()) { + rev_map.push_back(start_id + i); + for (int64_t j = 0; j < (int64_t) ndimsuint64_t; j++) { + float cur_val_float = (float) data_T[i * ndimsuint64_t + j]; + std::memcpy((char *) (data + npoints_filt * ndimsuint64_t + j), + (char *) &cur_val_float, sizeof(float)); + } + npoints_filt++; + } + } + delete[] data_T; + std::cout << "Finished converting part data to float.. identified " + << npoints_filt << " points matching the filter." << std::endl; + return rev_map; +} + template inline void save_bin(const std::string filename, T *data, size_t npts, size_t ndims) { @@ -334,19 +394,51 @@ inline void save_groundtruth_as_one_file(const std::string filename, << 2 * npts * ndims * sizeof(unsigned) + 2 * sizeof(int) << "B" << std::endl; - // data = new T[npts_u64 * ndims_u64]; writer.write((char *) data, npts * ndims * sizeof(uint32_t)); writer.write((char *) distances, npts * ndims * sizeof(float)); writer.close(); std::cout << "Finished writing truthset" << std::endl; } +inline void parse_label_file_into_vec( + size_t &line_cnt, const std::string &map_file, + std::vector> &pts_to_labels) { + std::ifstream infile(map_file); + std::string line, token; + std::set labels; + infile.clear(); + infile.seekg(0, std::ios::beg); + while (std::getline(infile, line)) { + std::istringstream iss(line); + std::vector lbls(0); + + getline(iss, token, '\t'); + std::istringstream new_iss(token); + while (getline(new_iss, token, ',')) { + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + lbls.push_back(token); + labels.insert(token); + } + if (lbls.size() <= 0) { + std::cout << "No label found"; + exit(-1); + } + std::sort(lbls.begin(), lbls.end()); + pts_to_labels.push_back(lbls); + } + std::cout << "Identified " << labels.size() + << " distinct label(s), and populated labels for " + << pts_to_labels.size() << " points" << std::endl; +} + template -int aux_main(const std::string &base_file, const std::string &query_file, - const std::string >_file, size_t k, - const diskann::Metric &metric, - const std::string &tags_file = std::string("")) { - size_t npoints, nqueries, dim; +int aux_main(const std::string &base_file, const std::string &label_file, + const std::string &query_file, const std::string >_file, + size_t k, const std::string &filter_label, + const std::string &universal_label, const diskann::Metric &metric, + const std::string &tags_file = std::string("")) { + size_t npoints, nqueries, dim, npoints_filt; float *base_data; float *query_data; @@ -392,25 +484,50 @@ int aux_main(const std::string &base_file, const std::string &query_file, int *closest_points = new int[nqueries * k]; float *dist_closest_points = new float[nqueries * k]; + std::vector> pts_to_labels; + if (filter_label != "") + parse_label_file_into_vec(npoints, label_file, pts_to_labels); + std::vector rev_map; + for (int p = 0; p < num_parts; p++) { size_t start_id = p * PARTSIZE; - load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + if (filter_label == "") { + load_bin_as_float(base_file.c_str(), base_data, npoints, dim, p); + } else { + rev_map = load_filtered_bin_as_float( + base_file.c_str(), base_data, npoints, dim, p, label_file.c_str(), + filter_label, universal_label, npoints_filt, pts_to_labels); + } int *closest_points_part = new int[nqueries * k]; float *dist_closest_points_part = new float[nqueries * k]; - auto nr = std::min(npoints, k); - - exact_knn(dim, nr, closest_points_part, dist_closest_points_part, npoints, - base_data, nqueries, query_data, metric); + _u32 part_k; + if (filter_label == "") { + part_k = k < npoints ? k : npoints; + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, + npoints, base_data, nqueries, query_data, metric); + } else { + part_k = k < npoints_filt ? k : npoints_filt; + if (npoints_filt > 0) { + exact_knn(dim, part_k, closest_points_part, dist_closest_points_part, + npoints_filt, base_data, nqueries, query_data, metric); + } + } for (_u64 i = 0; i < nqueries; i++) { - for (_u64 j = 0; j < nr; j++) { + for (_u64 j = 0; j < part_k; j++) { if (tags_enabled) if (location_to_tag[closest_points_part[i * k + j] + start_id] == 0) continue; - results[i].push_back(std::make_pair( - (uint32_t) (closest_points_part[i * nr + j] + start_id), - dist_closest_points_part[i * nr + j])); + if (filter_label == "") { + results[i].push_back(std::make_pair( + (uint32_t) (closest_points_part[i * part_k + j] + start_id), + dist_closest_points_part[i * part_k + j])); + } else { + results[i].push_back(std::make_pair( + (uint32_t) (rev_map[closest_points_part[i * part_k + j]]), + dist_closest_points_part[i * part_k + j])); + } } } @@ -455,8 +572,9 @@ int aux_main(const std::string &base_file, const std::string &query_file, } int main(int argc, char **argv) { - std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file; - uint64_t K; + std::string data_type, dist_fn, base_file, query_file, gt_file, tags_file, + label_file, filter_label, universal_label; + uint64_t K; try { po::options_description desc{"Arguments"}; @@ -474,6 +592,17 @@ int main(int argc, char **argv) { desc.add_options()("query_file", po::value(&query_file)->required(), "File containing the query vectors in binary format"); + desc.add_options()("label_file", + po::value(&label_file)->default_value(""), + "Input labels file in txt format if present"); + desc.add_options()("filter_label", + po::value(&filter_label)->default_value(""), + "Input filter label if doing filtered groundtruth"); + desc.add_options()( + "universal_label", + po::value(&universal_label)->default_value(""), + "Universal label, if using it, only in conjunction with label_file"); + desc.add_options()( "gt_file", po::value(>_file)->required(), "File name for the writing ground truth in binary format"); @@ -518,11 +647,14 @@ int main(int argc, char **argv) { try { if (data_type == std::string("float")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); + aux_main(base_file, label_file, query_file, gt_file, K, + filter_label, universal_label, metric, tags_file); if (data_type == std::string("int8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); + aux_main(base_file, label_file, query_file, gt_file, K, + filter_label, universal_label, metric, tags_file); if (data_type == std::string("uint8")) - aux_main(base_file, query_file, gt_file, K, metric, tags_file); + aux_main(base_file, label_file, query_file, gt_file, K, + filter_label, universal_label, metric, tags_file); } catch (const std::exception &e) { std::cout << std::string(e.what()) << std::endl; diskann::cerr << "Compute GT failed." << std::endl; diff --git a/tests/utils/count_bfs_levels.cpp b/tests/utils/count_bfs_levels.cpp index 97bf11e7b..21c0472f8 100644 --- a/tests/utils/count_bfs_levels.cpp +++ b/tests/utils/count_bfs_levels.cpp @@ -26,8 +26,9 @@ namespace po = boost::program_options; template void bfs_count(const std::string& index_path, unsigned data_dims) { using TagT = uint32_t; - diskann::Index index(diskann::Metric::L2, data_dims, 0, false, - false); + using LabelT = uint32_t; + diskann::Index index(diskann::Metric::L2, data_dims, 0, + false, false); std::cout << "Index class instantiated" << std::endl; index.load(index_path.c_str(), 1, 100); std::cout << "Index loaded" << std::endl; diff --git a/tests/utils/generate_synthetic_labels.cpp b/tests/utils/generate_synthetic_labels.cpp new file mode 100644 index 000000000..bff97b186 --- /dev/null +++ b/tests/utils/generate_synthetic_labels.cpp @@ -0,0 +1,169 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include "utils.h" + +namespace po = boost::program_options; +class ZipfDistribution { + public: + ZipfDistribution(int num_points, int num_labels) + : uniform_zero_to_one(std::uniform_real_distribution<>(0.0, 1.0)), + num_points(num_points), num_labels(num_labels) { + } + + std::unordered_map createDistributionMap() { + std::unordered_map map; + int primary_label_freq = ceil(num_points * distribution_factor); + for (int i{1}; i < num_labels + 1; i++) { + map[i] = ceil(primary_label_freq / i); + } + return map; + } + + int writeDistribution(std::ofstream& outfile) { + auto distribution_map = createDistributionMap(); + auto primary_label_frequency = num_points * distribution_factor; + for (int i{0}; i < num_points; i++) { + bool label_written = false; + for (auto it = distribution_map.cbegin(), next_it = it; + it != distribution_map.cend(); it = next_it) { + next_it++; + auto label_selection_probability = std::bernoulli_distribution( + distribution_factor / (double) it->first); + if (label_selection_probability(rand_engine)) { + if (label_written) { + outfile << ','; + } + outfile << it->first; + label_written = true; + // remove label from map if we have used all labels + distribution_map[it->first] -= 1; + if (distribution_map[it->first] == 0) { + distribution_map.erase(it); + } + } + } + if (!label_written) { + outfile << 0; + } + if (i < num_points - 1) { + outfile << '\n'; + } + } + return 0; + } + + int writeDistribution(std::string filename) { + std::ofstream outfile(filename); + if (!outfile.is_open()) { + std::cerr << "Error: could not open output file " << filename << '\n'; + return -1; + } + writeDistribution(outfile); + outfile.close(); + } + + private: + int num_labels; + const int num_points; + const double distribution_factor = 0.7; + std::knuth_b rand_engine; + const std::uniform_real_distribution uniform_zero_to_one; +}; + +int main(int argc, char** argv) { + std::string output_file, distribution_type; + _u64 num_labels, num_points; + + try { + po::options_description desc{"Arguments"}; + + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("output_file,O", + po::value(&output_file)->required(), + "Filename for saving the label file"); + desc.add_options()("num_points,N", + po::value(&num_points)->required(), + "Number of points in dataset"); + desc.add_options()("num_labels,L", + po::value(&num_labels)->required(), + "Number of unique labels, up to 5000"); + desc.add_options()( + "distribution_type,DT", + po::value(&distribution_type)->default_value("random"), + "Distribution function for labels defaults to random"); + + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; + } + po::notify(vm); + } catch (const std::exception& ex) { + std::cerr << ex.what() << '\n'; + return -1; + } + + if (num_labels > 5000) { + std::cerr << "Error: num_labels must be 5000 or less" << '\n'; + return -1; + } + + if (num_points <= 0) { + std::cerr << "Error: num_points must be greater than 0" << '\n'; + return -1; + } + + std::cout << "Generating synthetic labels for " << num_points + << " points with " << num_labels << " unique labels" << '\n'; + + try { + std::ofstream outfile(output_file); + if (!outfile.is_open()) { + std::cerr << "Error: could not open output file " << output_file << '\n'; + return -1; + } + + if (distribution_type == "zipf") { + ZipfDistribution zipf(num_points, num_labels); + zipf.writeDistribution(outfile); + } else if (distribution_type == "random") { + for (int i = 0; i < num_points; i++) { + bool label_written = false; + for (int j = 1; j <= num_labels; j++) { + // 50% chance to assign each label + if (rand() > (RAND_MAX / 2)) { + if (label_written) { + outfile << ','; + } + outfile << j; + label_written = true; + } + } + if (!label_written) { + outfile << 0; + } + if (i < num_points - 1) { + outfile << '\n'; + } + } + } + if (outfile.is_open()) { + outfile.close(); + } + + std::cout << "Labels written to " << output_file << '\n'; + + } catch (const std::exception& ex) { + std::cerr << "Label generation failed: " << ex.what() << '\n'; + return -1; + } + + return 0; +} \ No newline at end of file diff --git a/tests/utils/stats_label_data.cpp b/tests/utils/stats_label_data.cpp new file mode 100644 index 000000000..10eca0fc4 --- /dev/null +++ b/tests/utils/stats_label_data.cpp @@ -0,0 +1,150 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// Licensed under the MIT license. + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "utils.h" + +#ifndef _WINDOWS +#include +#include +#include +#include +#else +#include +#endif +namespace po = boost::program_options; + +void stats_analysis(const std::string labels_file, std::string univeral_label, + _u32 density = 10) { + std::string token, line; + std::ifstream labels_stream(labels_file); + std::unordered_map label_counts; + std::string label_with_max_points; + _u32 max_points = 0; + long long sum = 0; + long long point_cnt = 0; + float avg_labels_per_pt, avg_labels_per_pt_incl_0, mean_label_size, + mean_label_size_incl_0; + + std::vector<_u32> labels_per_point; + _u32 dense_pts = 0; + if (labels_stream.is_open()) { + while (getline(labels_stream, line)) { + point_cnt++; + std::stringstream iss(line); + _u32 lbl_cnt = 0; + while (getline(iss, token, ',')) { + lbl_cnt++; + token.erase(std::remove(token.begin(), token.end(), '\n'), token.end()); + token.erase(std::remove(token.begin(), token.end(), '\r'), token.end()); + if (label_counts.find(token) == label_counts.end()) + label_counts[token] = 0; + label_counts[token]++; + } + if (lbl_cnt >= density) { + dense_pts++; + } + labels_per_point.emplace_back(lbl_cnt); + } + } + + std::cout << "fraction of dense points with >= " << density << " labels = " + << (float) dense_pts / (float) labels_per_point.size() << std::endl; + std::sort(labels_per_point.begin(), labels_per_point.end()); + + std::vector> label_count_vec; + + for (auto it = label_counts.begin(); it != label_counts.end(); it++) { + auto& lbl = *it; + label_count_vec.emplace_back(std::make_pair(lbl.first, lbl.second)); + if (lbl.second > max_points) { + max_points = lbl.second; + label_with_max_points = lbl.first; + } + sum += lbl.second; + } + + sort(label_count_vec.begin(), label_count_vec.end(), + [](const std::pair& lhs, + const std::pair& rhs) { + return lhs.second < rhs.second; + }); + + for (float p = 0; p < 1; p += 0.05) { + std::cout << "Percentile " << (100 * p) << "\t" + << label_count_vec[(_u32) (p * label_count_vec.size())].first + << " with count=" + << label_count_vec[(_u32) (p * label_count_vec.size())].second + << std::endl; + } + + std::cout << "Most common label " + << "\t" << label_count_vec[label_count_vec.size() - 1].first + << " with count=" + << label_count_vec[label_count_vec.size() - 1].second << std::endl; + if (label_count_vec.size() > 1) + std::cout << "Second common label " + << "\t" << label_count_vec[label_count_vec.size() - 2].first + << " with count=" + << label_count_vec[label_count_vec.size() - 2].second + << std::endl; + if (label_count_vec.size() > 2) + std::cout << "Third common label " + << "\t" << label_count_vec[label_count_vec.size() - 3].first + << " with count=" + << label_count_vec[label_count_vec.size() - 3].second + << std::endl; + avg_labels_per_pt = (sum) / (float) point_cnt; + mean_label_size = (sum) / label_counts.size(); + std::cout << "Total number of points = " << point_cnt + << ", number of labels = " << label_counts.size() << std::endl; + std::cout << "Average number of labels per point = " << avg_labels_per_pt + << std::endl; + std::cout << "Mean label size excluding 0 = " << mean_label_size << std::endl; + std::cout << "Most popular label is " << label_with_max_points << " with " + << max_points << " pts" << std::endl; +} + +int main(int argc, char** argv) { + std::string labels_file, universal_label; + _u32 density; + + po::options_description desc{"Arguments"}; + try { + desc.add_options()("help,h", "Print information on arguments"); + desc.add_options()("labels_file", + po::value(&labels_file)->required(), + "path to labels data file."); + desc.add_options()("universal_label", + po::value(&universal_label)->required(), + "Universal label used in labels file."); + desc.add_options()( + "density", po::value<_u32>(&density)->default_value(1), + "Number of labels each point in labels file, defaults to 1"); + po::variables_map vm; + po::store(po::parse_command_line(argc, argv, desc), vm); + if (vm.count("help")) { + std::cout << desc; + return 0; + } + po::notify(vm); + } catch (const std::exception& e) { + std::cerr << e.what() << '\n'; + return -1; + } + stats_analysis(labels_file, universal_label, density); +} diff --git a/workflows/filtered_in_memory.md b/workflows/filtered_in_memory.md new file mode 100644 index 000000000..c3b652685 --- /dev/null +++ b/workflows/filtered_in_memory.md @@ -0,0 +1,126 @@ +**Usage for filtered indices** +================================ +## Building a filtered Index +DiskANN provides two algorithms for building an index with filters support: filtered-vamana and stitched-vamana. Here, we describe the parameters for building both. `tests/build_memory_index.cpp` and `tests/build_stitched_index.cpp` are respectively used to build each kind of index. + +### 1. filtered-vamana + +1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported. +2. **`--dist_fn`**: There are two distance functions supported: minimum Euclidean distance (l2) and maximum inner product (mips). +3. **`--data_file`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices. +4. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix. +5. **`-R (--max_degree)`** (default is 64): the degree of the graph index, typically between 32 and 150. Larger R will result in larger indices and longer indexing times, but might yield better search quality. +6. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality. Note that this is to be used only for building an unfiltered index. The corresponding search list parameter for a filtered index is managed by `--FilteredLbuild`. +7. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs. +8. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth). +9. **`--build_PQ_bytes`** (default is 0): Set to a positive value less than the dimensionality of the data to enable faster index build with PQ based distance comparisons. Defaults to using full precision vectors for distance comparisons. +10. **`--use_opq`**: use the flag to use OPQ rather than PQ compression. OPQ is more space efficient for some high dimensional datasets, but also needs a bit more build time. +11. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`. +12. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point. +13. **`--FilteredLbuild`**: If building a filtered index, we maintain a separate search list from the one provided by `--Lbuild`. + +### 2. stitched-vamana +1. **`--data_type`**: The type of dataset you wish to build an index on. float(32 bit), signed int8 and unsigned uint8 are supported. +2. **`--data_path`**: The input data over which to build an index, in .bin format. The first 4 bytes represent number of points as integer. The next 4 bytes represent the dimension of data as integer. The following `n*d*sizeof(T)` bytes contain the contents of the data one data point in time. sizeof(T) is 1 for byte indices, and 4 for float indices. This will be read by the program as int8_t for signed indices, uint8_t for unsigned indices or float for float indices. +3. **`--index_path_prefix`**: The constructed index components will be saved to this path prefix. +4. **`-R (--max_degree)`** (default is 64): Recall that stitched-vamana first builds a sub-index for each filter. This parameter sets the max degree for each sub-index. +5. **`-L (--Lbuild)`** (default is 100): the size of search list we maintain during sub-index building. Typical values are between 75 to 400. Larger values will take more time to build but result in indices that provide higher recall for the same search complexity. Ensure that value of L is at least that of R value unless you need to build indices really quickly and can somewhat compromise on quality. +6. **`--alpha`** (default is 1.2): A float value between 1.0 and 1.5 which determines the diameter of the graph, which will be approximately *log n* to the base alpha. Typical values are between 1 to 1.5. 1 will yield the sparsest graph, 1.5 will yield denser graphs. +7. **`-T (--num_threads)`** (default is to get_omp_num_procs()): number of threads used by the index build process. Since the code is highly parallel, the indexing time improves almost linearly with the number of threads (subject to the cores available on the machine and DRAM bandwidth). +8. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`. +9. **`--universal_label`**: Optionally, the the filter data may contain a "wild-card" filter corresponding to all filters. This is referred to as a universal label. Note that if a point has the universal label, then the filter data must only have the universal label on the line corresponding to said point. +10. **`--Stitched_R`**: Once all sub-indices are "stitched" together, we prune the resulting graph down to the degree given by this parameter. + +## Computing a groundtruth file for a filtered index +In order to evaluate the performance of our algorithms, we can compare its results (i.e. the top `k` neighbors found for each query) against the results found by an exact nearest neighbor search. We provide the program `tests/utils/compute_groundtruth.cpp` to provide the results for the latter: + +1. **`--data_type`** The type of dataset you built an index with. float(32 bit), signed int8 and unsigned uint8 are supported. +2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. +3. **`--base_file`**: The input data over which to build an index, in .bin format. Corresponds to the `--data_path` argument from above. +4. **`--query_file`**: The queries to be searched on, which are stored in the same .bin format. +5. **`--label_file`**: Filter data for each point, in `.txt` format. Line `i` of the file consists of a comma-separated list of filters corresponding to point `i` in the file passed via `--data_file`. +6. **`--filter_label`**: Filter for each query. For each query, a search is performed with this filter. +7. **`--universal_label`**: Corresponds to the universal label passed when building an index with filter support. +8. **`--gt_file`**: File to output results to. The binary file starts with `n`, the number of queries (4 bytes), followed by `d`, the number of ground truth elements per query (4 bytes), followed by `n*d` entries per query representing the `d` closest IDs per query in integer format, followed by `n*d` entries representing the corresponding distances (float). Total file size is `8 + 4*n*d + 4*n*d` bytes. +9. **`-K`**: The number of nearest neighbors to compute for each query. + + + +## Searching a Filtered Index + +Searching a filtered index uses the `tests/search_memory_index.cpp`: + +1. **`--data_type`**: The type of dataset you built the index on. float(32 bit), signed int8 and unsigned uint8 are supported. Use the same data type as in arg (1) above used in building the index. +2. **`--dist_fn`**: There are two distance functions supported: l2 and mips. There is an additional *fast_l2* implementation that could provide faster results for small (about a million-sized) indices. Use the same distance as in arg (2) above used in building the index. Note that stitched-vamana only supports l2. +3. **`--index_path_prefix`**: index built above in argument (4). +4. **`--result_path`**: search results will be stored in files, one per L value (see last arg), with specified prefix, in binary format. +5. **`-T (--num_threads)`**: The number of threads used for searching. Threads run in parallel and one thread handles one query at a time. More threads will result in higher aggregate query throughput, but may lead to higher per-query latency, especially if the DRAM bandwidth is a bottleneck. So find the balance depending on throughput and latency required for your application. +6. **`--query_file`**: The queries to be searched on in same binary file format as the data file (ii) above. The query file must be the same type as in argument (1). +7. **`--filter_label`**: The filter to be used when searching an index with filters. For each query, a search is performed with this filter. +8. **`--gt_file`**: The ground truth file for the queries and data file used in index construction. Use "null" if you do not have this file and if you do not want to compute recall. Note that if building a filtered index, a special groundtruth must be computed, as described above. +9. **`-K`**: search for *K* neighbors and measure *K*-recall@*K*, meaning the intersection between the retrieved top-*K* nearest neighbors and ground truth *K* nearest neighbors. +10. **`-L (--search_list)`**: A list of search_list sizes to perform search with. Larger parameters will result in slower latencies, but higher accuracies. Must be atleast the value of *K* in (7). + +Example with SIFT10K: +-------------------- +We demonstrate how to work through this pipeline using the SIFT10K dataset (http://corpus-texmex.irisa.fr/). Before starting, make sure you have compiled diskANN according to the instructions in the README and can see the following binaries (paths with respect to repository root): +- `build/tests/utils/compute_groundtruth` +- `build/tests/utils/fvecs_to_bin` +- `build/tests/build_memory_index` +- `build/tests/build_stitched_index` +- `build/tests/search_memory_index` + +Now, download the base and query set and convert the data to binary format: +```bash +wget ftp://ftp.irisa.fr/local/texmex/corpus/siftsmall.tar.gz +tar -zxvf siftsmall.tar.gz +build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_base.fvecs siftsmall/siftsmall_base.bin +build/tests/utils/fvecs_to_bin float siftsmall/siftsmall_query.fvecs siftsmall/siftsmall_query.bin +``` + +We now need to make label file for our vectors. For convenience, we've included a synthetic label generator through which we can generate label file as follow +```bash + build/tests/utils/generate_synthetic_labels --num_labels 50 --num_points 10000 --output_file ./rand_labels_50_10K.txt --distribution_type zipf +``` +Note : `distribution_type` can be `rand` or `zipf` + +This will genearate label file with 10000 data points with 50 distinct labels, ranging from 1 to 50 assigned using zipf distribution (0 is the universal label). + +Label count for each unique label in the generated label file can be printed with help of following command +```bash + build/tests/utils/stats_label_data.exe --labels_file ./rand_labels_50_10K.txt --universal_label 0 +``` + +Note that neither approach is designed for use with random synthetic labels, which will lead to unpredictable accuracy at search time. + +Now build and search the index and measure the recall using ground truth computed using bruteforce. We search for results with the filter 35. +```bash +build/tests/utils/compute_groundtruth --data_type float --dist_fn l2 --base_file siftsmall/siftsmall_base.bin --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --K 100 --label_file ./rand_labels_50_10K.txt --filter_label 35 --universal_label 0 +build/tests/build_memory_index --data_type float --dist_fn l2 --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R32_L50_filtered_index -R 32 --FilteredLbuild 50 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 +build/tests/build_stitched_index --data_type float --data_path siftsmall/siftsmall_base.bin --index_path_prefix siftsmall/siftsmall_R20_L40_SR32_stitched_index -R 20 -L 40 --stitched_R 32 --alpha 1.2 --label_file ./rand_labels_50_10K.txt --universal_label 0 +build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_filtered_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/filtered_search_results +build/tests/search_memory_index --data_type float --dist_fn l2 --index_path_prefix data/sift/siftsmall_R20_L40_SR32_stitched_index --query_file siftsmall/siftsmall_query.bin --gt_file siftsmall/siftsmall_gt_35.bin --filter_label 35 -K 10 -L 10 20 30 40 50 100 --result_path siftsmall/stitched_search_results +``` + + The output of both searches is listed below. The throughput (Queries/sec) as well as mean and 99.9 latency in microseconds for each `L` parameter provided. (Measured on a physical machine with a Intel(R) Xeon(R) W-2145 CPU and 64 GB RAM) + ``` + Stitched Index + Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10 +================================================================================= + 10 31324.39 37.33 116.79 311.90 17.80 + 20 91357.57 44.36 193.06 1042.30 17.90 + 30 69314.48 49.89 258.09 1398.00 18.20 + 40 61421.29 60.52 289.08 1515.00 18.60 + 50 54203.48 70.27 294.26 685.10 19.40 + 100 52904.45 79.00 336.26 1018.80 19.50 + +Filtered Index + Ls QPS Avg dist cmps Mean Latency (mus) 99.9 Latency Recall@10 +================================================================================= + 10 69671.84 21.48 45.25 146.20 11.60 + 20 168577.20 38.94 100.54 547.90 18.20 + 30 127129.41 52.95 126.83 768.40 19.70 + 40 106349.04 62.38 167.23 899.10 20.90 + 50 89952.33 70.95 189.12 1070.80 22.10 + 100 56899.00 112.26 304.67 636.60 23.80 + ```