-
Notifications
You must be signed in to change notification settings - Fork 100
Wrapper for all-neighbors knn graph building #785
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
jinsolp
wants to merge
53
commits into
rapidsai:branch-25.06
Choose a base branch
from
jinsolp:snmg-batching
base: branch-25.06
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
53 commits
Select commit
Hold shift + click to select a range
0a30e16
MG ann batch build
jinsolp b1ef22c
Merge branch 'rapidsai:branch-25.04' into snmg-batching
jinsolp cc555d8
organize headers
jinsolp 302334b
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp 79ca58a
Merge branch 'branch-25.04' into snmg-batching
jinsolp 7d516c0
makefile stylecheck
jinsolp 5eb7856
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp ec52b65
cleanup metric and add build
jinsolp a699b06
Merge branch 'branch-25.04' into snmg-batching
jinsolp 0bd0a0d
adding doc for build api
jinsolp e348808
cleanup headers
jinsolp 2993c54
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp 017d217
use raft kvp
jinsolp 7f5adcd
Merge branch 'branch-25.04' into snmg-batching
jinsolp 2bbb436
cleanup headers
jinsolp 0c5753b
change namespace and files to all_neighbors
jinsolp 3052994
WIP generalizing all-neighbors
jinsolp c72bfa8
split single build and batch build
jinsolp fd7b9fc
gtest nccl clique issues
jinsolp 4bdf741
WIP test cleanup
jinsolp 94b9210
splitting tests
jinsolp c2892ae
cleanup
jinsolp 1416ec4
cleanup and add raft_expects
jinsolp 3831691
Merge branch 'branch-25.04' into snmg-batching
jinsolp 481445d
using nn_descent_gnnd.hpp
jinsolp f9f9fb3
default n_clusters=1
jinsolp 4d8df17
build config output graph degree
jinsolp d76c220
Merge branch 'branch-25.04' into snmg-batching
jinsolp ce6c1a2
Merge branch 'rapidsai:branch-25.06' into snmg-batching
jinsolp 870153d
Merge branch 'rapidsai:branch-25.06' into snmg-batching
jinsolp c8766ec
WIP cleanup
jinsolp 4cbfd1e
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp 70fd440
adding support for device data when n_cluster=1
jinsolp 318f79d
adding device suppor test
jinsolp 8314cdc
test
jinsolp 2b5236a
handle metrics
jinsolp e5a1167
metric warning
jinsolp 53a9ab0
proper memory dealloc in builder and remove nccl
jinsolp 3eaae76
fix test name
jinsolp 74898f1
print statement
jinsolp a713ef7
multigpu assign clusters
jinsolp 1504e81
cleanup
jinsolp d0979f1
remove prints
jinsolp 1257816
remove print
jinsolp 92b7241
Merge branch 'branch-25.06' into snmg-batching
jinsolp 480779e
Merge branch 'branch-25.06' into snmg-batching
jinsolp 96cbc7c
diff max clusters size per worker
jinsolp e15b84b
Merge branch 'snmg-batching' of https://github.com/jinsolp/cuvs into …
jinsolp bd50a5a
Merge branch 'branch-25.06' into snmg-batching
jinsolp 132284c
Merge branch 'branch-25.06' into snmg-batching
jinsolp 792d110
doc and innerproduct
jinsolp cda1e4d
WIP removing index and cleanup
jinsolp b84a4a6
cleanup
jinsolp File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,149 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
|
||
#include <cuvs/neighbors/ivf_pq.hpp> | ||
#include <cuvs/neighbors/nn_descent.hpp> | ||
|
||
#include <variant> | ||
|
||
namespace cuvs::neighbors::all_neighbors { | ||
|
||
/** | ||
* @brief Parameters used to build an all-neighbors knn graph | ||
*/ | ||
namespace graph_build_params { | ||
|
||
/** Specialized parameters utilizing IVF-PQ to build knn graph */ | ||
struct ivf_pq_params { | ||
cuvs::neighbors::ivf_pq::index_params build_params; | ||
cuvs::neighbors::ivf_pq::search_params search_params; | ||
float refinement_rate = 2.0; | ||
}; | ||
|
||
using nn_descent_params = cuvs::neighbors::nn_descent::index_params; | ||
} // namespace graph_build_params | ||
|
||
using GraphBuildParams = | ||
std::variant<graph_build_params::ivf_pq_params, graph_build_params::nn_descent_params>; | ||
|
||
/** | ||
* @brief Parameters used to build an all-neighbors graph | ||
* | ||
* graph_build_params: graph building parameters for the given graph building algorithm. defaults | ||
* to ivfpq. | ||
* n_nearest_clusters: number of nearest clusters each data point will be assigned to in | ||
* the batching algorithm | ||
* n_clusters: number of total clusters (aka batches) to split the data into. If set to 1, algorithm | ||
* creates an all-neighbors graph without batching | ||
* metric: metric type | ||
* | ||
*/ | ||
struct all_neighbors_params { | ||
/** Parameters for knn graph building algorithm | ||
* | ||
* Set ivf_pq_params, or nn_descent_params to select the graph build | ||
* algorithm and control their parameters. | ||
* | ||
* @code{.cpp} | ||
* all_neighbors::index_params params; | ||
* // 1. Choose IVF-PQ algorithm | ||
* params.graph_build_params = all_neighbors::graph_build_params::ivf_pq_params{}; | ||
* | ||
* // 2. Choose NN Descent algorithm for kNN graph construction | ||
* params.graph_build_params = all_neighbors::graph_build_params::nn_descent_params{}; | ||
* | ||
* @endcode | ||
*/ | ||
GraphBuildParams graph_build_params; | ||
|
||
/** | ||
* Usage of n_nearest_clusters and n_clusters | ||
* | ||
* Hint1: the ratio of n_nearest_clusters / n_clusters determines device memory usage. | ||
* Approximately (n_nearest_clusters / n_clusters) * num_rows_in_entire_data number of rows will | ||
* be put on device memory at once. | ||
* E.g. between (n_nearest_clusters / n_clusters) = 2/10 and 2/20, the latter will use less device | ||
* memory. | ||
* | ||
* Hint2: larger n_nearest_clusters results in better accuracy of the final all-neighbors knn | ||
* graph. E.g. With the similar device memory usages, (n_nearest_clusters / n_clusters) = 4/20 | ||
* will have better accuracy than 2/10 at the cost of performance. | ||
*/ | ||
size_t n_nearest_clusters = 2; | ||
size_t n_clusters = 1; // defaults to not batching | ||
cuvs::distance::DistanceType metric = cuvs::distance::DistanceType::L2Expanded; | ||
}; | ||
|
||
/** | ||
* @brief Builds an approximate all-neighbors knn graph. | ||
* | ||
* Usage example: | ||
* @code{.cpp} | ||
* using namespace cuvs::neighbors; | ||
* // use default index parameters | ||
* all_neighbors::all_neighbors_params params; | ||
* auto indices = raft::make_device_matrix<int64_t, int64_t>(handle, n_row, k); | ||
* auto distances = raft::make_device_matrix<float, int64_t>(handle, n_row, k); | ||
* all_neighbors::build(res, params, dataset, indices.view(), distances.view()); | ||
* @endcode | ||
* | ||
* @param[in] handle raft::resources is an object mangaging resources | ||
* @param[in] params an instance of all_neighbors::all_neighbors_params that are parameters | ||
* to build all-neighbors knn graph | ||
* @param[in] dataset raft::host_matrix_view input dataset expected to be located | ||
* in host memory | ||
* @param[out] indices nearest neighbor indices of shape [n_row x k] | ||
* @param[out] distances nearest neighbor distances [n_row x k] | ||
*/ | ||
void build( | ||
const raft::resources& handle, | ||
const all_neighbors_params& params, | ||
raft::host_matrix_view<const float, int64_t, row_major> dataset, | ||
raft::device_matrix_view<int64_t, int64_t, row_major> indices, | ||
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances = std::nullopt); | ||
|
||
/** | ||
* @brief Builds an approximate all-neighbors knn graph. | ||
* params.n_clusters should be 1 for data on device. To use a larger | ||
* params.n_clusters for efficient device memory usage, put data on host RAM. | ||
* | ||
* Usage example: | ||
* @code{.cpp} | ||
* using namespace cuvs::neighbors; | ||
* // use default index parameters | ||
* all_neighbors::all_neighbors_params params; | ||
* auto indices = raft::make_device_matrix<int64_t, int64_t>(handle, n_row, k); | ||
* auto distances = raft::make_device_matrix<float, int64_t>(handle, n_row, k); | ||
* all_neighbors::build(res, params, dataset, indices.view(), distances.view()); | ||
* @endcode | ||
* | ||
* @param[in] handle raft::resources is an object mangaging resources | ||
* @param[in] params an instance of all_neighbors::all_neighbors_params that are parameters | ||
* to build all-neighbors knn graph | ||
* @param[in] dataset raft::device_matrix_view input dataset expected to be located | ||
* in device memory | ||
* @param[out] indices nearest neighbor indices of shape [n_row x k] | ||
* @param[out] distances nearest neighbor distances [n_row x k] | ||
*/ | ||
void build( | ||
const raft::resources& handle, | ||
const all_neighbors_params& params, | ||
raft::device_matrix_view<const float, int64_t, row_major> dataset, | ||
raft::device_matrix_view<int64_t, int64_t, row_major> indices, | ||
std::optional<raft::device_matrix_view<float, int64_t, row_major>> distances = std::nullopt); | ||
} // namespace cuvs::neighbors::all_neighbors |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,44 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#include "all_neighbors.cuh" | ||
|
||
namespace cuvs::neighbors::all_neighbors { | ||
|
||
#define CUVS_INST_ALL_NEIGHBORS(T, IdxT) \ | ||
void build(const raft::resources& handle, \ | ||
const all_neighbors_params& params, \ | ||
raft::host_matrix_view<const T, IdxT, row_major> dataset, \ | ||
raft::device_matrix_view<IdxT, IdxT, row_major> indices, \ | ||
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances) \ | ||
{ \ | ||
return all_neighbors::detail::build<T, IdxT>(handle, params, dataset, indices, distances); \ | ||
} \ | ||
\ | ||
void build(const raft::resources& handle, \ | ||
const all_neighbors_params& params, \ | ||
raft::device_matrix_view<const T, IdxT, row_major> dataset, \ | ||
raft::device_matrix_view<IdxT, IdxT, row_major> indices, \ | ||
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances) \ | ||
{ \ | ||
return all_neighbors::detail::build<T, IdxT>(handle, params, dataset, indices, distances); \ | ||
} | ||
|
||
CUVS_INST_ALL_NEIGHBORS(float, int64_t); | ||
|
||
#undef CUVS_INST_ALL_NEIGHBORS | ||
|
||
} // namespace cuvs::neighbors::all_neighbors |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
/* | ||
* Copyright (c) 2025, NVIDIA CORPORATION. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); | ||
* you may not use this file except in compliance with the License. | ||
* You may obtain a copy of the License at | ||
* | ||
* http://www.apache.org/licenses/LICENSE-2.0 | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, | ||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
* See the License for the specific language governing permissions and | ||
* limitations under the License. | ||
*/ | ||
|
||
#pragma once | ||
#include "all_neighbors_batched.cuh" | ||
#include <cuvs/neighbors/all_neighbors.hpp> | ||
#include <raft/util/cudart_utils.hpp> | ||
|
||
namespace cuvs::neighbors::all_neighbors::detail { | ||
using namespace cuvs::neighbors; | ||
|
||
void check_metric(const all_neighbors_params& params) | ||
{ | ||
if (std::holds_alternative<graph_build_params::nn_descent_params>(params.graph_build_params)) { | ||
auto allowed_metrics = params.metric == cuvs::distance::DistanceType::L2Expanded || | ||
params.metric == cuvs::distance::DistanceType::L2SqrtExpanded || | ||
params.metric == cuvs::distance::DistanceType::CosineExpanded || | ||
params.metric == cuvs::distance::DistanceType::InnerProduct; | ||
RAFT_EXPECTS(allowed_metrics, | ||
"Distance metric for all-neighbors build with NN Descent should be L2Expanded, " | ||
"L2SqrtExpanded, CosineExpanded, or InnerProduct"); | ||
} else if (std::holds_alternative<graph_build_params::ivf_pq_params>(params.graph_build_params)) { | ||
RAFT_EXPECTS(params.metric == cuvs::distance::DistanceType::L2Expanded, | ||
"Distance metric for all-neighbors build with IVFPQ should be L2Expanded"); | ||
} else { | ||
RAFT_FAIL("Invalid all-neighbors build algo"); | ||
} | ||
} | ||
|
||
// Single build (i.e. no batching) supports both host and device datasets | ||
template <typename T, typename IdxT, typename Accessor> | ||
void single_build( | ||
const raft::resources& handle, | ||
const all_neighbors_params& params, | ||
mdspan<const T, matrix_extent<IdxT>, row_major, Accessor> dataset, | ||
raft::device_matrix_view<IdxT, IdxT, row_major> indices, | ||
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt) | ||
{ | ||
size_t num_rows = static_cast<size_t>(dataset.extent(0)); | ||
size_t num_cols = static_cast<size_t>(dataset.extent(1)); | ||
|
||
auto knn_builder = get_knn_builder<T, IdxT>(handle, | ||
params.n_clusters, | ||
num_rows, | ||
num_rows, | ||
indices.extent(1), | ||
params.graph_build_params, | ||
params.metric, | ||
indices, | ||
distances); | ||
|
||
knn_builder->prepare_build(dataset); | ||
knn_builder->build_knn(dataset); | ||
} | ||
|
||
template <typename T, typename IdxT> | ||
void build(const raft::resources& handle, | ||
const all_neighbors_params& params, | ||
raft::host_matrix_view<const T, IdxT, row_major> dataset, | ||
raft::device_matrix_view<IdxT, IdxT, row_major> indices, | ||
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt) | ||
{ | ||
check_metric(params); | ||
|
||
RAFT_EXPECTS(dataset.extent(0) == indices.extent(0), | ||
"number of rows in dataset should be the same as number of rows in indices matrix"); | ||
|
||
if (distances.has_value()) { | ||
RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0) && | ||
indices.extent(1) == distances.value().extent(1), | ||
"indices matrix and distances matrix has to be the same shape."); | ||
} | ||
|
||
if (params.n_clusters == 1) { | ||
single_build(handle, params, dataset, indices, distances); | ||
} else { | ||
batch_build(handle, params, dataset, indices, distances); | ||
} | ||
} | ||
|
||
template <typename T, typename IdxT> | ||
void build(const raft::resources& handle, | ||
const all_neighbors_params& params, | ||
raft::device_matrix_view<const T, IdxT, row_major> dataset, | ||
raft::device_matrix_view<IdxT, IdxT, row_major> indices, | ||
std::optional<raft::device_matrix_view<T, IdxT, row_major>> distances = std::nullopt) | ||
{ | ||
check_metric(params); | ||
|
||
RAFT_EXPECTS(dataset.extent(0) == indices.extent(0), | ||
"number of rows in dataset should be the same as number of rows in indices matrix"); | ||
|
||
if (distances.has_value()) { | ||
RAFT_EXPECTS(indices.extent(0) == distances.value().extent(0) && | ||
indices.extent(1) == distances.value().extent(1), | ||
"indices matrix and distances matrix has to be the same shape."); | ||
} | ||
|
||
if (params.n_clusters > 1) { | ||
RAFT_FAIL( | ||
"Batched all-neighbors build is not supported with data on device. Put data on host for " | ||
"batch build."); | ||
} else { | ||
single_build(handle, params, dataset, indices, distances); | ||
} | ||
} | ||
} // namespace cuvs::neighbors::all_neighbors::detail |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't we allow inner product as well?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Now depending on this PR