Skip to content

[Feat] Expose C API for CAGRA merge #860

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 90 additions & 0 deletions cpp/include/cuvs/neighbors/cagra.h
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,38 @@ cuvsError_t cuvsCagraIndexDestroy(cuvsCagraIndex_t index);
*/
cuvsError_t cuvsCagraIndexGetDims(cuvsCagraIndex_t index, int* dim);

/**
* @}
*/

/**
* @defgroup cagra_c_merge_params C API for CUDA ANN Graph-based nearest neighbor search
* @{
*/

/**
* @brief Supplemental parameters to merge CAGRA index
*
*/
/** Strategy for merging CAGRA indices. */
typedef enum {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we reuse this in the C++ version so we don't have to duplicate?

Copy link
Member Author

@rhdong rhdong Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suppose we can't do it because C doesn't support namespaces, :: , and enum class. It's not the first time we re-define in C( which should work well), like hash_mode vs cuvsCagraHashMode

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we forward it? Something like "using cuvsCagraMergeStrategy as cuvs::neighbors::cagra::merge_strategy"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this with the the "DistanceType` in pairwise distances. Maybe we can do the same thing here?

Copy link
Member Author

@rhdong rhdong May 3, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do this with the the "DistanceType` in pairwise distances. Maybe we can do the same thing here?

Sorry, I found we defined the separate DistanceType and converted it to cpp type via static_cast. Maybe I found the improper position. Could you hint me at the code position, by any chance, many thanks?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@rhdong, this is where we map the C++ type to the C type. That static cast shouldn't be needed and I suspect it's an artifact from before we have formally combined them. You are free to fix it if you have time.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, I understood! Let C++ be compatible with C 😃

PHYSICAL = 0, ///< Merge indices physically
LOGICAL = 1 ///< Merge indices logically (if supported)
} cuvsCagraMergeStrategy;

struct cuvsCagraMergeParams {
cuvsCagraIndexParams_t output_index_params;
cuvsCagraMergeStrategy strategy;
};

typedef struct cuvsCagraMergeParams* cuvsCagraMergeParams_t;

/** Allocate CAGRA merge params with default values */
cuvsError_t cuvsCagraMergeParamsCreate(cuvsCagraMergeParams_t* params);

/** De-allocate CAGRA merge params */
cuvsError_t cuvsCagraMergeParamsDestroy(cuvsCagraMergeParams_t params);

/**
* @}
*/
Expand Down Expand Up @@ -584,6 +616,64 @@ cuvsError_t cuvsCagraSerializeToHnswlib(cuvsResources_t res,
* @param[out] index CAGRA index loaded disk
*/
cuvsError_t cuvsCagraDeserialize(cuvsResources_t res, const char* filename, cuvsCagraIndex_t index);

/**
* @brief Merge multiple CAGRA indices into a single CAGRA index.
*
* All input indices must have been built with the same data type (`index.dtype`) and
* have the same dimensionality (`index.dims`). The merged index uses the output
* parameters specified in `cuvsCagraMergeParams`.
*
* Input indices must have:
* - `index.dtype.code` and `index.dtype.bits` matching across all indices.
* - Supported data types for indices:
* a. `kDLFloat` with `bits = 32`
* b. `kDLFloat` with `bits = 16`
* c. `kDLInt` with `bits = 8`
* d. `kDLUInt` with `bits = 8`
*
* The resulting output index will have the same data type as the input indices.
*
* Example:
* @code{.c}
* #include <cuvs/core/c_api.h>
* #include <cuvs/neighbors/cagra.h>
*
* cuvsResources_t res;
* cuvsError_t res_create_status = cuvsResourcesCreate(&res);
*
* cuvsCagraIndex_t index1, index2, merged_index;
* cuvsCagraIndexCreate(&index1);
* cuvsCagraIndexCreate(&index2);
* cuvsCagraIndexCreate(&merged_index);
*
* // Assume index1 and index2 have been built using cuvsCagraBuild
*
* cuvsCagraMergeParams_t merge_params;
* cuvsError_t params_create_status = cuvsCagraMergeParamsCreate(&merge_params);
*
* cuvsError_t merge_status = cuvsCagraMerge(res, merge_params, (cuvsCagraIndex_t[]){index1,
* index2}, 2, merged_index);
*
* // Use merged_index for search operations
*
* cuvsError_t params_destroy_status = cuvsCagraMergeParamsDestroy(merge_params);
* cuvsError_t res_destroy_status = cuvsResourcesDestroy(res);
* @endcode
*
* @param[in] res cuvsResources_t opaque C handle
* @param[in] params cuvsCagraMergeParams_t parameters controlling merge behavior
* @param[in] indices Array of input cuvsCagraIndex_t handles to merge
* @param[in] num_indices Number of input indices
* @param[out] output_index Output handle that will store the merged index.
* Must be initialized using `cuvsCagraIndexCreate` before use.
*/
cuvsError_t cuvsCagraMerge(cuvsResources_t res,
cuvsCagraMergeParams_t params,
cuvsCagraIndex_t* indices,
size_t num_indices,
cuvsCagraIndex_t output_index);

/**
* @}
*/
Expand Down
14 changes: 2 additions & 12 deletions cpp/include/cuvs/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

#include "common.hpp"
#include <cuvs/distance/distance.hpp>
#include <cuvs/neighbors/cagra.h>
#include <cuvs/neighbors/common.hpp>
#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/neighbors/nn_descent.hpp>
Expand Down Expand Up @@ -278,18 +279,7 @@ struct extend_params {
*
* @note Currently, only the PHYSICAL strategy is supported.
*/
enum MergeStrategy {
/**
* @brief Physical merge: Builds a new CAGRA graph from the union of dataset points
* in existing CAGRA graphs.
*
* This is expensive to build but does not impact search latency or quality.
* Preferred for many smaller CAGRA graphs.
*
* @note Currently, this is the only supported strategy.
*/
PHYSICAL
};
using MergeStrategy = cuvsCagraMergeStrategy;

/**
* @brief Parameters for merging CAGRA indexes.
Expand Down
216 changes: 165 additions & 51 deletions cpp/src/neighbors/cagra_c.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,70 +33,82 @@

namespace {

template <typename T>
void* _build(cuvsResources_t res, cuvsCagraIndexParams params, DLManagedTensor* dataset_tensor)
{
auto dataset = dataset_tensor->dl_tensor;

auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::cagra::index<T, uint32_t>(*res_ptr);

auto index_params = cuvs::neighbors::cagra::index_params();
index_params.metric = static_cast<cuvs::distance::DistanceType>((int)params.metric),
index_params.intermediate_graph_degree = params.intermediate_graph_degree;
index_params.graph_degree = params.graph_degree;
static void _set_graph_build_params(
std::variant<std::monostate,
cuvs::neighbors::cagra::graph_build_params::ivf_pq_params,
cuvs::neighbors::cagra::graph_build_params::nn_descent_params,
cuvs::neighbors::cagra::graph_build_params::iterative_search_params>& out_params,
cuvsCagraIndexParams& params,
cuvsCagraGraphBuildAlgo algo,
int64_t n_rows,
int64_t dim)

switch (params.build_algo) {
{
switch (algo) {
case cuvsCagraGraphBuildAlgo::AUTO_SELECT: break;
case cuvsCagraGraphBuildAlgo::IVF_PQ: {
auto dataset_extent = raft::matrix_extent<int64_t>(dataset.shape[0], dataset.shape[1]);
auto pq_params = cuvs::neighbors::cagra::graph_build_params::ivf_pq_params(dataset_extent);
auto ivf_pq_build_params = params.graph_build_params->ivf_pq_build_params;
auto ivf_pq_search_params = params.graph_build_params->ivf_pq_search_params;
if (ivf_pq_build_params) {
pq_params.build_params.add_data_on_build = ivf_pq_build_params->add_data_on_build;
pq_params.build_params.n_lists = ivf_pq_build_params->n_lists;
pq_params.build_params.kmeans_n_iters = ivf_pq_build_params->kmeans_n_iters;
pq_params.build_params.kmeans_trainset_fraction =
ivf_pq_build_params->kmeans_trainset_fraction;
pq_params.build_params.pq_bits = ivf_pq_build_params->pq_bits;
pq_params.build_params.pq_dim = ivf_pq_build_params->pq_dim;
pq_params.build_params.codebook_kind =
static_cast<cuvs::neighbors::ivf_pq::codebook_gen>(ivf_pq_build_params->codebook_kind);
pq_params.build_params.force_random_rotation = ivf_pq_build_params->force_random_rotation;
pq_params.build_params.conservative_memory_allocation =
ivf_pq_build_params->conservative_memory_allocation;
pq_params.build_params.max_train_points_per_pq_code =
ivf_pq_build_params->max_train_points_per_pq_code;
}
if (ivf_pq_search_params) {
pq_params.search_params.n_probes = ivf_pq_search_params->n_probes;
pq_params.search_params.lut_dtype = ivf_pq_search_params->lut_dtype;
pq_params.search_params.internal_distance_dtype =
ivf_pq_search_params->internal_distance_dtype;
pq_params.search_params.preferred_shmem_carveout =
ivf_pq_search_params->preferred_shmem_carveout;
auto pq_params = cuvs::neighbors::cagra::graph_build_params::ivf_pq_params(
raft::matrix_extent<int64_t>(n_rows, dim));
if (params.graph_build_params) {
auto ivf_params = static_cast<cuvsIvfPqParams*>(params.graph_build_params);
if (ivf_params->ivf_pq_build_params) {
auto bp = ivf_params->ivf_pq_build_params;
pq_params.build_params.add_data_on_build = bp->add_data_on_build;
pq_params.build_params.n_lists = bp->n_lists;
pq_params.build_params.kmeans_n_iters = bp->kmeans_n_iters;
pq_params.build_params.kmeans_trainset_fraction = bp->kmeans_trainset_fraction;
pq_params.build_params.pq_bits = bp->pq_bits;
pq_params.build_params.pq_dim = bp->pq_dim;
pq_params.build_params.codebook_kind =
static_cast<cuvs::neighbors::ivf_pq::codebook_gen>(bp->codebook_kind);
pq_params.build_params.force_random_rotation = bp->force_random_rotation;
pq_params.build_params.conservative_memory_allocation =
bp->conservative_memory_allocation;
pq_params.build_params.max_train_points_per_pq_code = bp->max_train_points_per_pq_code;
}
if (ivf_params->ivf_pq_search_params) {
auto sp = ivf_params->ivf_pq_search_params;
pq_params.search_params.n_probes = sp->n_probes;
pq_params.search_params.lut_dtype = sp->lut_dtype;
pq_params.search_params.internal_distance_dtype = sp->internal_distance_dtype;
pq_params.search_params.preferred_shmem_carveout = sp->preferred_shmem_carveout;
}
if (ivf_params->refinement_rate > 1.0f) {
pq_params.refinement_rate = ivf_params->refinement_rate;
}
}
if (params.graph_build_params->refinement_rate > 1) {
pq_params.refinement_rate = params.graph_build_params->refinement_rate;
}
index_params.graph_build_params = pq_params;
out_params = pq_params;
break;
}
case cuvsCagraGraphBuildAlgo::NN_DESCENT: {
cuvs::neighbors::cagra::graph_build_params::nn_descent_params nn_descent_params{};
nn_descent_params =
cuvs::neighbors::nn_descent::index_params(index_params.intermediate_graph_degree);
nn_descent_params.max_iterations = params.nn_descent_niter;
index_params.graph_build_params = nn_descent_params;
auto nn_params = cuvs::neighbors::nn_descent::index_params(params.intermediate_graph_degree);
nn_params.max_iterations = params.nn_descent_niter;
out_params = nn_params;
break;
}
case cuvsCagraGraphBuildAlgo::ITERATIVE_CAGRA_SEARCH: {
cuvs::neighbors::cagra::graph_build_params::iterative_search_params p;
index_params.graph_build_params = p;
out_params = p;
break;
}
};
}
}

template <typename T>
void* _build(cuvsResources_t res, cuvsCagraIndexParams params, DLManagedTensor* dataset_tensor)
{
auto dataset = dataset_tensor->dl_tensor;

auto res_ptr = reinterpret_cast<raft::resources*>(res);
auto index = new cuvs::neighbors::cagra::index<T, uint32_t>(*res_ptr);

auto index_params = cuvs::neighbors::cagra::index_params();
index_params.metric = static_cast<cuvs::distance::DistanceType>((int)params.metric),
index_params.intermediate_graph_degree = params.intermediate_graph_degree;
index_params.graph_degree = params.graph_degree;

_set_graph_build_params(
index_params.graph_build_params, params, params.build_algo, dataset.shape[0], dataset.shape[1]);

if (auto* cparams = params.compression; cparams != nullptr) {
auto compression_params = cuvs::neighbors::vpq_params();
Expand Down Expand Up @@ -266,6 +278,54 @@ void* _deserialize(cuvsResources_t res, const char* filename)
return index;
}

template <typename T>
void* _merge(cuvsResources_t res,
cuvsCagraMergeParams params,
cuvsCagraIndex_t* indices,
size_t num_indices)
{
auto res_ptr = reinterpret_cast<raft::resources*>(res);
cuvs::neighbors::cagra::merge_params merge_params_cpp;
auto& out_idx_params = *params.output_index_params;

merge_params_cpp.output_index_params.metric =
static_cast<cuvs::distance::DistanceType>((int)out_idx_params.metric);
merge_params_cpp.output_index_params.intermediate_graph_degree =
out_idx_params.intermediate_graph_degree;
merge_params_cpp.output_index_params.graph_degree = out_idx_params.graph_degree;

int64_t total_size = 0;
int64_t dim = 0;
if (out_idx_params.build_algo == cuvsCagraGraphBuildAlgo::IVF_PQ) {
auto first_idx_ptr =
reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(indices[0]->addr);
dim = first_idx_ptr->dim();
for (size_t i = 0; i < num_indices; ++i) {
auto idx_ptr =
reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(indices[i]->addr);
total_size += idx_ptr->size();
}
}

_set_graph_build_params(merge_params_cpp.output_index_params.graph_build_params,
out_idx_params,
out_idx_params.build_algo,
total_size,
dim);

std::vector<cuvs::neighbors::cagra::index<T, uint32_t>*> index_ptrs;
index_ptrs.reserve(num_indices);
for (size_t i = 0; i < num_indices; ++i) {
auto idx_ptr = reinterpret_cast<cuvs::neighbors::cagra::index<T, uint32_t>*>(indices[i]->addr);
index_ptrs.push_back(idx_ptr);
}

auto merged_index = new cuvs::neighbors::cagra::index<T, uint32_t>(
cuvs::neighbors::cagra::merge(*res_ptr, merge_params_cpp, index_ptrs));

return merged_index;
}

} // namespace

extern "C" cuvsError_t cuvsCagraIndexCreate(cuvsCagraIndex_t* index)
Expand Down Expand Up @@ -398,6 +458,43 @@ extern "C" cuvsError_t cuvsCagraSearch(cuvsResources_t res,
});
}

extern "C" cuvsError_t cuvsCagraMerge(cuvsResources_t res,
cuvsCagraMergeParams_t params,
cuvsCagraIndex_t* indices,
size_t num_indices,
cuvsCagraIndex_t output_index)
{
return cuvs::core::translate_exceptions([=] {
// Basic checks on inputs
RAFT_EXPECTS(indices != nullptr && num_indices > 0, "indices array cannot be null or empty");
// Use first index dtype as reference
auto dtype = (*indices[0]).dtype;
for (size_t i = 1; i < num_indices; ++i) {
RAFT_EXPECTS((*indices[i]).dtype.code == dtype.code && (*indices[i]).dtype.bits == dtype.bits,
"All input indices must have the same data type");
RAFT_EXPECTS((*indices[i]).addr != 0, "All input indices must be built (non-empty)");
}
RAFT_EXPECTS(output_index != nullptr, "Output index pointer must not be null");
output_index->dtype = dtype; // output index type matches inputs
// Dispatch based on data type
if (dtype.code == kDLFloat && dtype.bits == 32) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<float>(res, *params, indices, num_indices));
} else if (dtype.code == kDLFloat && dtype.bits == 16) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<half>(res, *params, indices, num_indices));
} else if (dtype.code == kDLInt && dtype.bits == 8) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<int8_t>(res, *params, indices, num_indices));
} else if (dtype.code == kDLUInt && dtype.bits == 8) {
output_index->addr =
reinterpret_cast<uintptr_t>(_merge<uint8_t>(res, *params, indices, num_indices));
} else {
RAFT_FAIL("Unsupported index data type: code=%d, bits=%d", dtype.code, dtype.bits);
}
});
}

extern "C" cuvsError_t cuvsCagraIndexParamsCreate(cuvsCagraIndexParams_t* params)
{
return cuvs::core::translate_exceptions([=] {
Expand Down Expand Up @@ -469,6 +566,23 @@ extern "C" cuvsError_t cuvsCagraSearchParamsDestroy(cuvsCagraSearchParams_t para
return cuvs::core::translate_exceptions([=] { delete params; });
}

extern "C" cuvsError_t cuvsCagraMergeParamsCreate(cuvsCagraMergeParams_t* params)
{
return cuvs::core::translate_exceptions([=] {
cuvsCagraIndexParams_t idx_params;
cuvsCagraIndexParamsCreate(&idx_params);
*params = new cuvsCagraMergeParams{.output_index_params = idx_params, .strategy = PHYSICAL};
});
}

extern "C" cuvsError_t cuvsCagraMergeParamsDestroy(cuvsCagraMergeParams_t params)
{
return cuvs::core::translate_exceptions([=] {
cuvsCagraIndexParamsDestroy(params->output_index_params);
delete params;
});
}

extern "C" cuvsError_t cuvsCagraDeserialize(cuvsResources_t res,
const char* filename,
cuvsCagraIndex_t index)
Expand Down
Loading
Loading