Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allowing large data in kmeans #5228

Merged
merged 15 commits into from
Feb 14, 2023
Merged
5 changes: 4 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -376,7 +376,10 @@ if(BUILD_CUML_CPP_LIBRARY)
if(all_algo OR kmeans_algo)
target_sources(${CUML_CPP_TARGET}
PRIVATE
src/kmeans/kmeans.cu)
src/kmeans/kmeans_transform.cu
src/kmeans/kmeans_fit_predict.cu
src/kmeans/kmeans_predict.cu
)
endif()

if(all_algo OR knn_algo)
Expand Down
95 changes: 54 additions & 41 deletions cpp/include/cuml/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,48 +74,27 @@ void fit_predict(const raft::handle_t& handle,
int* labels,
double& inertia,
int& n_iter);
void fit_predict(const raft::handle_t& handle,
const KMeansParams& params,
const float* X,
int64_t n_samples,
int64_t n_features,
const float* sample_weight,
float* centroids,
int64_t* labels,
float& inertia,
int64_t& n_iter);

/**
* @brief Compute k-means clustering.
*
* @param[in] handle The handle to the cuML library context that
manages the CUDA resources.
* @param[in] params Parameters for KMeans model.
* @param[in] X Training instances to cluster. It must be noted
that the data must be in row-major format and stored in device accessible
* location.
* @param[in] n_samples Number of samples in the input X.
* @param[in] n_features Number of features or the dimensions of each
* sample.
* @param[in] sample_weight The weights for each observation in X.
* @param[inout] centroids [in] When init is InitMethod::Array, use
centroids as the initial cluster centers
* [out] Otherwise, generated centroids from the
kmeans algorithm is stored at the address pointed by 'centroids'.
* @param[out] inertia Sum of squared distances of samples to their
closest cluster center.
* @param[out] n_iter Number of iterations run.
*/

void fit(const raft::handle_t& handle,
const KMeansParams& params,
const float* X,
int n_samples,
int n_features,
const float* sample_weight,
float* centroids,
float& inertia,
int& n_iter);

void fit(const raft::handle_t& handle,
const KMeansParams& params,
const double* X,
int n_samples,
int n_features,
const double* sample_weight,
double* centroids,
double& inertia,
int& n_iter);
void fit_predict(const raft::handle_t& handle,
const KMeansParams& params,
const double* X,
int64_t n_samples,
int64_t n_features,
const double* sample_weight,
double* centroids,
int64_t* labels,
double& inertia,
int64_t& n_iter);

/**
* @brief Predict the closest cluster each sample in X belongs to.
Expand Down Expand Up @@ -160,7 +139,27 @@ void predict(const raft::handle_t& handle,
bool normalize_weights,
int* labels,
double& inertia);
void predict(const raft::handle_t& handle,
const KMeansParams& params,
const float* centroids,
const float* X,
int64_t n_samples,
int64_t n_features,
const float* sample_weight,
bool normalize_weights,
int64_t* labels,
float& inertia);

void predict(const raft::handle_t& handle,
const KMeansParams& params,
const double* centroids,
const double* X,
int64_t n_samples,
int64_t n_features,
const double* sample_weight,
bool normalize_weights,
int64_t* labels,
double& inertia);
/**
* @brief Transform X to a cluster-distance space.
*
Expand Down Expand Up @@ -193,6 +192,20 @@ void transform(const raft::handle_t& handle,
int n_samples,
int n_features,
double* X_new);
void transform(const raft::handle_t& handle,
const KMeansParams& params,
const float* centroids,
const float* X,
int64_t n_samples,
int64_t n_features,
float* X_new);

void transform(const raft::handle_t& handle,
const KMeansParams& params,
const double* centroids,
const double* X,
int64_t n_samples,
int64_t n_features,
double* X_new);
}; // end namespace kmeans
}; // end namespace ML
19 changes: 19 additions & 0 deletions cpp/include/cuml/cluster/kmeans_mg.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,25 @@ void fit(const raft::handle_t& handle,
double& inertia,
int& n_iter);

void fit(const raft::handle_t& handle,
const KMeansParams& params,
const float* X,
int64_t n_samples,
int64_t n_features,
const float* sample_weight,
float* centroids,
float& inertia,
int64_t& n_iter);

void fit(const raft::handle_t& handle,
const KMeansParams& params,
const double* X,
int64_t n_samples,
int64_t n_features,
const double* sample_weight,
double* centroids,
double& inertia,
int64_t& n_iter);
}; // end namespace opg
}; // end namespace kmeans
}; // end namespace ML
203 changes: 0 additions & 203 deletions cpp/src/kmeans/kmeans.cu

This file was deleted.

Loading