K-Means#

#include <raft/cluster/kmeans_balanced.cuh>

template<typename DataT, typename MathT, typename IndexT, typename MappingOpT = raft::identity_op>
void raft::cluster::kmeans_balanced::fit(const raft::resources &handle, kmeans_balanced_params const &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<MathT, IndexT> centroids, MappingOpT mapping_op = raft::identity_op())#

Find clusters of balanced sizes with a hierarchical k-means algorithm.

This variant of the k-means algorithm first clusters the dataset in mesoclusters, then clusters the subsets associated to each mesocluster into fine clusters, and finally runs a few k-means iterations over the whole dataset and with all the centroids to obtain the final clusters.

Each k-means iteration applies expectation-maximization-balancing:

  • Balancing: adjust centers for clusters that have a small number of entries. If the size of a cluster is below a threshold, the center is moved towards a bigger cluster.

  • Expectation: predict the labels (i.e find closest cluster centroid to each point)

  • Maximization: calculate optimal centroids (i.e find the center of gravity of each cluster)

The number of mesoclusters is chosen by rounding the square root of the number of clusters. E.g for 512 clusters, we would have 23 mesoclusters. The number of fine clusters per mesocluster is chosen proportionally to the number of points in each mesocluster.

This variant of k-means uses random initialization and a fixed number of iterations, though iterations can be repeated if the balancing step moved the centroids.

Additionally, this algorithm supports quantized datasets in arbitrary types but the core part of the algorithm will work with a floating-point type, hence a conversion function can be provided to map the data type to the math type.

#include <raft/core/handle.hpp>
#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/cluster/kmeans_balanced_types.hpp>
...
raft::handle_t handle;
raft::cluster::kmeans_balanced_params params;
auto centroids = raft::make_device_matrix<float, int>(handle, n_clusters, n_features);
raft::cluster::kmeans_balanced::fit(handle, params, X, centroids.view());
Template Parameters:
  • DataT – Type of the input data.

  • MathT – Type of the centroids and mapped data.

  • IndexT – Type used for indexing.

  • MappingOpT – Type of the mapping function.

Parameters:
  • handle[in] The raft resources

  • params[in] Structure containing the hyper-parameters

  • X[in] Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features]

  • centroids[out] The generated centroids [dim = n_clusters x n_features]

  • mapping_op[in] (optional) Functor to convert from the input datatype to the arithmetic datatype. If DataT == MathT, this must be the identity.

template<typename DataT, typename MathT, typename IndexT, typename LabelT, typename MappingOpT = raft::identity_op>
void raft::cluster::kmeans_balanced::predict(const raft::resources &handle, kmeans_balanced_params const &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<const MathT, IndexT> centroids, raft::device_vector_view<LabelT, IndexT> labels, MappingOpT mapping_op = raft::identity_op())#

Predict the closest cluster each sample in X belongs to.

#include <raft/core/handle.hpp>
#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/cluster/kmeans_balanced_types.hpp>
...
raft::handle_t handle;
raft::cluster::kmeans_balanced_params params;
auto labels = raft::make_device_vector<float, int>(handle, n_rows);
raft::cluster::kmeans_balanced::predict(handle, params, X, centroids, labels);
Template Parameters:
  • DataT – Type of the input data.

  • MathT – Type of the centroids and mapped data.

  • IndexT – Type used for indexing.

  • LabelT – Type of the output labels.

  • MappingOpT – Type of the mapping function.

Parameters:
  • handle[in] The raft resources

  • params[in] Structure containing the hyper-parameters

  • X[in] Dataset for which to infer the closest clusters. [dim = n_samples x n_features]

  • centroids[in] The input centroids [dim = n_clusters x n_features]

  • labels[out] The output labels [dim = n_samples]

  • mapping_op[in] (optional) Functor to convert from the input datatype to the arithmetic datatype. If DataT == MathT, this must be the identity.

template<typename DataT, typename MathT, typename IndexT, typename LabelT, typename MappingOpT = raft::identity_op>
void raft::cluster::kmeans_balanced::fit_predict(const raft::resources &handle, kmeans_balanced_params const &params, raft::device_matrix_view<const DataT, IndexT> X, raft::device_matrix_view<MathT, IndexT> centroids, raft::device_vector_view<LabelT, IndexT> labels, MappingOpT mapping_op = raft::identity_op())#

Compute hierarchical balanced k-means clustering and predict cluster index for each sample in the input.

#include <raft/core/handle.hpp>
#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/cluster/kmeans_balanced_types.hpp>
...
raft::handle_t handle;
raft::cluster::kmeans_balanced_params params;
auto centroids = raft::make_device_matrix<float, int>(handle, n_clusters, n_features);
auto labels = raft::make_device_vector<float, int>(handle, n_rows);
raft::cluster::kmeans_balanced::fit_predict(
    handle, params, X, centroids.view(), labels.view());
Template Parameters:
  • DataT – Type of the input data.

  • MathT – Type of the centroids and mapped data.

  • IndexT – Type used for indexing.

  • LabelT – Type of the output labels.

  • MappingOpT – Type of the mapping function.

Parameters:
  • handle[in] The raft resources

  • params[in] Structure containing the hyper-parameters

  • X[in] Training instances to cluster. The data must be in row-major format. [dim = n_samples x n_features]

  • centroids[out] The output centroids [dim = n_clusters x n_features]

  • labels[out] The output labels [dim = n_samples]

  • mapping_op[in] (optional) Functor to convert from the input datatype to the arithmetic datatype. If DataT and MathT are the same, this must be the identity.

struct kmeans_balanced_params : public raft::cluster::kmeans_base_params#
#include <kmeans_balanced_types.hpp>

Simple object to specify hyper-parameters to the balanced k-means algorithm.

The following metrics are currently supported in k-means balanced:

  • InnerProduct

  • L2Expanded

  • L2SqrtExpanded

Public Members

uint32_t n_iters = 20#

Number of training iterations