Skip to content

Wrapper for all-neighbors knn graph building #785

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 53 commits into
base: branch-25.06
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
0a30e16
MG ann batch build
jinsolp Mar 25, 2025
b1ef22c
Merge branch 'rapidsai:branch-25.04' into snmg-batching
jinsolp Mar 25, 2025
cc555d8
organize headers
jinsolp Mar 25, 2025
302334b
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp Mar 25, 2025
79ca58a
Merge branch 'branch-25.04' into snmg-batching
jinsolp Mar 25, 2025
7d516c0
makefile stylecheck
jinsolp Mar 25, 2025
5eb7856
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp Mar 25, 2025
ec52b65
cleanup metric and add build
jinsolp Mar 31, 2025
a699b06
Merge branch 'branch-25.04' into snmg-batching
jinsolp Apr 1, 2025
0bd0a0d
adding doc for build api
jinsolp Apr 1, 2025
e348808
cleanup headers
jinsolp Apr 1, 2025
2993c54
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp Apr 1, 2025
017d217
use raft kvp
jinsolp Apr 1, 2025
7f5adcd
Merge branch 'branch-25.04' into snmg-batching
jinsolp Apr 1, 2025
2bbb436
cleanup headers
jinsolp Apr 2, 2025
0c5753b
change namespace and files to all_neighbors
jinsolp Apr 2, 2025
3052994
WIP generalizing all-neighbors
jinsolp Apr 2, 2025
c72bfa8
split single build and batch build
jinsolp Apr 2, 2025
fd7b9fc
gtest nccl clique issues
jinsolp Apr 2, 2025
4bdf741
WIP test cleanup
jinsolp Apr 2, 2025
94b9210
splitting tests
jinsolp Apr 3, 2025
c2892ae
cleanup
jinsolp Apr 3, 2025
1416ec4
cleanup and add raft_expects
jinsolp Apr 3, 2025
3831691
Merge branch 'branch-25.04' into snmg-batching
jinsolp Apr 3, 2025
481445d
using nn_descent_gnnd.hpp
jinsolp Apr 3, 2025
f9f9fb3
default n_clusters=1
jinsolp Apr 3, 2025
4d8df17
build config output graph degree
jinsolp Apr 9, 2025
d76c220
Merge branch 'branch-25.04' into snmg-batching
jinsolp Apr 9, 2025
ce6c1a2
Merge branch 'rapidsai:branch-25.06' into snmg-batching
jinsolp Apr 14, 2025
870153d
Merge branch 'rapidsai:branch-25.06' into snmg-batching
jinsolp Apr 16, 2025
c8766ec
WIP cleanup
jinsolp Apr 16, 2025
4cbfd1e
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp Apr 16, 2025
70fd440
adding support for device data when n_cluster=1
jinsolp Apr 16, 2025
318f79d
adding device suppor test
jinsolp Apr 17, 2025
8314cdc
test
jinsolp Apr 17, 2025
2b5236a
handle metrics
jinsolp Apr 17, 2025
e5a1167
metric warning
jinsolp Apr 17, 2025
53a9ab0
proper memory dealloc in builder and remove nccl
jinsolp Apr 18, 2025
3eaae76
fix test name
jinsolp Apr 18, 2025
74898f1
print statement
jinsolp Apr 18, 2025
a713ef7
multigpu assign clusters
jinsolp Apr 22, 2025
1504e81
cleanup
jinsolp Apr 22, 2025
d0979f1
remove prints
jinsolp Apr 22, 2025
1257816
remove print
jinsolp Apr 23, 2025
92b7241
Merge branch 'branch-25.06' into snmg-batching
jinsolp Apr 23, 2025
480779e
Merge branch 'branch-25.06' into snmg-batching
jinsolp Apr 23, 2025
96cbc7c
diff max clusters size per worker
jinsolp Apr 24, 2025
e15b84b
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp Apr 24, 2025
bd50a5a
Merge branch 'branch-25.06' into snmg-batching
jinsolp Apr 24, 2025
132284c
Merge branch 'branch-25.06' into snmg-batching
jinsolp Apr 25, 2025
792d110
doc and innerproduct
jinsolp May 3, 2025
cda1e4d
WIP removing index and cleanup
jinsolp May 6, 2025
b84a4a6
cleanup
jinsolp May 6, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ if(BUILD_SHARED_LIBS)
src/distance/pairwise_distance.cu
src/distance/sparse_distance.cu
src/embed/spectral.cu
src/neighbors/all_neighbors/all_neighbors.cu
src/neighbors/brute_force.cu
src/neighbors/brute_force_serialize.cu
src/neighbors/cagra_build_float.cu
Expand Down
149 changes: 149 additions & 0 deletions cpp/include/cuvs/neighbors/all_neighbors.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once

#include <cuvs/neighbors/ivf_pq.hpp>
#include <cuvs/neighbors/nn_descent.hpp>

#include <variant>

namespace cuvs::neighbors::all_neighbors {

/**
* @brief Parameters used to build an all-neighbors knn graph
*/
namespace graph_build_params {

/** Specialized parameters utilizing IVF-PQ to build knn graph */
struct ivf_pq_params {
cuvs::neighbors::ivf_pq::index_params build_params;
cuvs::neighbors::ivf_pq::search_params search_params;
float refinement_rate = 2.0;
};

using nn_descent_params = cuvs::neighbors::nn_descent::index_params;
} // namespace graph_build_params

using GraphBuildParams =
std::variant<graph_build_params::ivf_pq_params, graph_build_params::nn_descent_params>;

/**
* @brief Parameters used to build an all-neighbors graph
*
* graph_build_params: graph building parameters for the given graph building algorithm. defaults
* to ivfpq.
* n_nearest_clusters: number of nearest clusters each data point will be assigned to in
* the batching algorithm
* n_clusters: number of total clusters (aka batches) to split the data into. If set to 1, algorithm
* creates an all-neighbors graph without batching
* metric: metric type
*
*/
struct all_neighbors_params {
/** Parameters for knn graph building algorithm
*
* Set ivf_pq_params, or nn_descent_params to select the graph build
* algorithm and control their parameters.
*
* @code{.cpp}
* all_neighbors::index_params params;
* // 1. Choose IVF-PQ algorithm
* params.graph_build_params = all_neighbors::graph_build_params::ivf_pq_params{};
*
* // 2. Choose NN Descent algorithm for kNN graph construction
* params.graph_build_params = all_neighbors::graph_build_params::nn_descent_params{};
*
* @endcode
*/
GraphBuildParams graph_build_params;

/**
* Usage of n_nearest_clusters and n_clusters
*
* Hint1: the ratio of n_nearest_clusters / n_clusters determines device memory usage.
* Approximately (n_nearest_clusters / n_clusters) * num_rows_in_entire_data number of rows will
* be put on device memory at once.
* E.g. between (n_nearest_clusters / n_clusters) = 2/10 and 2/20, the latter will use less device
* memory.
*
* Hint2: larger n_nearest_clusters results in better accuracy of the final all-neighbors knn
* graph. E.g. With the similar device memory usages, (n_nearest_clusters / n_clusters) = 4/20
* will have better accuracy than 2/10 at the cost of performance.
*/
size_t n_nearest_clusters = 2;
size_t n_clusters = 1; // defaults to not batching
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded;
};

/**
* @brief Builds an approximate all-neighbors knn graph.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // use default index parameters
* all_neighbors::all_neighbors_params params;
* auto indices = raft::make_device_matrix<int64_t, int64_t>(handle, n_row, k);
* auto distances = raft::make_device_matrix<float, int64_t>(handle, n_row, k);
* all_neighbors::build(res, params, dataset, indices.view(), distances.view());
* @endcode
*
* @param[in] handle raft::resources is an object mangaging resources
* @param[in] params an instance of all_neighbors::all_neighbors_params that are parameters
* to build all-neighbors knn graph
* @param[in] dataset raft::host_matrix_view input dataset expected to be located
* in host memory
* @param[out] indices nearest neighbor indices of shape [n_row x k]
* @param[out] distances nearest neighbor distances [n_row x k]
*/
void build(
const raft::resources& handle,
const all_neighbors_params& params,
raft::host_matrix_view<const float, int64_t, row_major> dataset,
raft::device_matrix_view<int64_t, int64_t, row_major> indices,
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances = std::nullopt);

/**
* @brief Builds an approximate all-neighbors knn graph.
* params.n_clusters should be 1 for data on device. To use a larger
* params.n_clusters for efficient device memory usage, put data on host RAM.
*
* Usage example:
* @code{.cpp}
* using namespace cuvs::neighbors;
* // use default index parameters
* all_neighbors::all_neighbors_params params;
* auto indices = raft::make_device_matrix<int64_t, int64_t>(handle, n_row, k);
* auto distances = raft::make_device_matrix<float, int64_t>(handle, n_row, k);
* all_neighbors::build(res, params, dataset, indices.view(), distances.view());
* @endcode
*
* @param[in] handle raft::resources is an object mangaging resources
* @param[in] params an instance of all_neighbors::all_neighbors_params that are parameters
* to build all-neighbors knn graph
* @param[in] dataset raft::device_matrix_view input dataset expected to be located
* in device memory
* @param[out] indices nearest neighbor indices of shape [n_row x k]
* @param[out] distances nearest neighbor distances [n_row x k]
*/
void build(
const raft::resources& handle,
const all_neighbors_params& params,
raft::device_matrix_view<const float, int64_t, row_major> dataset,
raft::device_matrix_view<int64_t, int64_t, row_major> indices,
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances = std::nullopt);
} // namespace cuvs::neighbors::all_neighbors
44 changes: 44 additions & 0 deletions cpp/src/neighbors/all_neighbors/all_neighbors.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#include "all_neighbors.cuh"

namespace cuvs::neighbors::all_neighbors {

#define CUVS_INST_ALL_NEIGHBORS(T, IdxT) \
void build(const raft::resources& handle, \
const all_neighbors_params& params, \
raft::host_matrix_view<const T, IdxT, row_major> dataset, \
raft::device_matrix_view<IdxT, IdxT, row_major> indices, \
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances) \
{ \
return all_neighbors::detail::build<T, IdxT>(handle, params, dataset, indices, distances); \
} \
\
void build(const raft::resources& handle, \
const all_neighbors_params& params, \
raft::device_matrix_view<const T, IdxT, row_major> dataset, \
raft::device_matrix_view<IdxT, IdxT, row_major> indices, \
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances) \
{ \
return all_neighbors::detail::build<T, IdxT>(handle, params, dataset, indices, distances); \
}

CUVS_INST_ALL_NEIGHBORS(float, int64_t);

#undef CUVS_INST_ALL_NEIGHBORS

} // namespace cuvs::neighbors::all_neighbors
120 changes: 120 additions & 0 deletions cpp/src/neighbors/all_neighbors/all_neighbors.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
/*
* Copyright (c) 2025, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

#pragma once
#include "all_neighbors_batched.cuh"
#include <cuvs/neighbors/all_neighbors.hpp>
#include <raft/util/cudart_utils.hpp>

namespace cuvs::neighbors::all_neighbors::detail {
using namespace cuvs::neighbors;

void check_metric(const all_neighbors_params& params)
{
if (std::holds_alternative<graph_build_params::nn_descent_params>(params.graph_build_params)) {
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded ||
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't we allow inner product as well?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now depending on this PR

params.metric == cuvs::distance::DistanceType::L2SqrtExpanded ||
params.metric == cuvs::distance::DistanceType::CosineExpanded ||
params.metric == cuvs::distance::DistanceType::InnerProduct;
RAFT_EXPECTS(allowed_metrics,
"Distance metric for all-neighbors build with NN Descent should be L2Expanded, "
"L2SqrtExpanded, CosineExpanded, or InnerProduct");
} else if (std::holds_alternative<graph_build_params::ivf_pq_params>(params.graph_build_params)) {
RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded,
"Distance metric for all-neighbors build with IVFPQ should be L2Expanded");
} else {
RAFT_FAIL("Invalid all-neighbors build algo");
}
}

// Single build (i.e. no batching) supports both host and device datasets
template <typename T, typename IdxT, typename Accessor>
void single_build(
const raft::resources& handle,
const all_neighbors_params& params,
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset,
raft::device_matrix_view<IdxT, IdxT, row_major> indices,
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt)
{
size_t num_rows = static_cast<size_t>(dataset.extent(0));
size_t num_cols = static_cast<size_t>(dataset.extent(1));

auto knn_builder = get_knn_builder<T, IdxT>(handle,
params.n_clusters,
num_rows,
num_rows,
indices.extent(1),
params.graph_build_params,
params.metric,
indices,
distances);

knn_builder->prepare_build(dataset);
knn_builder->build_knn(dataset);
}

template <typename T, typename IdxT>
void build(const raft::resources& handle,
const all_neighbors_params& params,
raft::host_matrix_view<const T, IdxT, row_major> dataset,
raft::device_matrix_view<IdxT, IdxT, row_major> indices,
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt)
{
check_metric(params);

RAFT_EXPECTS(dataset.extent(0) == indices.extent(0),
"number of rows in dataset should be the same as number of rows in indices matrix");

if (distances.has_value()) {
RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0) &&
indices.extent(1) == distances.value().extent(1),
"indices matrix and distances matrix has to be the same shape.");
}

if (params.n_clusters == 1) {
single_build(handle, params, dataset, indices, distances);
} else {
batch_build(handle, params, dataset, indices, distances);
}
}

template <typename T, typename IdxT>
void build(const raft::resources& handle,
const all_neighbors_params& params,
raft::device_matrix_view<const T, IdxT, row_major> dataset,
raft::device_matrix_view<IdxT, IdxT, row_major> indices,
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt)
{
check_metric(params);

RAFT_EXPECTS(dataset.extent(0) == indices.extent(0),
"number of rows in dataset should be the same as number of rows in indices matrix");

if (distances.has_value()) {
RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0) &&
indices.extent(1) == distances.value().extent(1),
"indices matrix and distances matrix has to be the same shape.");
}

if (params.n_clusters > 1) {
RAFT_FAIL(
"Batched all-neighbors build is not supported with data on device. Put data on host for "
"batch build.");
} else {
single_build(handle, params, dataset, indices, distances);
}
}
} // namespace cuvs::neighbors::all_neighbors::detail
Loading
Loading