@@ -455,6 +455,74 @@ void fit(const raft::resources& handle,
455455 raft::device_matrix_view<const int8_t , int64_t > X,
456456 raft::device_matrix_view<float , int64_t > centroids);
457457
458+ /* *
459+ * @brief Find balanced clusters with k-means algorithm.
460+ *
461+ * @code{.cpp}
462+ * #include <raft/core/resources.hpp>
463+ * #include <cuvs/cluster/kmeans.hpp>
464+ * using namespace cuvs::cluster;
465+ * ...
466+ * raft::resources handle;
467+ * cuvs::cluster::kmeans::balanced_params params;
468+ * int64_t n_features = 15, n_clusters = 8;
469+ * auto centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);
470+ *
471+ * kmeans::fit(handle,
472+ * params,
473+ * X,
474+ * centroids);
475+ * @endcode
476+ *
477+ * @param[in] handle The raft handle.
478+ * @param[in] params Parameters for KMeans model.
479+ * @param[in] X Training instances to cluster. The data must
480+ * be in row-major format.
481+ * [dim = n_samples x n_features]
482+ * @param[inout] centroids [out] The generated centroids from the
483+ * kmeans algorithm are stored at the address
484+ * pointed by 'centroids'.
485+ * [dim = n_clusters x n_features]
486+ */
487+ void fit (const raft::resources& handle,
488+ cuvs::cluster::kmeans::balanced_params const & params,
489+ raft::device_matrix_view<const half, int64_t > X,
490+ raft::device_matrix_view<float , int64_t > centroids);
491+
492+ /* *
493+ * @brief Find balanced clusters with k-means algorithm.
494+ *
495+ * @code{.cpp}
496+ * #include <raft/core/resources.hpp>
497+ * #include <cuvs/cluster/kmeans.hpp>
498+ * using namespace cuvs::cluster;
499+ * ...
500+ * raft::resources handle;
501+ * cuvs::cluster::kmeans::balanced_params params;
502+ * int64_t n_features = 15, n_clusters = 8;
503+ * auto centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);
504+ *
505+ * kmeans::fit(handle,
506+ * params,
507+ * X,
508+ * centroids);
509+ * @endcode
510+ *
511+ * @param[in] handle The raft handle.
512+ * @param[in] params Parameters for KMeans model.
513+ * @param[in] X Training instances to cluster. The data must
514+ * be in row-major format.
515+ * [dim = n_samples x n_features]
516+ * @param[inout] centroids [out] The generated centroids from the
517+ * kmeans algorithm are stored at the address
518+ * pointed by 'centroids'.
519+ * [dim = n_clusters x n_features]
520+ */
521+ void fit (const raft::resources& handle,
522+ cuvs::cluster::kmeans::balanced_params const & params,
523+ raft::device_matrix_view<const uint8_t , int64_t > X,
524+ raft::device_matrix_view<float , int64_t > centroids);
525+
458526/* *
459527 * @brief Predict the closest cluster each sample in X belongs to.
460528 *
@@ -819,6 +887,138 @@ void predict(const raft::resources& handle,
819887 raft::device_matrix_view<const float , int64_t > centroids,
820888 raft::device_vector_view<int , int64_t > labels);
821889
890+ /* *
891+ * @brief Predict the closest cluster each sample in X belongs to.
892+ *
893+ * @code{.cpp}
894+ * #include <raft/core/resources.hpp>
895+ * #include <cuvs/cluster/kmeans.hpp>
896+ * using namespace cuvs::cluster;
897+ * ...
898+ * raft::resources handle;
899+ * cuvs::cluster::kmeans::balanced_params params;
900+ * int64_t n_features = 15, n_clusters = 8;
901+ * auto centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);
902+ *
903+ * kmeans::fit(handle,
904+ * params,
905+ * X,
906+ * centroids.view());
907+ * ...
908+ * auto labels = raft::make_device_vector<uint32_t, int64_t>(handle, X.extent(0));
909+ *
910+ * kmeans::predict(handle,
911+ * params,
912+ * X,
913+ * centroids.view(),
914+ * labels.view());
915+ * @endcode
916+ *
917+ * @param[in] handle The raft handle.
918+ * @param[in] params Parameters for KMeans model.
919+ * @param[in] X New data to predict.
920+ * [dim = n_samples x n_features]
921+ * @param[in] centroids Cluster centroids. The data must be in
922+ * row-major format.
923+ * [dim = n_clusters x n_features]
924+ * @param[out] labels Index of the cluster each sample in X
925+ * belongs to.
926+ * [len = n_samples]
927+ */
928+ void predict (const raft::resources& handle,
929+ cuvs::cluster::kmeans::balanced_params const & params,
930+ raft::device_matrix_view<const float , int64_t > X,
931+ raft::device_matrix_view<const float , int64_t > centroids,
932+ raft::device_vector_view<uint32_t , int64_t > labels);
933+
934+ /* *
935+ * @brief Predict the closest cluster each sample in X belongs to.
936+ *
937+ * @code{.cpp}
938+ * #include <raft/core/resources.hpp>
939+ * #include <cuvs/cluster/kmeans.hpp>
940+ * using namespace cuvs::cluster;
941+ * ...
942+ * raft::resources handle;
943+ * cuvs::cluster::kmeans::balanced_params params;
944+ * int64_t n_features = 15, n_clusters = 8;
945+ * auto centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);
946+ *
947+ * kmeans::fit(handle,
948+ * params,
949+ * X,
950+ * centroids.view());
951+ * ...
952+ * auto labels = raft::make_device_vector<uint32_t, int64_t>(handle, X.extent(0));
953+ *
954+ * kmeans::predict(handle,
955+ * params,
956+ * X,
957+ * centroids.view(),
958+ * labels.view());
959+ * @endcode
960+ *
961+ * @param[in] handle The raft handle.
962+ * @param[in] params Parameters for KMeans model.
963+ * @param[in] X New data to predict.
964+ * [dim = n_samples x n_features]
965+ * @param[in] centroids Cluster centroids. The data must be in
966+ * row-major format.
967+ * [dim = n_clusters x n_features]
968+ * @param[out] labels Index of the cluster each sample in X
969+ * belongs to.
970+ * [len = n_samples]
971+ */
972+ void predict (const raft::resources& handle,
973+ cuvs::cluster::kmeans::balanced_params const & params,
974+ raft::device_matrix_view<const half, int64_t > X,
975+ raft::device_matrix_view<const float , int64_t > centroids,
976+ raft::device_vector_view<uint32_t , int64_t > labels);
977+
978+ /* *
979+ * @brief Predict the closest cluster each sample in X belongs to.
980+ *
981+ * @code{.cpp}
982+ * #include <raft/core/resources.hpp>
983+ * #include <cuvs/cluster/kmeans.hpp>
984+ * using namespace cuvs::cluster;
985+ * ...
986+ * raft::resources handle;
987+ * cuvs::cluster::kmeans::balanced_params params;
988+ * int64_t n_features = 15, n_clusters = 8;
989+ * auto centroids = raft::make_device_matrix<float, int64_t>(handle, n_clusters, n_features);
990+ *
991+ * kmeans::fit(handle,
992+ * params,
993+ * X,
994+ * centroids.view());
995+ * ...
996+ * auto labels = raft::make_device_vector<uint32_t, int64_t>(handle, X.extent(0));
997+ *
998+ * kmeans::predict(handle,
999+ * params,
1000+ * X,
1001+ * centroids.view(),
1002+ * labels.view());
1003+ * @endcode
1004+ *
1005+ * @param[in] handle The raft handle.
1006+ * @param[in] params Parameters for KMeans model.
1007+ * @param[in] X New data to predict.
1008+ * [dim = n_samples x n_features]
1009+ * @param[in] centroids Cluster centroids. The data must be in
1010+ * row-major format.
1011+ * [dim = n_clusters x n_features]
1012+ * @param[out] labels Index of the cluster each sample in X
1013+ * belongs to.
1014+ * [len = n_samples]
1015+ */
1016+ void predict (const raft::resources& handle,
1017+ cuvs::cluster::kmeans::balanced_params const & params,
1018+ raft::device_matrix_view<const uint8_t , int64_t > X,
1019+ raft::device_matrix_view<const float , int64_t > centroids,
1020+ raft::device_vector_view<uint32_t , int64_t > labels);
1021+
8221022/* *
8231023 * @brief Compute k-means clustering and predicts cluster index for each sample
8241024 * in the input.
0 commit comments