Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Resolve auto-merger conflicts between branch-23.04 & branch-23.06 #5340

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/pr.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,10 @@ jobs:
build_type: pull-request
package-name: cuml
# Always want to test against latest dask/distributed.
test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
# On arm also need to install cupy from the specific webpage and CMake
# because treelite needs to be compiled (no wheels available for arm).
test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64 && pip install cmake && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64 && pip install cmake && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
# skipped test context: https://github.com/rapidsai/cuml/issues/5025
# parallelization is based on current test memory usage
test-unittest: "pytest -v ./python/cuml/tests -k 'not test_silhouette_score_batched and not test_sparse_pca_inputs' -n 8 --ignore=python/cuml/tests/dask && pytest -v ./python/cuml/tests -k 'test_sparse_pca_inputs' && pytest -v ./python/cuml/tests/dask"
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,10 @@ jobs:
date: ${{ inputs.date }}
sha: ${{ inputs.sha }}
package-name: cuml
test-before-amd64: "pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
test-before-amd64: "pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
# On arm also need to install cupy from the specific webpage and CMake
# because treelite needs to be compiled (no wheels available for arm).
test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64 && pip install cmake && pip install git+https://github.com/dask/dask.git@main git+https://github.com/dask/distributed.git@main git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
test-before-arm64: "pip install 'cupy-cuda11x<12.0.0' -f https://pip.cupy.dev/aarch64 && pip install cmake && pip install git+https://github.com/dask/dask.git@2023.3.2 git+https://github.com/dask/distributed.git@2023.3.2.1 git+https://github.com/rapidsai/dask-cuda.git@branch-23.06"
# skipped test context: https://github.com/rapidsai/cuml/issues/5025
# parallelization is based on current test memory usage
test-unittest: "pytest -v ./python/cuml/tests -k 'not test_silhouette_score_batched and not test_sparse_pca_inputs' -n 8 --ignore=python/cuml/tests/dask && pytest -v ./python/cuml/tests -k 'test_sparse_pca_inputs' && pytest -v ./python/cuml/tests/dask"
17 changes: 16 additions & 1 deletion ci/test_python_singlegpu.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ EXITCODE=0
trap "EXITCODE=1" ERR
set +e

rapids-logger "pytest cuml single GPU"
rapids-logger "pytest cuml single GPU..."
cd python/cuml/tests
pytest \
--numprocesses=8 \
Expand All @@ -19,6 +19,21 @@ pytest \
--cov=cuml \
--cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cuml-coverage.xml" \
--cov-report=term \
-m "not memleak" \
.

rapids-logger "memory leak pytests..."

pytest \
--numprocesses=1 \
--ignore=dask \
--cache-clear \
--junitxml="${RAPIDS_TESTS_DIR}/junit-cuml-memleak.xml" \
--cov-config=../../.coveragerc \
--cov=cuml \
--cov-report=xml:"${RAPIDS_COVERAGE_DIR}/cuml-memleak-coverage.xml" \
--cov-report=term \
-m "memleak" \
.

rapids-logger "Test script exiting with value: $EXITCODE"
Expand Down
5 changes: 3 additions & 2 deletions conda/environments/all_cuda-118_arch-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@ dependencies:
- cupy>=9.5.0,<12.0.0a0
- cxx-compiler
- cython>=0.29,<0.30
- dask-core==2023.3.2
- dask-cuda==23.6.*
- dask-cudf==23.6.*
- dask-ml
- dask>=2023.1.1
- distributed>=2023.1.1
- dask==2023.3.2
- distributed==2023.3.2.1
- doxygen=1.8.20
- faiss-proc=*=cuda
- gcc_linux-64=11.*
Expand Down
5 changes: 3 additions & 2 deletions conda/recipes/cuml/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -70,8 +70,9 @@ requirements:
- cudf ={{ minor_version }}
- cupy >=7.8.0,<12.0.0a0
- dask-cudf ={{ minor_version }}
- dask >=2023.1.1
- distributed >=2023.1.1
- dask ==2023.3.2
- dask-core==2023.3.2
- distributed ==2023.3.2.1
- joblib >=0.11
- libcuml ={{ version }}
- libcumlprims ={{ minor_version }}
Expand Down
4 changes: 2 additions & 2 deletions cpp/include/cuml/cluster/hdbscan.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -459,7 +459,7 @@ void compute_all_points_membership_vectors(
const float* X,
raft::distance::DistanceType metric,
float* membership_vec,
int batch_size);
size_t batch_size = 4096);

void compute_membership_vector(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
Expand All @@ -470,7 +470,7 @@ void compute_membership_vector(const raft::handle_t& handle,
int min_samples,
raft::distance::DistanceType metric,
float* membership_vec,
int batch_size);
size_t batch_size = 4096);

void out_of_sample_predict(const raft::handle_t& handle,
HDBSCAN::Common::CondensedHierarchy<int, float>& condensed_tree,
Expand Down
25 changes: 12 additions & 13 deletions cpp/src/hdbscan/detail/soft_clustering.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ void dist_membership_vector(const raft::handle_t& handle,
value_idx* exemplar_label_offsets,
value_t* dist_membership_vec,
raft::distance::DistanceType metric,
int batch_size,
size_t batch_size,
bool softmax = false)
{
auto stream = handle.get_stream();
Expand All @@ -82,16 +82,11 @@ void dist_membership_vector(const raft::handle_t& handle,
// compute the number of batches based on the batch size
value_idx n_batches;

if (batch_size == 0) {
n_batches = 1;
batch_size = n_queries;
}
else {
n_batches = raft::ceildiv((int)n_queries, (int)batch_size);
}
n_batches = raft::ceildiv((int)n_queries, (int)batch_size);

for(value_idx bid = 0; bid < n_batches; bid++) {
value_idx batch_offset = bid * batch_size;
value_idx samples_per_batch = min(batch_size, (int)n_queries - batch_offset);
value_idx samples_per_batch = min((value_idx)batch_size, (value_idx)n_queries - batch_offset);
rmm::device_uvector<value_t> dist(samples_per_batch * n_exemplars, stream);

// compute the distances using raft API
Expand Down Expand Up @@ -392,14 +387,16 @@ void all_points_membership_vectors(const raft::handle_t& handle,
const value_t* X,
raft::distance::DistanceType metric,
value_t* membership_vec,
value_idx batch_size)
size_t batch_size)
{
auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();

size_t m = prediction_data.n_rows;
size_t n = prediction_data.n_cols;
RAFT_EXPECTS(0 <= batch_size && batch_size <= m, "Invalid batch_size. batch_size should be >= 0 and <= the number of samples in the training data");

if (batch_size > m) batch_size = m;
RAFT_EXPECTS(0 < batch_size && batch_size <= m, "Invalid batch_size. batch_size should be > 0 and <= the number of samples in the training data");

auto parents = condensed_tree.get_parents();
auto children = condensed_tree.get_children();
Expand Down Expand Up @@ -507,11 +504,10 @@ void membership_vector(const raft::handle_t& handle,
raft::distance::DistanceType metric,
int min_samples,
value_t* membership_vec,
value_idx batch_size)
size_t batch_size)
{
RAFT_EXPECTS(metric == raft::distance::DistanceType::L2SqrtExpanded,
"Currently only L2 expanded distance is supported");
RAFT_EXPECTS(0 <= batch_size && batch_size <= n_prediction_points, "Invalid batch_size. batch_size should be >= 0 and <= the number of points to predict");

auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();
Expand All @@ -525,6 +521,9 @@ void membership_vector(const raft::handle_t& handle,
value_idx n_exemplars = prediction_data.get_n_exemplars();
value_t* lambdas = condensed_tree.get_lambdas();

if (batch_size > n_prediction_points) batch_size = n_prediction_points;
RAFT_EXPECTS(0 < batch_size && batch_size <= n_prediction_points, "Invalid batch_size. batch_size should be > 0 and <= the number of samples in the training data");

rmm::device_uvector<value_t> dist_membership_vec(n_prediction_points * n_selected_clusters,
stream);

Expand Down
4 changes: 2 additions & 2 deletions cpp/src/hdbscan/hdbscan.cu
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ void compute_all_points_membership_vectors(
const float* X,
raft::distance::DistanceType metric,
float* membership_vec,
int batch_size)
size_t batch_size)
{
HDBSCAN::detail::Predict::all_points_membership_vectors(
handle, condensed_tree, prediction_data, X, metric, membership_vec, batch_size);
Expand All @@ -108,7 +108,7 @@ void compute_membership_vector(const raft::handle_t& handle,
int min_samples,
raft::distance::DistanceType metric,
float* membership_vec,
int batch_size)
size_t batch_size)
{
// Note that (min_samples+1) is parsed to the approximate_predict function. This was done for the
// core distance computation to consistent with Scikit learn Contrib.
Expand Down
9 changes: 2 additions & 7 deletions cpp/test/sg/hdbscan_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,7 @@
#include <cuml/cluster/hdbscan.hpp>
#include <hdbscan/detail/condense.cuh>
#include <hdbscan/detail/extract.cuh>
#include <hdbscan/detail/predict.cuh>
#include <hdbscan/detail/reachability.cuh>
#include <hdbscan/detail/soft_clustering.cuh>
#include <hdbscan/detail/utils.h>

#include <raft/spatial/knn/specializations.cuh>
#include <raft/stats/adjusted_rand_index.cuh>
Expand Down Expand Up @@ -460,8 +457,7 @@ class AllPointsMembershipVectorsTest
prediction_data_,
data.data(),
raft::distance::DistanceType::L2SqrtExpanded,
membership_vec.data(),
0);
membership_vec.data());

ASSERT_TRUE(MLCommon::devArrMatch(membership_vec.data(),
params.expected_probabilities.data(),
Expand Down Expand Up @@ -755,8 +751,7 @@ class MembershipVectorTest : public ::testing::TestWithParam<MembershipVectorInp
params.n_points_to_predict,
params.min_samples,
raft::distance::DistanceType::L2SqrtExpanded,
membership_vec.data(),
0);
membership_vec.data());

ASSERT_TRUE(MLCommon::devArrMatch(membership_vec.data(),
params.expected_probabilities.data(),
Expand Down
7 changes: 5 additions & 2 deletions dependencies.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -130,10 +130,10 @@ dependencies:
- output_types: [conda, requirements, pyproject]
packages:
- cudf==23.6.*
- dask>=2023.1.1
- dask==2023.3.2
- dask-cuda==23.6.*
- dask-cudf==23.6.*
- distributed>=2023.1.1
- distributed==2023.3.2.1
- joblib>=0.11
- numba
# TODO: Are seaborn and scipy really hard dependencies, or should
Expand All @@ -146,6 +146,9 @@ dependencies:
- output_types: [conda, requirements]
packages:
- cupy>=9.5.0,<12.0.0a0
- output_types: conda
packages:
- dask-core==2023.3.2
- output_types: pyproject
packages:
- &cupy_pip cupy-cuda11x>=9.5.0,<12.0.0a0
Expand Down
21 changes: 13 additions & 8 deletions python/cuml/cluster/hdbscan/prediction.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML":
float* X,
DistanceType metric,
float* membership_vec,
int batch_size)
size_t batch_size)

void compute_membership_vector(
const handle_t& handle,
Expand All @@ -107,7 +107,7 @@ cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML":
int min_samples,
DistanceType metric,
float* membership_vec,
int batch_size);
size_t batch_size);

void out_of_sample_predict(const handle_t &handle,
CondensedHierarchy[int, float] &condensed_tree,
Expand All @@ -131,7 +131,7 @@ _metrics_mapping = {
}


def all_points_membership_vectors(clusterer, batch_size=0):
def all_points_membership_vectors(clusterer, batch_size=4096):

"""
Predict soft cluster membership vectors for all points in the
Expand All @@ -145,11 +145,13 @@ def all_points_membership_vectors(clusterer, batch_size=0):
A clustering object that has been fit to the data and
had ``prediction_data=True`` set.

batch_size : int, optional, default=0
batch_size : int, optional, default=min(4096, n_rows)
Lowers memory requirement by computing distance-based membership in
smaller batches of points in the training data. Batch size of 0 uses
all of the training points, batch size of 1000 computes distances for
1000 points at a time.
1000 points at a time. The default batch_size is 4096. If the number
of rows in the original dataset is less than 4096, this defaults to
the number of rows.

Returns
-------
Expand Down Expand Up @@ -214,6 +216,7 @@ def all_points_membership_vectors(clusterer, batch_size=0):
<CondensedHierarchy[int, float]*><size_t> clusterer.condensed_tree_ptr

cdef handle_t* handle_ = <handle_t*><size_t>clusterer.handle.getHandle()

compute_all_points_membership_vectors(handle_[0],
deref(condensed_tree),
deref(prediction_data_),
Expand All @@ -229,7 +232,7 @@ def all_points_membership_vectors(clusterer, batch_size=0):
clusterer.n_clusters_))


def membership_vector(clusterer, points_to_predict, batch_size=0, convert_dtype=True):
def membership_vector(clusterer, points_to_predict, batch_size=4096, convert_dtype=True):
"""Predict soft cluster membership. The result produces a vector
for each point in ``points_to_predict`` that gives a probability that
the given point is a member of a cluster for each of the selected clusters
Expand All @@ -247,11 +250,13 @@ def membership_vector(clusterer, points_to_predict, batch_size=0, convert_dtype=
have the same dimensionality as the original dataset over which
clusterer was fit.

batch_size : int, optional, default=0
batch_size : int, optional, default=min(4096, n_points_to_predict)
Lowers memory requirement by computing distance-based membership in
smaller batches of points in the training data. Batch size of 0 uses
all of the training points, batch size of 1000 computes distances for
1000 points at a time.
1000 points at a time. The default batch_size is 4096. If the number
of rows in the original dataset is less than 4096, this defaults to
the number of rows.

Returns
-------
Expand Down
12 changes: 6 additions & 6 deletions python/cuml/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -531,7 +531,7 @@ def test_hdbscan_plots():
@pytest.mark.parametrize("cluster_selection_epsilon", [0.0, 0.5])
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("batch_size", [0, 128])
@pytest.mark.parametrize("batch_size", [128, 1000])
def test_all_points_membership_vectors_blobs(
nrows,
ncols,
Expand Down Expand Up @@ -593,7 +593,7 @@ def test_all_points_membership_vectors_blobs(
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("connectivity", ["knn"])
@pytest.mark.parametrize("batch_size", [0, 128])
@pytest.mark.parametrize("batch_size", [128, 1000])
def test_all_points_membership_vectors_moons(
nrows,
min_samples,
Expand Down Expand Up @@ -650,7 +650,7 @@ def test_all_points_membership_vectors_moons(
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("connectivity", ["knn"])
@pytest.mark.parametrize("batch_size", [0, 128])
@pytest.mark.parametrize("batch_size", [128, 1000])
def test_all_points_membership_vectors_circles(
nrows,
min_samples,
Expand Down Expand Up @@ -981,7 +981,7 @@ def test_approximate_predict_digits(
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("allow_single_cluster", [True, False])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("batch_size", [0, 128])
@pytest.mark.parametrize("batch_size", [128])
def test_membership_vector_blobs(
nrows,
n_points_to_predict,
Expand Down Expand Up @@ -1057,7 +1057,7 @@ def test_membership_vector_blobs(
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("connectivity", ["knn"])
@pytest.mark.parametrize("batch_size", [0, 16])
@pytest.mark.parametrize("batch_size", [16])
def test_membership_vector_moons(
nrows,
n_points_to_predict,
Expand Down Expand Up @@ -1121,7 +1121,7 @@ def test_membership_vector_moons(
@pytest.mark.parametrize("max_cluster_size", [0])
@pytest.mark.parametrize("cluster_selection_method", ["eom", "leaf"])
@pytest.mark.parametrize("connectivity", ["knn"])
@pytest.mark.parametrize("batch_size", [0, 16])
@pytest.mark.parametrize("batch_size", [16])
def test_membership_vector_circles(
nrows,
n_points_to_predict,
Expand Down
2 changes: 1 addition & 1 deletion python/cuml/tests/test_mbsgd_regressor.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def test_mbsgd_regressor_vs_skl(lrate, penalty, make_dataset):
skl_sgd_regressor.fit(cp.asnumpy(X_train), cp.asnumpy(y_train).ravel())
skl_pred = skl_sgd_regressor.predict(cp.asnumpy(X_test))
skl_r2 = r2_score(skl_pred, cp.asnumpy(y_test), convert_dtype=datatype)
assert abs(cu_r2 - skl_r2) <= 0.02
assert abs(cu_r2 - skl_r2) <= 0.021


@pytest.mark.parametrize(
Expand Down
Loading