Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[REVIEW] Add Fused L2 Expanded KNN kernel #339

Merged
merged 29 commits into from
Nov 23, 2021
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
75c9f27
add fused L2 expanded kNN kernel, this is faster by at least 20-25% o…
mdoijade Sep 24, 2021
7a1e1e6
use lid > firsActiveLane instead of bitwise left shift and & for upda…
mdoijade Sep 27, 2021
e655cd4
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
mdoijade Sep 27, 2021
290d28d
fix incorrect output for NN >32 case when taking prod-cons knn merge …
mdoijade Sep 28, 2021
60d9201
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
mdoijade Sep 28, 2021
5f3cea1
fix clang format issues
mdoijade Sep 28, 2021
5b5f7a0
enable testing of cuml using this raft fork
mdoijade Sep 28, 2021
738c604
add custom atomicMax function which works fine if negative zeros are …
mdoijade Sep 29, 2021
15cbda8
merge branch-21.12 and test customAtomicMax without +0 addition
mdoijade Sep 29, 2021
352cc2d
fix hang in raft atomicMax of fp32 when the inputs are NaNs
mdoijade Sep 30, 2021
aa8ef09
remove redundant processing.hpp included in fused_l2_knn
mdoijade Oct 5, 2021
6072281
refactor fused L2 KNN main function to call both L2 expanded/unexpand…
mdoijade Oct 6, 2021
ae14f75
revert ball cover test to use brute_force_knn function instead of exp…
mdoijade Oct 6, 2021
53b6415
use isnan only if DeviceMax/Min operations in atomicCAS based functio…
mdoijade Oct 7, 2021
1d9ade3
fix clang format issues
mdoijade Oct 7, 2021
62cff7b
revert prtest.config changes, move fusedL2kNN launch/selection code t…
mdoijade Oct 11, 2021
9164a64
fix bug in updateSortedWarpQ for NN > 32, disable use of sqrt as it i…
mdoijade Oct 13, 2021
abc2b11
allocate workspace when resize is required for using prod-cons mutexes
mdoijade Oct 13, 2021
ec0cc32
add unit test for fused L2 KNN exp/unexp cases using faiss bfknn as g…
mdoijade Nov 2, 2021
700318d
merge branch-21.12 and update fused_l2_knn.cuh with those changes
mdoijade Nov 2, 2021
2b64775
move customAtomicMax to generic atomicMax specialization, and remove …
mdoijade Nov 2, 2021
ef9a898
fix clang format errors
mdoijade Nov 2, 2021
b317a12
call faiss before fusedL2knn kernel in the test
mdoijade Nov 3, 2021
9e2e19e
fix issues in verification function as it can happen that 2 vectors w…
mdoijade Nov 3, 2021
395beff
Merge branch 'branch-22.02' into fusedL2ExpandedKNN
mdoijade Nov 17, 2021
f0fd7b4
revert ball_cover test to use compute_bfknn which is wrapper for brut…
mdoijade Nov 17, 2021
bb099ca
Merge branch 'branch-21.12' into fusedL2ExpandedKNN
cjnolet Nov 17, 2021
a2f1dee
Merge branch 'branch-22.02' into fusedL2ExpandedKNN
cjnolet Nov 23, 2021
bdce263
Adjusting rng.cuh
cjnolet Nov 23, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions ci/prtest.config
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
RUN_CUGRAPH_LIBCUGRAPH_TESTS=OFF
RUN_CUGRAPH_PYTHON_TESTS=OFF

RUN_CUML_LIBCUML_TESTS=OFF
RUN_CUML_PRIMS_TESTS=OFF
RUN_CUML_PYTHON_TESTS=OFF
RUN_CUML_LIBCUML_TESTS=ON
RUN_CUML_PRIMS_TESTS=ON
RUN_CUML_PYTHON_TESTS=ON
225 changes: 209 additions & 16 deletions cpp/include/raft/spatial/knn/detail/fused_l2_knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
#include <faiss/gpu/utils/Select.cuh>
#include <limits>
#include <raft/distance/pairwise_distance_base.cuh>
#include <raft/linalg/norm.cuh>
#include "processing.hpp"

namespace raft {
Expand Down Expand Up @@ -143,11 +144,11 @@ DI void updateSortedWarpQ(myWarpSelect &heapArr, Pair *allWarpTopKs, int rowId,
Pair tempKV;
tempKV.value = raft::shfl(heapArr->warpK[i], srcLane);
tempKV.key = raft::shfl(heapArr->warpV[i], srcLane);
const auto firstActiveLane = __ffs(activeLanes);
if (firstActiveLane == (lid + 1)) {
const auto firstActiveLane = __ffs(activeLanes) - 1;
if (firstActiveLane == lid) {
heapArr->warpK[i] = KVPair.value;
heapArr->warpV[i] = KVPair.key;
} else if (activeLanes & ((uint32_t)1 << lid)) {
} else if (lid > firstActiveLane) {
heapArr->warpK[i] = tempKV.value;
heapArr->warpV[i] = tempKV.key;
}
Expand Down Expand Up @@ -191,7 +192,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
}

volatile int *mutex = mutexes;
Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]);

Pair *shDumpKV = nullptr;
if (useNorms) {
shDumpKV =
(Pair *)(&smem[Policy::SmemSize +
((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]);
} else {
shDumpKV = (Pair *)(&smem[Policy::SmemSize]);
}

const int lid = threadIdx.x % warpSize;
const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols);

Expand All @@ -204,13 +214,11 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
myWarpSelect heapArr1(identity, keyMax, numOfNN);
myWarpSelect heapArr2(identity, keyMax, numOfNN);
myWarpSelect *heapArr[] = {&heapArr1, &heapArr2};
__syncthreads();
__syncwarp();

loadAllWarpQShmem<Policy, Pair>(heapArr, &shDumpKV[0], m, numOfNN);

while (cta_processed < gridDim.x - 1) {
Pair otherKV[Policy::AccRowsPerTh];

if (threadIdx.x == 0) {
int32_t old = -3;
while (old != -1) {
Expand All @@ -223,12 +231,19 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
otherKV[i].value = identity;
otherKV[i].key = keyMax;

if (lid < numOfNN && rowId < m) {
otherKV[i].value = out_dists[rowId * numOfNN + lid];
otherKV[i].key = (uint32_t)out_inds[rowId * numOfNN + lid];
const auto shMemRowId =
(threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN && rowId < m) {
otherKV.value = out_dists[rowId * numOfNN + idx];
otherKV.key = (uint32_t)out_inds[rowId * numOfNN + idx];
shDumpKV[shMemRowId * numOfNN + idx] = otherKV;
}
}
}
__threadfence();
Expand All @@ -239,14 +254,27 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
}

// Perform merging of otherKV with topk's across warp.
__syncwarp();

#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
const auto rowId = starty + i * Policy::AccThRows;
const auto shMemRowId =
(threadIdx.x / Policy::AccThCols) + i * Policy::AccThRows;
if (rowId < m) {
heapArr[i]->add(otherKV[i].value, otherKV[i].key);
#pragma unroll
for (int j = 0; j < heapArr[i]->kNumWarpQRegisters; ++j) {
Pair otherKV;
otherKV.value = identity;
otherKV.key = keyMax;
const auto idx = j * warpSize + lid;
if (idx < numOfNN) {
otherKV = shDumpKV[shMemRowId * numOfNN + idx];
}
heapArr[i]->add(otherKV.value, otherKV.key);
}
}
}

cta_processed++;
}
#pragma unroll
Expand Down Expand Up @@ -300,6 +328,16 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
[numOfNN, sqrt, m, n, ldd, out_dists, out_inds] __device__(
AccT acc[Policy::AccRowsPerTh][Policy::AccColsPerTh], DataT * regxn,
DataT * regyn, IdxT gridStrideX, IdxT gridStrideY) {
if (useNorms) {
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
#pragma unroll
for (int j = 0; j < Policy::AccColsPerTh; ++j) {
acc[i][j] = regxn[i] + regyn[j] - (DataT)2.0 * acc[i][j];
}
}
}

if (sqrt) {
#pragma unroll
for (int i = 0; i < Policy::AccRowsPerTh; ++i) {
Expand All @@ -309,7 +347,14 @@ __global__ __launch_bounds__(Policy::Nthreads, 2) void fusedL2kNN(
}
}
}
Pair *shDumpKV = (Pair *)(&smem[Policy::SmemSize]);
Pair *shDumpKV = nullptr;
if (useNorms) {
shDumpKV =
(Pair *)(&smem[Policy::SmemSize +
((Policy::Mblk + Policy::Nblk) * sizeof(DataT))]);
} else {
shDumpKV = (Pair *)(&smem[Policy::SmemSize]);
}

constexpr uint32_t mask = 0xffffffffu;
const IdxT starty = gridStrideY + (threadIdx.x / Policy::AccThCols);
Expand Down Expand Up @@ -606,6 +651,154 @@ void l2_unexpanded_knn(size_t D, value_idx *out_inds, value_t *out_dists,
}
}

template <typename DataT, typename AccT, typename OutT, typename IdxT,
int VecLen, bool usePrevTopKs, bool isRowMajor>
void fusedL2ExpKNNImpl(const DataT *x, const DataT *y, IdxT m, IdxT n, IdxT k,
IdxT lda, IdxT ldb, IdxT ldd, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
typedef typename raft::linalg::Policy2x8<DataT, 1>::Policy RowPolicy;
typedef typename raft::linalg::Policy4x4<DataT, VecLen>::ColPolicy ColPolicy;

typedef typename std::conditional<true, RowPolicy, ColPolicy>::type KPolicy;

ASSERT(isRowMajor, "Only Row major inputs are allowed");

ASSERT(!(((x != y) && (worksize < (m + n) * sizeof(AccT))) ||
(worksize < m * sizeof(AccT))),
"workspace size error");
ASSERT(workspace != nullptr, "workspace is null");

dim3 blk(KPolicy::Nthreads);
// Accumulation operation lambda
auto core_lambda = [] __device__(AccT & acc, DataT & x, DataT & y) {
acc += x * y;
};

auto fin_op = [] __device__(AccT d_val, int g_d_idx) { return d_val; };

typedef cub::KeyValuePair<uint32_t, AccT> Pair;

if (isRowMajor) {
constexpr auto fusedL2ExpkNN32RowMajor =
fusedL2kNN<true, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 32, 2, usePrevTopKs, true>;
constexpr auto fusedL2ExpkNN64RowMajor =
fusedL2kNN<true, DataT, AccT, OutT, IdxT, KPolicy, decltype(core_lambda),
decltype(fin_op), 64, 3, usePrevTopKs, true>;

auto fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor;
if (numOfNN <= 32) {
fusedL2ExpKNNRowMajor = fusedL2ExpkNN32RowMajor;
} else if (numOfNN <= 64) {
fusedL2ExpKNNRowMajor = fusedL2ExpkNN64RowMajor;
} else {
ASSERT(numOfNN <= 64,
"fusedL2kNN: num of nearest neighbors must be <= 64");
}

const auto sharedMemSize =
KPolicy::SmemSize + ((KPolicy::Mblk + KPolicy::Nblk) * sizeof(DataT)) +
(KPolicy::Mblk * numOfNN * sizeof(Pair));
dim3 grid = raft::distance::launchConfigGenerator<KPolicy>(
m, n, sharedMemSize, fusedL2ExpKNNRowMajor);
int32_t *mutexes = nullptr;
if (grid.x > 1) {
const auto numMutexes = raft::ceildiv<int>(m, KPolicy::Mblk);
const auto normsSize =
(x != y) ? (m + n) * sizeof(DataT) : n * sizeof(DataT);
const auto requiredSize = sizeof(int32_t) * numMutexes + normsSize;
if (worksize < requiredSize) {
worksize = requiredSize;
return;
} else {
mutexes = (int32_t *)((char *)workspace + normsSize);
CUDA_CHECK(
cudaMemsetAsync(mutexes, 0, sizeof(int32_t) * numMutexes, stream));
}
}

DataT *xn = (DataT *)workspace;
DataT *yn = (DataT *)workspace;

auto norm_op = [] __device__(DataT in) { return in; };

if (x != y) {
yn += m;
raft::linalg::rowNorm(xn, x, k, m, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
raft::linalg::rowNorm(yn, y, k, n, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
} else {
raft::linalg::rowNorm(xn, x, k, n, raft::linalg::L2Norm, isRowMajor,
stream, norm_op);
}
fusedL2ExpKNNRowMajor<<<grid, blk, sharedMemSize, stream>>>(
x, y, xn, yn, m, n, k, lda, ldb, ldd, core_lambda, fin_op, sqrt,
(uint32_t)numOfNN, mutexes, out_dists, out_inds);
} else {
}

CUDA_CHECK(cudaGetLastError());
}

template <typename DataT, typename AccT, typename OutT, typename IdxT,
bool usePrevTopKs, bool isRowMajor>
void fusedL2ExpkNN(IdxT m, IdxT n, IdxT k, IdxT lda, IdxT ldb, IdxT ldd,
const DataT *x, const DataT *y, bool sqrt, OutT *out_dists,
IdxT *out_inds, IdxT numOfNN, cudaStream_t stream,
void *workspace, size_t &worksize) {
size_t bytesA = sizeof(DataT) * lda;
size_t bytesB = sizeof(DataT) * ldb;
if (16 % sizeof(DataT) == 0 && bytesA % 16 == 0 && bytesB % 16 == 0) {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 16 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace,
worksize);
} else if (8 % sizeof(DataT) == 0 && bytesA % 8 == 0 && bytesB % 8 == 0) {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 8 / sizeof(DataT), usePrevTopKs,
isRowMajor>(x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists,
out_inds, numOfNN, stream, workspace,
worksize);
} else {
fusedL2ExpKNNImpl<DataT, AccT, OutT, IdxT, 1, usePrevTopKs, isRowMajor>(
x, y, m, n, k, lda, ldb, ldd, sqrt, out_dists, out_inds, numOfNN, stream,
workspace, worksize);
}
}

template <raft::distance::DistanceType distanceType, typename value_idx,
typename value_t, bool usePrevTopKs>
void l2_expanded_knn(size_t D, value_idx *out_inds, value_t *out_dists,
const value_t *index, const value_t *query,
size_t n_index_rows, size_t n_query_rows, int k,
bool rowMajorIndex, bool rowMajorQuery,
cudaStream_t stream, void *workspace, size_t &worksize) {
// Validate the input data
ASSERT(k > 0, "l2Knn: k must be > 0");
ASSERT(D > 0, "l2Knn: D must be > 0");
ASSERT(n_index_rows > 0, "l2Knn: n_index_rows must be > 0");
ASSERT(index, "l2Knn: index must be provided (passed null)");
ASSERT(n_query_rows > 0, "l2Knn: n_query_rows must be > 0");
ASSERT(query, "l2Knn: query must be provided (passed null)");
ASSERT(out_dists, "l2Knn: out_dists must be provided (passed null)");
ASSERT(out_inds, "l2Knn: out_inds must be provided (passed null)");
// Currently we only support same layout for x & y inputs.
ASSERT(rowMajorIndex == rowMajorQuery,
"l2Knn: rowMajorIndex and rowMajorQuery should have same layout");

bool sqrt = (distanceType == raft::distance::DistanceType::L2SqrtExpanded);

if (rowMajorIndex) {
value_idx lda = D, ldb = D, ldd = n_index_rows;
fusedL2ExpkNN<value_t, value_t, value_t, value_idx, usePrevTopKs, true>(
n_query_rows, n_index_rows, D, lda, ldb, ldd, query, index, sqrt,
out_dists, out_inds, k, stream, workspace, worksize);
} else {
// TODO: Add support for column major layout
}
}

} // namespace detail
} // namespace knn
} // namespace spatial
Expand Down
32 changes: 24 additions & 8 deletions cpp/include/raft/spatial/knn/detail/knn_brute_force_faiss.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
#include <raft/handle.hpp>
#include <set>

#include <raft/distance/distance.cuh>
#include "fused_l2_knn.cuh"
#include "haversine_distance.cuh"
#include "processing.hpp"
Expand Down Expand Up @@ -276,13 +277,29 @@ void brute_force_knn_impl(std::vector<float *> &input,
metric == raft::distance::DistanceType::L2SqrtUnexpanded ||
metric == raft::distance::DistanceType::L2Expanded ||
metric == raft::distance::DistanceType::L2SqrtExpanded)) {
size_t worksize = 0;
void *workspace = nullptr;

size_t worksize = 0, tempWorksize = 0;
rmm::device_uvector<char> workspace(worksize, stream);
switch (metric) {
case raft::distance::DistanceType::L2Expanded:
case raft::distance::DistanceType::L2Unexpanded:
case raft::distance::DistanceType::L2SqrtExpanded:
tempWorksize = raft::distance::getWorkspaceSize<
raft::distance::DistanceType::L2Expanded, float, float, float,
IntType>(search_items, input[i], n, sizes[i], D);
worksize = tempWorksize;
workspace.resize(worksize, stream);
l2_expanded_knn<raft::distance::DistanceType::L2Expanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
if (worksize > tempWorksize) {
workspace.resize(worksize, stream);
l2_expanded_knn<raft::distance::DistanceType::L2Expanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
}
break;
case raft::distance::DistanceType::L2Unexpanded:
// Even for L2 Sqrt distance case we use non-sqrt version
// as FAISS bfKNN only support non-sqrt metric & some tests
// in RAFT/cuML (like Linkage) fails if we use L2 sqrt.
Expand All @@ -293,14 +310,13 @@ void brute_force_knn_impl(std::vector<float *> &input,
l2_unexpanded_knn<raft::distance::DistanceType::L2Unexpanded, int64_t,
float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace, worksize);
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
if (worksize) {
rmm::device_uvector<int> d_mutexes(worksize, stream);
workspace = d_mutexes.data();
workspace.resize(worksize, stream);
l2_unexpanded_knn<raft::distance::DistanceType::L2Unexpanded,
int64_t, float, false>(
D, out_i_ptr, out_d_ptr, input[i], search_items, sizes[i], n, k,
rowMajorIndex, rowMajorQuery, stream, workspace, worksize);
rowMajorIndex, rowMajorQuery, stream, workspace.data(), worksize);
}
break;
default:
Expand Down
2 changes: 1 addition & 1 deletion cpp/include/raft/spatial/knn/knn.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ inline void brute_force_knn(
std::vector<int> &sizes, int D, float *search_items, int n, int64_t *res_I,
float *res_D, int k, bool rowMajorIndex = true, bool rowMajorQuery = true,
std::vector<int64_t> *translations = nullptr,
distance::DistanceType metric = distance::DistanceType::L2Unexpanded,
distance::DistanceType metric = distance::DistanceType::L2Expanded,
float metric_arg = 2.0f) {
ASSERT(input.size() == sizes.size(),
"input and sizes vectors must be the same size");
Expand Down