Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add Fused L2 Expanded KNN kernel #339

Merged
merged 29 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
75c9f27
add fused L2 expanded kNN kernel, this is faster by at least 20-25% o…
mdoijade Sep 24, 2021
7a1e1e6
use lid > firsActiveLane instead of bitwise left shift and & for upda…
mdoijade Sep 27, 2021
e655cd4
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
mdoijade Sep 27, 2021
290d28d
fix incorrect output for NN >32 case when taking prod-cons knn merge …
mdoijade Sep 28, 2021
60d9201
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
mdoijade Sep 28, 2021
5f3cea1
fix clang format issues
mdoijade Sep 28, 2021
5b5f7a0
enable testing of cuml using this raft fork
mdoijade Sep 28, 2021
738c604
add custom atomicMax function which works fine if negative zeros are …
mdoijade Sep 29, 2021
15cbda8
merge branch-21.12 and test customAtomicMax without +0 addition
mdoijade Sep 29, 2021
352cc2d
fix hang in raft atomicMax of fp32 when the inputs are NaNs
mdoijade Sep 30, 2021
aa8ef09
remove redundant processing.hpp included in fused_l2_knn
mdoijade Oct 5, 2021
6072281
refactor fused L2 KNN main function to call both L2 expanded/unexpand…
mdoijade Oct 6, 2021
ae14f75
revert ball cover test to use brute_force_knn function instead of exp…
mdoijade Oct 6, 2021
53b6415
use isnan only if DeviceMax/Min operations in atomicCAS based functio…
mdoijade Oct 7, 2021
1d9ade3
fix clang format issues
mdoijade Oct 7, 2021
62cff7b
revert prtest.config changes, move fusedL2kNN launch/selection code t…
mdoijade Oct 11, 2021
9164a64
fix bug in updateSortedWarpQ for NN > 32, disable use of sqrt as it i…
mdoijade Oct 13, 2021
abc2b11
allocate workspace when resize is required for using prod-cons mutexes
mdoijade Oct 13, 2021
ec0cc32
add unit test for fused L2 KNN exp/unexp cases using faiss bfknn as g…
mdoijade Nov 2, 2021
700318d
merge branch-21.12 and update fused_l2_knn.cuh with those changes
mdoijade Nov 2, 2021
2b64775
move customAtomicMax to generic atomicMax specialization, and remove …
mdoijade Nov 2, 2021
ef9a898
fix clang format errors
mdoijade Nov 2, 2021
b317a12
call faiss before fusedL2knn kernel in the test
mdoijade Nov 3, 2021
9e2e19e
fix issues in verification function as it can happen that 2 vectors w…
mdoijade Nov 3, 2021
395beff
Merge branch 'branch-22.02' into fusedL2ExpandedKNN
mdoijade Nov 17, 2021
f0fd7b4
revert ball_cover test to use compute_bfknn which is wrapper for brut…
mdoijade Nov 17, 2021
bb099ca
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
cjnolet Nov 17, 2021
a2f1dee
Merge branch 'branch-22.02' into fusedL2ExpandedKNN
cjnolet Nov 23, 2021
bdce263
Adjusting rng.cuh
cjnolet Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
refactor fused L2 KNN main function to call both L2 expanded/unexpand…
…ed. make function namings consistent
  • Loading branch information
mdoijade committed Oct 6, 2021
commit 6072281734b7a0882b97d2c698c7ec5a01a2ba6a
184 changes: 86 additions & 98 deletions cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -514,10 +514,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(

template <typename DataT, typename AccT, typename OutT, typename IdxT,
int VecLen, bool usePrevTopKs, bool isRowMajor>
void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
void fusedL2UnexpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
IdxT lda, IdxT ldb, IdxT ldd, bool sqrt,
OutT *out_dists, IdxT *out_inds, IdxT numOfNN,
cudaStream_t stream, void *workspace,
size_t &worksize) {
typedef typename raft::linalg::Policy2x8<DataT, 1>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;

Expand All @@ -537,25 +538,25 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
typedef cub::KeyValuePair<uint32_t, AccT> Pair;

if (isRowMajor) {
constexpr auto fusedL2kNN32RowMajor =
constexpr auto fusedL2UnexpKnn32RowMajor =
fusedL2kNN<false, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 32, 2, usePrevTopKs, true>;
constexpr auto fusedL2kNN64RowMajor =
constexpr auto fusedL2UnexpKnn64RowMajor =
fusedL2kNN<false, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 64, 3, usePrevTopKs, true>;

auto fusedL2kNNRowMajor = fusedL2kNN32RowMajor;
auto fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor;
if (numOfNN <= 32) {
fusedL2kNNRowMajor = fusedL2kNN32RowMajor;
fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn32RowMajor;
} else if (numOfNN <= 64) {
fusedL2kNNRowMajor = fusedL2kNN64RowMajor;
fusedL2UnexpKnnRowMajor = fusedL2UnexpKnn64RowMajor;
} else {
ASSERT(numOfNN <= 64,
"fusedL2kNN: num of nearest neighbors must be <= 64");
}

dim3 grid = raft::distance::launchConfigGenerator<KPolicy>(
m, n, KPolicy::SmemSize, fusedL2kNNRowMajor);
m, n, KPolicy::SmemSize, fusedL2UnexpKnnRowMajor);
if (grid.x > 1) {
const auto numMutexes = raft::ceildiv<int>(m, KPolicy::Mblk);
if (workspace == nullptr || worksize < (sizeof(int32_t) * numMutexes)) {
Expand All @@ -570,7 +571,7 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
const auto sharedMemSize =
KPolicy::SmemSize + (KPolicy::Mblk * numOfNN * sizeof(Pair));

fusedL2kNNRowMajor<<<grid, blk, sharedMemSize, stream>>>(
fusedL2UnexpKnnRowMajor<<<grid, blk, sharedMemSize, stream>>>(
x, y, nullptr, nullptr, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt,
(uint32_t)numOfNN, (int *)workspace, out_dists, out_inds);
} else {
Expand All @@ -581,78 +582,32 @@ void fusedL2kNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,

template <typename DataT, typename AccT, typename OutT, typename IdxT,
bool usePrevTopKs, bool isRowMajor>
void fusedL2kNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
const DataT *x, const DataT *y, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
void fusedL2UnexpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
const DataT *x, const DataT *y, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
size_t bytesA = sizeof(DataT) * lda;
size_t bytesB = sizeof(DataT) * ldb;
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) {
fusedL2kNNImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace, worksize);
fusedL2UnexpKnnImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT),
usePrevTopKs, isRowMajor>(
x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream,
workspace, worksize);
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) {
fusedL2kNNImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace, worksize);
} else {
fusedL2kNNImpl<DataT, AccT, OutT, IdxT, 1, usePrevTopKs, isRowMajor>(
fusedL2UnexpKnnImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT),
usePrevTopKs, isRowMajor>(
x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream,
workspace, worksize);
}
}

/**
* Compute the k-nearest neighbors using L2 unexpanded distance.

* @tparam value_idx
* @tparam value_t
* @param[out] out_inds output indices array on device (size n_query_rows * k)
* @param[out] out_dists output dists array on device (size n_query_rows * k)
* @param[in] index input index array on device (size n_index_rows * D)
* @param[in] query input query array on device (size n_query_rows * D)
* @param[in] n_index_rows number of rows in index array
* @param[in] n_query_rows number of rows in query array
* @param[in] k number of closest neighbors to return
* @param[in] rowMajorIndex are the index arrays in row-major layout?
* @param[in] rowMajorQuery are the query array in row-major layout?
* @param[in] stream stream to order kernel launch
*/
template <raft::distance::DistanceType distanceType, typename value_idx,
typename value_t, bool usePrevTopKs>
void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, size_t n_query_rows, int k,
bool rowMajorIndex, bool rowMajorQuery,
cudaStream_t stream, void *workspace, size_t &worksize) {
// Validate the input data
ASSERT(k > 0, "l2Knn: k must be > 0");
ASSERT(D > 0, "l2Knn: D must be > 0");
ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0");
ASSERT(index, "l2Knn: index must be provided (passed null)");
ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0");
ASSERT(query, "l2Knn: query must be provided (passed null)");
ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)");
ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)");
// Currently we only support same layout for x & y inputs.
ASSERT(rowMajorIndex == rowMajorQuery,
"l2Knn: rowMajorIndex and rowMajorQuery should have same layout");

bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtUnexpanded);

if (rowMajorIndex) {
value_idx lda = D, ldb = D, ldd = n_index_rows;
fusedL2kNN<value_t, value_t, value_t, value_idx, usePrevTopKs, true>(
n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt,
out_dists, out_inds, k, stream, workspace, worksize);
} else {
// TODO: Add support for column major layout
fusedL2UnexpKnnImpl<DataT, AccT, OutT, IdxT, 1, usePrevTopKs, isRowMajor>(
x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream,
workspace, worksize);
}
}

template <typename DataT, typename AccT, typename OutT, typename IdxT,
int VecLen, bool usePrevTopKs, bool isRowMajor>
void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
void fusedL2ExpKnnImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
Expand All @@ -679,18 +634,18 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
typedef cub::KeyValuePair<uint32_t, AccT> Pair;

if (isRowMajor) {
constexpr auto fusedL2ExpkNN32RowMajor =
constexpr auto fusedL2ExpKnn32RowMajor =
fusedL2kNN<true, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 32, 2, usePrevTopKs, true>;
constexpr auto fusedL2ExpkNN64RowMajor =
constexpr auto fusedL2ExpKnn64RowMajor =
fusedL2kNN<true, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 64, 3, usePrevTopKs, true>;

auto fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor;
auto fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor;
if (numOfNN <= 32) {
fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor;
fusedL2ExpKnnRowMajor = fusedL2ExpKnn32RowMajor;
} else if (numOfNN <= 64) {
fusedL2ExpKNNRowMajor = fusedL2ExpkNN64RowMajor;
fusedL2ExpKnnRowMajor = fusedL2ExpKnn64RowMajor;
} else {
ASSERT(numOfNN <= 64,
"fusedL2kNN: num of nearest neighbors must be <= 64");
Expand All @@ -700,7 +655,7 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) +
(KPolicy::Mblk * numOfNN * sizeof(Pair));
dim3 grid = raft::distance::launchConfigGenerator<KPolicy>(
m, n, sharedMemSize, fusedL2ExpKNNRowMajor);
m, n, sharedMemSize, fusedL2ExpKnnRowMajor);
int32_t *mutexes = nullptr;
if (grid.x > 1) {
const auto numMutexes = raft::ceildiv<int>(m, KPolicy::Mblk);
Expand Down Expand Up @@ -732,7 +687,7 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
}
fusedL2ExpKNNRowMajor<<<grid, blk, sharedMemSize, stream>>>(
fusedL2ExpKnnRowMajor<<<grid, blk, sharedMemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt,
(uint32_t)numOfNN, mutexes, out_dists, out_inds);
} else {
Expand All @@ -743,36 +698,52 @@ void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,

template <typename DataT, typename AccT, typename OutT, typename IdxT,
bool usePrevTopKs, bool isRowMajor>
void fusedL2ExpkNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
void fusedL2ExpKnn(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
const DataT *x, const DataT *y, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
size_t bytesA = sizeof(DataT) * lda;
size_t bytesB = sizeof(DataT) * ldb;
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), usePrevTopKs,
fusedL2ExpKnnImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace,
worksize);
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), usePrevTopKs,
fusedL2ExpKnnImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace,
worksize);
} else {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 1, usePrevTopKs, isRowMajor>(
fusedL2ExpKnnImpl<DataT, AccT, OutT, IdxT, 1, usePrevTopKs, isRowMajor>(
x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream,
workspace, worksize);
}
}

/**
* Compute the k-nearest neighbors using L2 expanded/unexpanded distance.

* @tparam value_idx
* @tparam value_t
* @param[out] out_inds output indices array on device (size n_query_rows * k)
* @param[out] out_dists output dists array on device (size n_query_rows * k)
* @param[in] index input index array on device (size n_index_rows * D)
* @param[in] query input query array on device (size n_query_rows * D)
* @param[in] n_index_rows number of rows in index array
* @param[in] n_query_rows number of rows in query array
* @param[in] k number of closest neighbors to return
* @param[in] rowMajorIndex are the index arrays in row-major layout?
* @param[in] rowMajorQuery are the query array in row-major layout?
* @param[in] stream stream to order kernel launch
*/
template <raft::distance::DistanceType distanceType, typename value_idx,
typename value_t, bool usePrevTopKs>
void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, size_t n_query_rows, int k,
bool rowMajorIndex, bool rowMajorQuery,
cudaStream_t stream, void *workspace, size_t &worksize) {
void fusedL2Knn(size_t D, value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query, size_t n_index_rows,
size_t n_query_rows, int k, bool rowMajorIndex,
bool rowMajorQuery, cudaStream_t stream, void *workspace,
size_t &worksize) {
// Validate the input data
ASSERT(k > 0, "l2Knn: k must be > 0");
ASSERT(D > 0, "l2Knn: D must be > 0");
Expand All @@ -785,17 +756,34 @@ void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists,
// Currently we only support same layout for x & y inputs.
ASSERT(rowMajorIndex == rowMajorQuery,
"l2Knn: rowMajorIndex and rowMajorQuery should have same layout");

bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtExpanded);

if (rowMajorIndex) {
value_idx lda = D, ldb = D, ldd = n_index_rows;
fusedL2ExpkNN<value_t, value_t, value_t, value_idx, usePrevTopKs, true>(
n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt,
out_dists, out_inds, k, stream, workspace, worksize);
} else {
// TODO: Add support for column major layout
}
// TODO: Add support for column major layout
ASSERT(rowMajorIndex == true,
"l2Knn: only rowMajor inputs are supported for now.");

// Even for L2 Sqrt distance case we use non-sqrt version as FAISS bfKNN only support
// non-sqrt metric & some tests in RAFT/cuML (like Linkage) fails if we use L2 sqrt.
bool sqrt =
(distanceType == raft::distance::DistanceType::L2SqrtUnexpanded) ||
(distanceType == raft::distance::DistanceType::L2SqrtExpanded);

value_idx lda = D, ldb = D, ldd = n_index_rows;
switch (distanceType) {
case raft::distance::DistanceType::L2SqrtExpanded:
case raft::distance::DistanceType::L2Expanded:
fusedL2ExpKnn<value_t, value_t, value_t, value_idx, usePrevTopKs, true>(
n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt,
out_dists, out_inds, k, stream, workspace, worksize);
break;
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtUnexpanded:
fusedL2UnexpKnn<value_t, value_t, value_t, value_idx, usePrevTopKs, true>(
n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt,
out_dists, out_inds, k, stream, workspace, worksize);
break;
default:
printf("only L2 distance metric is supported\n");
break;
};
}

} // namespace detail
Expand Down
31 changes: 14 additions & 17 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -287,34 +287,31 @@ void brute_force_knn_impl(std::vector<float *> &input,
IntType>(search_items, input[i], n, sizes[i], D);
worksize = tempWorksize;
workspace.resize(worksize, stream);
l2_expanded_knn<raft::distance::DistanceType::L2Expanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
fusedL2Knn<raft::distance::DistanceType::L2Expanded, int64_t, float,
false>(D, out_i_ptr, out_d_ptr, input[i], search_items,
sizes[i], n, k, rowMajorIndex, rowMajorQuery,
stream, workspace.data(), worksize);
if (worksize > tempWorksize) {
workspace.resize(worksize, stream);
l2_expanded_knn<raft::distance::DistanceType::L2Expanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
fusedL2Knn<raft::distance::DistanceType::L2Expanded, int64_t, float,
false>(D, out_i_ptr, out_d_ptr, input[i], search_items,
sizes[i], n, k, rowMajorIndex, rowMajorQuery,
stream, workspace.data(), worksize);
}
break;
case raft::distance::DistanceType::L2Unexpanded:
// Even for L2 Sqrt distance case we use non-sqrt version
// as FAISS bfKNN only support non-sqrt metric & some tests
// in RAFT/cuML (like Linkage) fails if we use L2 sqrt.
// Even for L2 Sqrt distance case we use non-sqrt version
// as FAISS bfKNN only support non-sqrt metric & some tests
// in RAFT/cuML (like Linkage) fails if we use L2 sqrt.
case raft::distance::DistanceType::L2SqrtUnexpanded:
l2_unexpanded_knn<raft::distance::DistanceType::L2Unexpanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
fusedL2Knn<raft::distance::DistanceType::L2Unexpanded, int64_t, float,
false>(D, out_i_ptr, out_d_ptr, input[i], search_items,
sizes[i], n, k, rowMajorIndex, rowMajorQuery,
stream, workspace.data(), worksize);
if (worksize) {
workspace.resize(worksize, stream);
l2_unexpanded_knn<raft::distance::DistanceType::L2Unexpanded,
int64_t, float, false>(
fusedL2Knn<raft::distance::DistanceType::L2Unexpanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/test/spatial/ball_cover.cu
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,14 @@ void compute_bfknn(const raft::handle_t &handle, const value_t *X1,
} else {
size_t worksize = 0;
void *workspace = nullptr;
raft::spatial::knn::detail::l2_unexpanded_knn<
raft::spatial::knn::detail::fusedL2Knn<
raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t, false>(
(size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0], (size_t)n,
(int)k, true, true, handle.get_stream(), workspace, worksize);
if (worksize) {
rmm::device_uvector<int> d_mutexes(worksize, handle.get_stream());
workspace = d_mutexes.data();
raft::spatial::knn::detail::l2_unexpanded_knn<
raft::spatial::knn::detail::fusedL2Knn<
raft::distance::DistanceType::L2SqrtUnexpanded, int64_t, value_t,
false>((size_t)d, inds, dists, input_vec[0], X2, (size_t)sizes_vec[0],
(size_t)n, (int)k, true, true, handle.get_stream(), workspace,
Expand Down