Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mdspanifying (currently tested) raft::matrix #846

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
5d697b3
Breaking apart mdspan/mdarray into host_ and device_ variants
cjnolet Sep 7, 2022
0e6cc86
Updates
cjnolet Sep 7, 2022
71922f6
Fixing style
cjnolet Sep 7, 2022
b0e5a02
Separating host_span and device_span as well
cjnolet Sep 7, 2022
d69b163
Cleanup and getting to build
cjnolet Sep 7, 2022
50d750b
Updates
cjnolet Sep 7, 2022
6fda1fd
Fixing docs
cjnolet Sep 8, 2022
dded6ed
Updating readme to use proper header paths
cjnolet Sep 8, 2022
3aeb530
More updates based on review feedback
cjnolet Sep 8, 2022
ca354f3
Mdspanifying spatial/knn functions
cjnolet Sep 14, 2022
46b0750
Getting knn test to build
cjnolet Sep 14, 2022
a52bd9c
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 14, 2022
76a469d
Merge branch 'branch-22.10' into imp-2210-host_device_mdspan
cjnolet Sep 14, 2022
b6c758c
Fixing style
cjnolet Sep 14, 2022
3838690
Merge branch 'imp-2210-host_device_mdspan' into fea-2210-mdspanified_knn
cjnolet Sep 15, 2022
105a3a6
Trying to FIND_RAFT_CPP on by default
cjnolet Sep 15, 2022
7eae6e3
Fixing bad merge
cjnolet Sep 15, 2022
019e358
Merge remote-tracking branch 'rapidsai/branch-22.10' into imp-2210-ho…
cjnolet Sep 15, 2022
08c5648
Merge branch 'imp-2210-host_device_mdspan' into fea-2210-mdspanified_knn
cjnolet Sep 15, 2022
141a2d1
Fixing knn wrapper
cjnolet Sep 19, 2022
6cd27af
mdspanidying random ball cover
cjnolet Sep 19, 2022
4870554
mdspan-ifying ivf_flat, rbc, and epsilon neighborhoods
cjnolet Sep 19, 2022
428a9e6
Fixing last compile error
cjnolet Sep 19, 2022
a612bc2
Updating ball cover specializations and API
cjnolet Sep 20, 2022
e63d121
Removing stream destroy from eps neigh tests
cjnolet Sep 20, 2022
44f7aca
Starting on col wise sort
cjnolet Sep 20, 2022
d111743
Updating docs
cjnolet Sep 20, 2022
13d23fd
Merge branch 'fea-2210-mdspanified_knn' into fea-2210-mdspanified_matrix
cjnolet Sep 20, 2022
3a14fa5
MOre gather and colwise sort
cjnolet Sep 21, 2022
f27b0cb
Breaking matrix functions out into individual files.
cjnolet Sep 21, 2022
b1e834c
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 22, 2022
32a9a34
Merge branch 'branch-22.10' into fea-2210-mdspanified_matrix
cjnolet Sep 22, 2022
9002337
Fixing style
cjnolet Sep 22, 2022
6c34351
Fixing style for tests
cjnolet Sep 22, 2022
fea8448
iUpdates
cjnolet Sep 26, 2022
bbf2ab1
Updates based on review feedback
cjnolet Sep 26, 2022
f528697
Updates based on review feedback
cjnolet Sep 26, 2022
90bbb33
Getting to nbuild
cjnolet Sep 26, 2022
8cb3ec5
Fixing style
cjnolet Sep 26, 2022
c4bd2d1
Removing files from raft;:Matrix which still need to be tested
cjnolet Sep 26, 2022
b166814
Progress on matrix API
cjnolet Sep 27, 2022
4c552bc
Adding weight back into rbc
cjnolet Sep 27, 2022
53a254e
Style check
cjnolet Sep 27, 2022
278ce4d
More updates
cjnolet Sep 27, 2022
14c4ff1
Trying to figure out gather and linewise op
cjnolet Sep 27, 2022
8db49ab
Still trying to figure out why gaqther isn't being invoked
cjnolet Sep 27, 2022
4ea623f
Style fix
cjnolet Sep 27, 2022
d5db8e3
Merge remote-tracking branch 'rapidsai/branch-22.10' into fea-2210-md…
cjnolet Sep 27, 2022
477d18a
Finisxhing up matrix
cjnolet Sep 27, 2022
f63b458
Getting tests to pass. Still need to figure out linewise_op failures
cjnolet Sep 28, 2022
7edf83a
Cleaning up, docs, tests passing. ready for review
cjnolet Sep 28, 2022
5b705d7
More docs cleanup
cjnolet Sep 28, 2022
b49c9f4
Updates
cjnolet Sep 28, 2022
5fa5bbf
Implementing review feedback
cjnolet Sep 29, 2022
ce94f63
Renaming
cjnolet Sep 29, 2022
9939464
Updating docs to include [in] and [out]
cjnolet Sep 29, 2022
32775a6
Syncing handle after argmax prim
cjnolet Sep 29, 2022
5363a2d
Removing defaults on template args
cjnolet Sep 30, 2022
b4377cb
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 30, 2022
dec2f81
Adding limit-tests to build.sh. Removing default template args per
cjnolet Sep 30, 2022
b70b44b
Updating docs
cjnolet Sep 30, 2022
e25caef
More style cleanup
cjnolet Sep 30, 2022
37ae236
Pulling out argmax for now since the test seems to be failing in centos.
cjnolet Sep 30, 2022
423034b
Merge branch 'fea-2210-mdspanified_knn' into fea-2210-mdspanified_mat…
cjnolet Sep 30, 2022
9ce7d4f
Lots of updates from review feedback.
cjnolet Sep 30, 2022
5313cb9
Merge branch 'branch-22.10' into fea-2210-mdspanified_knn
cjnolet Sep 30, 2022
c024518
Merge branch 'fea-2210-mdspanified_knn' into fea-2210-mdspanified_mat…
cjnolet Sep 30, 2022
3db8a27
Moving re-defined validation logic out of mdspan.hpp
cjnolet Sep 30, 2022
4c06e91
Review feedback
cjnolet Oct 2, 2022
946bfcb
More review feedbck
cjnolet Oct 2, 2022
c7245f6
Merge branch 'branch-22.10' into fea-2210-mdspanified_matrix_tested
cjnolet Oct 2, 2022
e0e5eae
Fixing styl
cjnolet Oct 2, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Adding limit-tests to build.sh. Removing default template args per
reviews
  • Loading branch information
cjnolet committed Sep 30, 2022
commit dec2f813d8a55d22393f399128a2755f6f10575b
25 changes: 22 additions & 3 deletions build.sh
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ ARGS=$*
REPODIR=$(cd $(dirname $0); pwd)

VALIDARGS="clean libraft pylibraft raft-dask docs tests bench clean -v -g --install --compile-libs --compile-nn --compile-dist --allgpuarch --no-nvtx --show_depr_warn -h --buildfaiss --minimal-deps"
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>]
HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<tool>] [--limit-tests=<targets>]
where <target> is:
clean - remove all existing build artifacts and configuration (start over)
libraft - build the raft C++ code only. Also builds the C-wrapper library
Expand All @@ -40,6 +40,7 @@ HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<to
the only option to be supported)
--minimal-deps - disables dependencies like thrust so they can be overridden.
can be useful for a pure header-only install
--limit-tests - semicolon-separated list of test executables to compile (e.g. TEST_SPATIAL;TEST_CLUSTER)
--allgpuarch - build for all supported GPU architectures
--buildfaiss - build faiss statically into raft
--install - install cmake targets
Expand All @@ -50,7 +51,7 @@ HELP="$0 [<target> ...] [<flag> ...] [--cmake-args=\"<args>\"] [--cache-tool=<to
to speedup the build process.
-h - print this text

default action (no args) is to build both libraft and raft-dask targets
default action (no args) is to build libraft, tests, pylibraft and raft-dask targets
"
LIBRAFT_BUILD_DIR=${LIBRAFT_BUILD_DIR:=${REPODIR}/cpp/build}
SPHINX_BUILD_DIR=${REPODIR}/docs
Expand All @@ -70,6 +71,8 @@ COMPILE_NN_LIBRARY=OFF
COMPILE_DIST_LIBRARY=OFF
ENABLE_NN_DEPENDENCIES=OFF

TEST_TARGETS="CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NN_TEST;SPATIAL_TEST;STATS_TEST;UTILS_TEST"

ENABLE_thrust_DEPENDENCY=ON

CACHE_ARGS=""
Expand Down Expand Up @@ -136,6 +139,21 @@ function cacheTool {
fi
}

function limitTests {
# Check for option to limit the set of test binaries to build
if [[ -n $(echo $ARGS | { grep -E "\-\-limit\-tests" || true; } ) ]]; then
# There are possible weird edge cases that may cause this regex filter to output nothing and fail silently
# the true pipe will catch any weird edge cases that may happen and will cause the program to fall back
# on the invalid option error
LIMIT_TEST_TARGETS=$(echo $ARGS | sed -e 's/.*--limit-tests=//' -e 's/ .*//')
if [[ -n ${LIMIT_TEST_TARGETS} ]]; then
# Remove the full LIMIT_TEST_TARGETS argument from list of args so that it passes validArgs function
ARGS=${ARGS//--limit-tests=$LIMIT_TEST_TARGETS/}
TEST_TARGETS=${LIMIT_TEST_TARGETS}
fi
fi
}

if hasArg -h || hasArg --help; then
echo "${HELP}"
exit 0
Expand All @@ -145,6 +163,7 @@ fi
if (( ${NUMARGS} != 0 )); then
cmakeArgs
cacheTool
limitTests
for a in ${ARGS}; do
if ! (echo " ${VALIDARGS} " | grep -q " ${a} "); then
echo "Invalid option: ${a}"
Expand Down Expand Up @@ -194,7 +213,7 @@ if hasArg tests || (( ${NUMARGS} == 0 )); then
COMPILE_DIST_LIBRARY=ON
ENABLE_NN_DEPENDENCIES=ON
COMPILE_NN_LIBRARY=ON
CMAKE_TARGET="${CMAKE_TARGET};CLUSTER_TEST;CORE_TEST;DISTANCE_TEST;LABEL_TEST;LINALG_TEST;MATRIX_TEST;RANDOM_TEST;SOLVERS_TEST;SPARSE_TEST;SPARSE_DIST_TEST;SPARSE_NN_TEST;SPATIAL_TEST;STATS_TEST;UTILS_TEST"
CMAKE_TARGET="${CMAKE_TARGET};${TEST_TARGETS}"
fi

if hasArg bench || (( ${NUMARGS} == 0 )); then
Expand Down
10 changes: 2 additions & 8 deletions cpp/include/raft/spatial/knn/ball_cover.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,7 @@ namespace knn {
* @param[in] handle library resource management handle
* @param[inout] index an empty (and not previous built) instance of BallCoverIndex
*/
template <typename idx_t,
typename value_t,
typename int_t,
typename matrix_idx_t>
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_build_index(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index)
{
Expand Down Expand Up @@ -87,10 +84,7 @@ void rbc_build_index(const raft::handle_t& handle,
* many datasets can still have great recall even by only
* looking in the closest landmark.
*/
template <typename idx_t,
typename value_t,
typename int_t,
typename matrix_idx_t>
template <typename idx_t, typename value_t, typename int_t, typename matrix_idx_t>
void rbc_all_knn_query(const raft::handle_t& handle,
BallCoverIndex<idx_t, value_t, int_t, matrix_idx_t>& index,
int_t k,
Expand Down
51 changes: 19 additions & 32 deletions cpp/include/raft/spatial/knn/ivf_flat.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -100,19 +100,16 @@ inline auto build(
*
* @return the constructed ivf-flat index
*/
template <typename value_t,
typename idx_t,
typename int_t,
typename matrix_idx_t>
template <typename value_t, typename idx_t>
auto build_index(const handle_t& handle,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> dataset,
raft::device_matrix_view<const value_t, idx_t, row_major> dataset,
const index_params& params) -> index<value_t, idx_t>
{
return raft::spatial::knn::ivf_flat::detail::build(handle,
params,
dataset.data_handle(),
static_cast<idx_t>(dataset.extent(0)),
static_cast<int_t>(dataset.extent(1)));
static_cast<idx_t>(dataset.extent(1)));
}

/**
Expand Down Expand Up @@ -191,15 +188,12 @@ inline auto extend(const handle_t& handle,
*
* @return the constructed extended ivf-flat index
*/
template <typename value_t,
typename idx_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
template <typename value_t, typename idx_t>
auto extend(const handle_t& handle,
const index<value_t, idx_t>& orig_index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, matrix_idx_t>> new_indices =
std::nullopt) -> index<value_t, idx_t>
raft::device_matrix_view<const value_t, idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices = std::nullopt)
-> index<value_t, idx_t>
{
return raft::spatial::knn::ivf_flat::detail::extend<value_t, idx_t>(
handle,
Expand Down Expand Up @@ -248,15 +242,11 @@ inline void extend(const handle_t& handle,
* If the original index is empty (`orig_index.size() == 0`), you can pass `std::nullopt`
* here to imply a continuous range `[0...n_rows)`.
*/
template <typename value_t,
typename idx_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
void extend(
const handle_t& handle,
index<value_t, idx_t>* index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, matrix_idx_t>> new_indices = std::nullopt)
template <typename value_t, typename idx_t>
void extend(const handle_t& handle,
index<value_t, idx_t>* index,
raft::device_matrix_view<const value_t, idx_t, row_major> new_vectors,
std::optional<raft::device_vector_view<const idx_t, idx_t>> new_indices = std::nullopt)
{
*index = extend(handle,
*index,
Expand Down Expand Up @@ -363,25 +353,22 @@ inline void search(const handle_t& handle,
* @param[in] params configure the search
* @param[in] k the number of neighbors to find for each query.
*/
template <typename value_t,
typename idx_t,
typename int_t = std::uint32_t,
typename matrix_idx_t = std::uint32_t>
template <typename value_t, typename idx_t, typename int_t>
void search(const handle_t& handle,
const index<value_t, idx_t>& index,
raft::device_matrix_view<const value_t, matrix_idx_t, row_major> queries,
raft::device_matrix_view<idx_t, matrix_idx_t, row_major> neighbors,
raft::device_matrix_view<idx_t, matrix_idx_t, float> distances,
raft::device_matrix_view<const value_t, idx_t, row_major> queries,
raft::device_matrix_view<idx_t, idx_t, row_major> neighbors,
raft::device_matrix_view<idx_t, idx_t, float> distances,
const search_params& params,
int_t k)
{
RAFT_EXPECTS(
queries.extent(0) == neighbors.extent(0) && queries.extent(0) == distances.extent(0),
"Number of rows in output neighbors and distances matrices must equal the number of queries.");

RAFT_EXPECTS(neighbors.extent(1) == distances.extent(1) &&
neighbors.extent(1) == static_cast<matrix_idx_t>(k),
"Number of columns in output neighbors and distances matrices must equal k");
RAFT_EXPECTS(
neighbors.extent(1) == distances.extent(1) && neighbors.extent(1) == static_cast<idx_t>(k),
"Number of columns in output neighbors and distances matrices must equal k");

RAFT_EXPECTS(queries.extent(1) == index.dim(),
"Number of query dimensions should equal number of dimensions in the index.");
Expand Down
18 changes: 9 additions & 9 deletions cpp/test/spatial/ann_ivf_flat.cu
Original file line number Diff line number Diff line change
Expand Up @@ -208,10 +208,10 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs> {
index_params.add_data_on_build = false;
index_params.kmeans_trainset_fraction = 0.5;

auto database_view = raft::make_device_matrix_view<const DataT>(
auto database_view = raft::make_device_matrix_view<const DataT, IdxT>(
(const DataT*)database.data(), ps.num_db_vecs, ps.dim);

auto index = ivf_flat::build_index<DataT, IdxT>(handle_, database_view, index_params);
auto index = ivf_flat::build_index(handle_, database_view, index_params);

rmm::device_uvector<IdxT> vector_indices(ps.num_db_vecs, stream_);
thrust::sequence(handle_.get_thrust_policy(),
Expand All @@ -221,16 +221,16 @@ class AnnIVFFlatTest : public ::testing::TestWithParam<AnnIvfFlatInputs> {

int64_t half_of_data = ps.num_db_vecs / 2;

auto half_of_data_view = raft::make_device_matrix_view<const DataT, int>(
auto half_of_data_view = raft::make_device_matrix_view<const DataT, IdxT>(
(const DataT*)database.data(), half_of_data, ps.dim);

auto index_2 = ivf_flat::extend<DataT, IdxT>(handle_, index, half_of_data_view);
auto index_2 = ivf_flat::extend(handle_, index, half_of_data_view);

ivf_flat::extend<DataT, IdxT>(handle_,
&index_2,
database.data() + half_of_data * ps.dim,
vector_indices.data() + half_of_data,
int64_t(ps.num_db_vecs) - half_of_data);
ivf_flat::extend(handle_,
&index_2,
database.data() + half_of_data * ps.dim,
vector_indices.data() + half_of_data,
int64_t(ps.num_db_vecs) - half_of_data);

ivf_flat::search(handle_,
search_params,
Expand Down