Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions diskann-benchmark-core/src/search/graph/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@

pub mod knn;
pub mod multihop;
pub mod rag;
pub mod range;

pub mod strategy;

pub use knn::KNN;
pub use multihop::MultiHop;
pub use rag::RAG;
pub use range::Range;

pub use strategy::Strategy;
Expand Down
216 changes: 216 additions & 0 deletions diskann-benchmark-core/src/search/graph/rag.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
/*
* Copyright (c) Microsoft Corporation.
* Licensed under the MIT license.
*/

//! A built-in helper for benchmarking RAG (Retrieval-Augmented Generation) search.
//!
//! This mirrors [`super::KNN`] but uses [`graph::search::RagSearch`] as the search
//! parameters, applying diversity-maximizing reranking in post-processing.

use std::sync::Arc;

use diskann::{
ANNResult,
graph::{self, glue},
provider,
};
use diskann_benchmark_runner::utils::{MicroSeconds, percentiles};
use diskann_utils::{future::AsyncFriendly, views::Matrix};

use crate::{
search::{self, Search, graph::Strategy},
utils,
};

/// A built-in helper for benchmarking the RAG search method
/// [`graph::DiskANNIndex::search`] with [`graph::search::RagSearch`].
///
/// This is identical to [`super::KNN`] in structure but uses [`graph::search::RagSearch`]
/// as the search parameters, which applies diversity-maximizing reranking via greedy
/// orthogonalization in post-processing.
///
/// The provided implementation of [`Search`] accepts [`graph::search::RagSearch`]
/// and returns [`super::knn::Metrics`] as additional output.
#[derive(Debug)]
pub struct RAG<DP, T, S>
where
DP: provider::DataProvider,
{
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
}

impl<DP, T, S> RAG<DP, T, S>
where
DP: provider::DataProvider,
{
/// Construct a new [`RAG`] searcher.
///
/// If `strategy` is one of the container variants of [`Strategy`], its length
/// must match the number of rows in `queries`.
///
/// # Errors
///
/// Returns an error if the number of elements in `strategy` is not compatible with
/// the number of rows in `queries`.
pub fn new(
index: Arc<graph::DiskANNIndex<DP>>,
queries: Arc<Matrix<T>>,
strategy: Strategy<S>,
) -> anyhow::Result<Arc<Self>> {
strategy.length_compatible(queries.nrows())?;

Ok(Arc::new(Self {
index,
queries,
strategy,
}))
}
}

impl<DP, T, S> Search for RAG<DP, T, S>
where
DP: provider::DataProvider<Context: Default, ExternalId: search::Id>,
S: glue::SearchStrategy<DP, [T], DP::ExternalId>
+ glue::PostProcess<graph::search::RagSearchParams, DP, [T], DP::ExternalId>
+ Clone
+ AsyncFriendly,
T: AsyncFriendly + Clone,
{
type Id = DP::ExternalId;
type Parameters = graph::search::RagSearch;
type Output = super::knn::Metrics;

fn num_queries(&self) -> usize {
self.queries.nrows()
}

fn id_count(&self, parameters: &Self::Parameters) -> search::IdCount {
search::IdCount::Fixed(parameters.k_value())
}

async fn search<O>(
&self,
parameters: &Self::Parameters,
buffer: &mut O,
index: usize,
) -> ANNResult<Self::Output>
where
O: graph::SearchOutputBuffer<DP::ExternalId> + Send,
{
let context = DP::Context::default();
let rag_search = parameters.clone();
let stats = self
.index
.search(
rag_search,
self.strategy.get(index)?,
&context,
self.queries.row(index),
buffer,
)
.await?;

Ok(super::knn::Metrics {
comparisons: stats.cmps,
hops: stats.hops,
})
}
}

/// An [`search::Aggregate`]d summary of multiple [`RAG`] search runs.
///
/// This reuses [`super::knn::Summary`] since the output format is identical —
/// the only difference is the search parameters type.
pub struct Aggregator<'a, I> {
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
}

impl<'a, I> Aggregator<'a, I> {
/// Construct a new [`Aggregator`] using `groundtruth` for recall computation.
pub fn new(
groundtruth: &'a dyn crate::recall::Rows<I>,
recall_k: usize,
recall_n: usize,
) -> Self {
Self {
groundtruth,
recall_k,
recall_n,
}
}
}

impl<I> search::Aggregate<graph::search::RagSearch, I, super::knn::Metrics> for Aggregator<'_, I>
where
I: crate::recall::RecallCompatible,
{
type Output = super::knn::Summary;

fn aggregate(
&mut self,
run: search::Run<graph::search::RagSearch>,
mut results: Vec<search::SearchResults<I, super::knn::Metrics>>,
) -> anyhow::Result<super::knn::Summary> {
// Compute the recall using just the first result.
let recall = match results.first() {
Some(first) => crate::recall::knn(
self.groundtruth,
None,
first.ids().as_rows(),
self.recall_k,
self.recall_n,
true,
)?,
None => anyhow::bail!("Results must be non-empty"),
};

let mut mean_latencies = Vec::with_capacity(results.len());
let mut p90_latencies = Vec::with_capacity(results.len());
let mut p99_latencies = Vec::with_capacity(results.len());

results.iter_mut().for_each(|r| {
match percentiles::compute_percentiles(r.latencies_mut()) {
Ok(values) => {
let percentiles::Percentiles { mean, p90, p99, .. } = values;
mean_latencies.push(mean);
p90_latencies.push(p90);
p99_latencies.push(p99);
}
Err(_) => {
let zero = MicroSeconds::new(0);
mean_latencies.push(0.0);
p90_latencies.push(zero);
p99_latencies.push(zero);
}
}
});

// Extract the inner Knn parameters so we can produce a knn::Summary.
let knn_params = *run.parameters().knn();

Ok(super::knn::Summary {
setup: run.setup().clone(),
parameters: knn_params,
end_to_end_latencies: results.iter().map(|r| r.end_to_end_latency()).collect(),
recall,
mean_latencies,
p90_latencies,
p99_latencies,
mean_cmps: utils::average_all(
results
.iter()
.flat_map(|r| r.output().iter().map(|o| o.comparisons)),
),
mean_hops: utils::average_all(
results
.iter()
.flat_map(|r| r.output().iter().map(|o| o.hops)),
),
})
}
}
31 changes: 31 additions & 0 deletions diskann-benchmark/example/mimir-search-rag.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"search_directories": [
"C:/data/mimir"
],
"jobs": [
{
"type": "disk-index",
"content": {
"source": {
"disk-index-source": "Load",
"data_type": "float16",
"load_path": "C:/data/mimir/mimir_new"
},
"search_phase": {
"queries": "mimir_query.bin",
"groundtruth": "mimir_gt_1000.bin",
"search_list": [2000],
"beam_width": 4,
"recall_at": 1000,
"num_threads": 1,
"is_flat_search": false,
"distance": "squared_l2",
"vector_filters_file": null,
"is_rag_search": true,
"rag_eta": 0.01,
"rag_power": 2.0
}
}
}
]
}
31 changes: 31 additions & 0 deletions diskann-benchmark/example/mimir-search-standard-knn.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
{
"search_directories": [
"C:/data/mimir"
],
"jobs": [
{
"type": "disk-index",
"content": {
"source": {
"disk-index-source": "Load",
"data_type": "float16",
"load_path": "C:/data/mimir/mimir_new"
},
"search_phase": {
"queries": "mimir_query.bin",
"groundtruth": "mimir_gt_1000.bin",
"search_list": [2000],
"beam_width": 4,
"recall_at": 1000,
"num_threads": 1,
"is_flat_search": false,
"distance": "squared_l2",
"vector_filters_file": null,
"is_rag_search": false,
"rag_eta": null,
"rag_power": null
}
}
}
]
}
59 changes: 59 additions & 0 deletions diskann-benchmark/example/rag-search.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
{
"search_directories": [
"test_data/disk_index_search"
],
"jobs": [
{
"type": "async-index-build",
"content": {
"source": {
"index-source": "Load",
"data_type": "float32",
"distance": "squared_l2",
"load_path": "disk_index_siftsmall_learn_256pts_saved_index"
},
"search_phase": {
"search-type": "topk",
"queries": "disk_index_sample_query_10pts.fbin",
"groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin",
"reps": 3,
"num_threads": [1],
"runs": [
{
"search_n": 10,
"search_l": [10, 20, 50, 100],
"recall_k": 10
}
]
}
}
},
{
"type": "async-index-build",
"content": {
"source": {
"index-source": "Load",
"data_type": "float32",
"distance": "squared_l2",
"load_path": "disk_index_siftsmall_learn_256pts_saved_index"
},
"search_phase": {
"search-type": "topk-rag",
"queries": "disk_index_sample_query_10pts.fbin",
"groundtruth": "disk_index_10pts_idx_uint32_truth_search_res.bin",
"reps": 3,
"num_threads": [1],
"runs": [
{
"search_n": 10,
"search_l": [10, 20, 50, 100],
"recall_k": 10
}
],
"rag_eta": 0.01,
"rag_power": 2.0
}
}
}
]
}
Loading
Loading