Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEA] support of prefiltered brute force #2294

Merged
merged 29 commits into from
May 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
d85fcf0
[FEA] support of prefiltered brute force based on cuSparseSDDMM
rhdong May 7, 2024
e3ef7bc
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 8, 2024
c7e4e7a
Improve the performance in classic scenarios by replace the cuSparseS…
rhdong May 8, 2024
ed07f60
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 13, 2024
5c5aa9b
optimize and remove used.
rhdong May 13, 2024
b4971c2
Update cpp/include/raft/sparse/distance/detail/utils.cuh
achirkin May 13, 2024
b1c1bb8
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 14, 2024
77ee4a6
Test cases adjustment
rhdong May 14, 2024
ea8420a
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 15, 2024
cc2b228
Merge SDDMM with customized kernel, optimize bitset count
rhdong May 15, 2024
68731b0
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 17, 2024
3a81f19
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 20, 2024
2684afe
Optimize by dense bfknn
rhdong May 20, 2024
57193a5
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 20, 2024
8e1217c
Optimize the test cases
rhdong May 20, 2024
56f00cd
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 20, 2024
71bd24b
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 21, 2024
96f4e83
Splitting(revert) the cuVS part
rhdong May 21, 2024
4f1aa17
Fix CI issue
rhdong May 21, 2024
b718673
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 22, 2024
9e24c5a
Move sparse distance API utils to cuvs and split the bitmap
rhdong May 22, 2024
18cb672
Optimize by review comments
rhdong May 23, 2024
e393af9
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 23, 2024
72c71f5
Merge remote-tracking branch 'origin/branch-24.06' into rhdong/prefil…
rhdong May 23, 2024
97a0e74
Remove the sparse select_k instantiations
rhdong May 23, 2024
de49e0c
Merge branch 'rhdong/prefiltered-bf' of https://github.com/rhdong/raf…
rhdong May 23, 2024
7d08443
Fix CI issue
rhdong May 23, 2024
18ba927
Merge branch 'branch-24.06' into rhdong/prefiltered-bf
cjnolet May 23, 2024
f38642f
Fix docs issue.
rhdong May 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Improve the performance in classic scenarios by replace the cuSparseS…
…DDMM with faster_dot_on_csr
  • Loading branch information
rhdong committed May 8, 2024
commit c7e4e7a4262466172ed17d5fef6b801781320436
19 changes: 12 additions & 7 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,13 @@ struct params {

inline auto operator<<(std::ostream& os, const params& p) -> std::ostream&
{
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k << "#"
<< p.removed_ratio;
os << p.n_samples << "#" << p.n_dims << "#" << p.n_queries << "#" << p.k;
if (p.removed_ratio > 0.0) {
os << "#" << p.removed_ratio;
} else {
os << "#"
<< "[No filtered]";
}
switch (p.metric) {
case raft::distance::DistanceType::InnerProduct: os << "#InnerProduct"; break;
case raft::distance::DistanceType::L2Expanded: os << "#L2Expanded"; break;
Expand Down Expand Up @@ -595,11 +600,11 @@ const std::vector<params> kInputsFilter =
);

const std::vector<params> kInputsBruteForceFilter = raft::util::itertools::product<params>(
{size_t(1000000)}, // n_samples
{size_t(128)}, // n_dim
{size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.8, 0.9}, // removed_ratio
{size_t(1000000)}, // n_samples
{size_t(4096), size_t(512), size_t(128)}, // n_dim
{size_t(1), size_t(10), size_t(1000)}, // n_queries
{size_t(255)}, // k
{0.0, 0.8, 0.9, 0.99}, // removed_ratio
{raft::distance::DistanceType::InnerProduct, raft::distance::DistanceType::L2Expanded});

inline const std::vector<TransferStrategy> kAllStrategies{
Expand Down
36 changes: 16 additions & 20 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -606,24 +606,21 @@ void brute_force_search(

// create filter csr view
auto compressed_csr_view = csr.structure_view();
auto csr_view = make_device_csr_matrix_view<T, IdxT, IdxT, IdxT>(csr.get_elements().data(),
compressed_csr_view);

// create dataset view
auto dataset_view = raft::make_device_matrix_view<const T, IdxT, raft::col_major>(
idx.dataset().data_handle(), dim, n_dataset);

// calc dot
T alpha = static_cast<T>(1.0f);
T beta = static_cast<T>(0.0f);
raft::sparse::linalg::sddmm(res,
queries,
dataset_view,
csr_view,
raft::linalg::Operation::NON_TRANSPOSE,
raft::linalg::Operation::NON_TRANSPOSE,
raft::make_host_scalar_view<T>(&alpha),
raft::make_host_scalar_view<T>(&beta));
rmm::device_uvector<IdxT> rows(compressed_csr_view.get_nnz(), stream);
raft::sparse::convert::csr_to_coo(compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_n_rows(),
rows.data(),
compressed_csr_view.get_nnz(),
stream);

raft::sparse::distance::detail::faster_dot_on_csr(res,
csr.get_elements().data(),
compressed_csr_view.get_nnz(),
rows.data(),
compressed_csr_view.get_indices().data(),
queries.data_handle(),
idx.dataset().data_handle(),
dim);

// post process
std::optional<device_vector<T, IdxT>> query_norms_;
Expand Down Expand Up @@ -658,9 +655,8 @@ void brute_force_search(
raft::sparse::distance::detail::epilogue_on_csr(
res,
csr.get_elements().data(),
compressed_csr_view.get_indptr().data(),
compressed_csr_view.get_nnz(),
compressed_csr_view.get_n_rows(),
rows.data(),
compressed_csr_view.get_indices().data(),
query_norms ? query_norms->data_handle() : query_norms_->data_handle(),
idx.norms().data_handle(),
Expand Down
88 changes: 80 additions & 8 deletions cpp/include/raft/sparse/distance/detail/utils.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -61,24 +61,20 @@ RAFT_KERNEL epilogue_on_csr_kernel(value_t* __restrict__ compressed_C,
template <typename value_idx, typename value_t, int tpb = 256>
void epilogue_on_csr(raft::resources const& handle,
value_t* compressed_C,
const value_idx* indptr,
const value_idx nnz,
const value_idx n_rows,
const value_idx* rows,
const value_idx* cols,
const value_t* Q_sq_norms,
const value_t* R_sq_norms,
raft::distance::DistanceType metric)
{
auto stream = resource::get_cuda_stream(handle);

rmm::device_uvector<value_idx> rows(nnz, stream);
raft::sparse::convert::csr_to_coo(indptr, n_rows, rows.data(), nnz, stream);

int blocks = raft::ceildiv<size_t>((size_t)nnz, tpb);
if (metric == raft::distance::DistanceType::L2Expanded) {
epilogue_on_csr_kernel<<<blocks, tpb, 0, stream>>>(
compressed_C,
rows.data(),
rows,
cols,
Q_sq_norms,
R_sq_norms,
Expand All @@ -89,7 +85,7 @@ void epilogue_on_csr(raft::resources const& handle,
} else if (metric == raft::distance::DistanceType::L2SqrtExpanded) {
epilogue_on_csr_kernel<<<blocks, tpb, 0, stream>>>(
compressed_C,
rows.data(),
rows,
cols,
Q_sq_norms,
R_sq_norms,
Expand All @@ -100,7 +96,7 @@ void epilogue_on_csr(raft::resources const& handle,
} else if (metric == raft::distance::DistanceType::CosineExpanded) {
epilogue_on_csr_kernel<<<blocks, tpb, 0, stream>>>(
compressed_C,
rows.data(),
rows,
cols,
Q_sq_norms,
R_sq_norms,
Expand All @@ -110,6 +106,82 @@ void epilogue_on_csr(raft::resources const& handle,
});
}
}

template <typename value_t>
__inline__ __device__ value_t warpReduceSum(value_t val)
{
return val;
}

template <typename value_idx, typename value_t, int tpb>
RAFT_KERNEL faster_dot_on_csr_kernel(value_t* __restrict__ dot,
const value_idx* __restrict__ rows,
const value_idx* __restrict__ cols,
const value_t* __restrict__ A,
const value_t* __restrict__ B,
const value_idx nnz,
const value_idx dim)
{
auto dot_id = blockIdx.x;
auto vec_id = threadIdx.x;
auto lane_id = threadIdx.x & 0x1f;

const value_idx row = rows[dot_id] * dim;
const value_idx col = cols[dot_id] * dim;
__shared__ value_t g_dot_;

if (threadIdx.x == 0) { g_dot_ = 0.0; }
__syncthreads();

value_t l_dot_ = 0.0;

#pragma unroll
for (value_idx k = vec_id; k < dim; k += blockDim.x) {
l_dot_ += A[row + k] * B[col + k];
}

#pragma unroll
for (int offset = warpSize / 2; offset > 0; offset /= 2) {
l_dot_ += __shfl_down_sync(0xffffffff, l_dot_, offset);
}

if (lane_id == 0) { atomicAdd_block(&g_dot_, l_dot_); }
__syncthreads();

if (threadIdx.x == 0) { dot[dot_id] = g_dot_; }
}

template <typename value_idx, typename value_t>
void faster_dot_on_csr(raft::resources const& handle,
value_t* dot,
const value_idx nnz,
const value_idx* rows,
const value_idx* cols,
const value_t* A,
const value_t* B,
const value_idx dim)
{
auto stream = resource::get_cuda_stream(handle);

int blocks = int(nnz);
if (dim < 128) {
constexpr int tpb = 64;
faster_dot_on_csr_kernel<value_idx, value_t, tpb>
<<<blocks, tpb, 0, stream>>>(dot, rows, cols, A, B, nnz, dim);
} else if (dim < 256) {
constexpr int tpb = 128;
faster_dot_on_csr_kernel<value_idx, value_t, tpb>
<<<blocks, tpb, 0, stream>>>(dot, rows, cols, A, B, nnz, dim);
} else if (dim < 512) {
constexpr int tpb = 256;
faster_dot_on_csr_kernel<value_idx, value_t, tpb>
<<<blocks, tpb, 0, stream>>>(dot, rows, cols, A, B, nnz, dim);
} else {
constexpr int tpb = 512;
faster_dot_on_csr_kernel<value_idx, value_t, tpb>
<<<blocks, tpb, 0, stream>>>(dot, rows, cols, A, B, nnz, dim);
}
}
} // namespace detail
} // namespace distance
} // namespace sparse
Expand Down
6 changes: 6 additions & 0 deletions cpp/test/sparse/neighbors/prefiltered_brute_force.cu
Original file line number Diff line number Diff line change
Expand Up @@ -405,6 +405,12 @@ TEST_P(PrefilteredBruteForceTest_float_int64, Result) { Run(); }

template <typename index_t>
const std::vector<PrefilteredBruteForceInputs<index_t>> selectk_inputs = {
{1, 100000, 255, 255, 0.4, raft::distance::DistanceType::L2Expanded},
{10, 100000, 512, 16, 0.5, raft::distance::DistanceType::L2Expanded},
{20, 100000, 2052, 16, 0.2, raft::distance::DistanceType::L2Expanded},
{1, 10000, 255, 16, 0.4, raft::distance::DistanceType::InnerProduct},
{20, 10000, 512, 16, 0.5, raft::distance::DistanceType::InnerProduct},
{100, 10000, 2052, 16, 0.2, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 1, 0, 0.1, raft::distance::DistanceType::L2Expanded},
{1000, 10000, 3, 0, 0.1, raft::distance::DistanceType::InnerProduct},
{1000, 10000, 5, 0, 0.1, raft::distance::DistanceType::L2SqrtExpanded},
Expand Down
Loading