diff --git a/cpp/include/cuvs/cluster/kmeans.hpp b/cpp/include/cuvs/cluster/kmeans.hpp index d299d9f483..f03888298e 100644 --- a/cpp/include/cuvs/cluster/kmeans.hpp +++ b/cpp/include/cuvs/cluster/kmeans.hpp @@ -167,6 +167,13 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * on the host. Data is processed in GPU-sized batches, streaming from host to device. * The batch size is controlled by params.streaming_batch_size. * + * Multi-GPU dispatch is selected automatically based on the handle state: + * - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X + * is split across GPUs internally with an OpenMP parallel region and NCCL. + * - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as + * this worker's partition, and RAFT communicators are used for collectives. + * - Otherwise: single-GPU batched k-means. + * * @code{.cpp} * #include * #include @@ -196,7 +203,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 }; * raft::make_host_scalar_view(&n_iter)); * @endcode * - * @param[in] handle The raft handle. + * @param[in] handle The raft handle. When a multi-GPU resource is + * attached, multi-GPU dispatch is used automatically. * @param[in] params Parameters for KMeans model. Batch size is read from * params.streaming_batch_size. * @param[in] X Training instances on HOST memory. The data must diff --git a/cpp/src/cluster/detail/kmeans_common.cuh b/cpp/src/cluster/detail/kmeans_common.cuh index 250563dd12..8af6114b4b 100644 --- a/cpp/src/cluster/detail/kmeans_common.cuh +++ b/cpp/src/cluster/detail/kmeans_common.cuh @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -43,6 +44,9 @@ #include #include #include +#include + +#include #include #include @@ -129,43 +133,36 @@ void countLabels(raft::resources const& handle, stream)); } -template -void checkWeight(raft::resources const& handle, - raft::device_vector_view weight, - rmm::device_uvector& workspace) +/** + * @brief Compute the sum of sample weights. + * + * Device-accessible mdspans are reduced on device via mapThenSumReduce; + * host mdspans are summed on the host. + * + * @return Sum of weights. + */ +template +DataT weightSum( + raft::resources const& handle, + raft::mdspan, raft::layout_right, Accessor> weight) { - cudaStream_t stream = raft::resource::get_cuda_stream(handle); - auto wt_aggr = raft::make_device_scalar(handle, 0); - auto n_samples = weight.extent(0); - - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(), - temp_storage_bytes, - weight.data_handle(), - wt_aggr.data_handle(), - n_samples, - stream)); - DataT wt_sum = 0; - raft::copy(handle, - raft::make_host_scalar_view(&wt_sum), - raft::make_device_scalar_view(wt_aggr.data_handle())); - raft::resource::sync_stream(handle, stream); - - if (wt_sum != n_samples) { - RAFT_LOG_DEBUG( - "[Warning!] KMeans: normalizing the user provided sample weight to " - "sum up to %d samples", - n_samples); - - auto scale = static_cast(n_samples) / wt_sum; - raft::linalg::map( - handle, weight, raft::mul_const_op{scale}, raft::make_const_mdspan(weight)); + auto n_samples = weight.extent(0); + + DataT wt_sum = DataT{0}; + if constexpr (raft::is_device_mdspan_v) { + auto stream = raft::resource::get_cuda_stream(handle); + auto d_wt_sum = raft::make_device_scalar(handle, DataT{0}); + raft::linalg::mapThenSumReduce( + d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle()); + raft::copy(&wt_sum, d_wt_sum.data_handle(), 1, stream); + raft::resource::sync_stream(handle); + } else { + for (IndexT i = 0; i < n_samples; ++i) { + wt_sum += weight(i); + } } + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); + return wt_sum; } template @@ -262,7 +259,7 @@ void sampleCentroids(raft::resources const& handle, raft::copy(handle, raft::make_host_scalar_view(&nPtsSampledInRank), raft::make_device_scalar_view(nSelected.data_handle())); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle(); thrust::for_each_n(raft::resource::get_thrust_policy(handle), @@ -367,7 +364,9 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms = + std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \ extern template void minClusterAndDistanceCompute( \ @@ -380,7 +379,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int) @@ -399,7 +399,9 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, + std::optional> + precomputed_centroid_norms = std::nullopt); #define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \ extern template void minClusterDistanceCompute( \ @@ -412,7 +414,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t) EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t) @@ -484,14 +487,18 @@ void countSamplesInCluster(raft::resources const& handle, * @tparam IndexT Index type * @tparam LabelsIterator Iterator type for cluster labels * - * @param[in] handle RAFT resources handle - * @param[in] X Input samples [n_samples x n_features] - * @param[in] sample_weights Weights for each sample [n_samples] - * @param[in] cluster_labels Cluster assignment for each sample (iterator) - * @param[in] n_clusters Number of clusters - * @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features] - * @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters] - * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] handle RAFT resources handle + * @param[in] X Input samples [n_samples x n_features] + * @param[in] sample_weights Weights for each sample [n_samples] + * @param[in] cluster_labels Cluster assignment for each sample (iterator) + * @param[in] n_clusters Number of clusters + * @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features] + * @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same + * overwrite-vs-accumulate semantics as `centroid_sums` + * @param[inout] workspace Workspace buffer for intermediate operations + * @param[in] reset_sums If true (default), outputs are reset to zero before reducing; + * if false, this call's contribution is accumulated into the + * existing `centroid_sums` */ template void compute_centroid_adjustments( @@ -502,7 +509,8 @@ void compute_centroid_adjustments( IndexT n_clusters, raft::device_matrix_view centroid_sums, raft::device_vector_view weight_per_cluster, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + bool reset_sums = true) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -518,7 +526,8 @@ void compute_centroid_adjustments( X.extent(1), n_clusters, centroid_sums.data_handle(), - stream); + stream, + reset_sums); raft::linalg::reduce_cols_by_key(sample_weights.data_handle(), cluster_labels, @@ -526,7 +535,8 @@ void compute_centroid_adjustments( static_cast(1), static_cast(n_samples), n_clusters, - stream); + stream, + reset_sums); } /** * @brief Finalize centroids by dividing accumulated sums by counts. @@ -594,8 +604,109 @@ DataT compute_centroid_shift(raft::resources const& handle, new_centroids.data_handle()); DataT result = 0; raft::copy(&result, sqrdNorm.data_handle(), 1, stream); - raft::resource::sync_stream(handle, stream); + raft::resource::sync_stream(handle); return result; } +/** + * @brief Process a single batch of data in the Lloyd iteration. + * + * Given one batch of data + precomputed norms + weights + current centroids it + * 1. finds the nearest centroid for every sample, + * 2. accumulates weighted centroid sums and counts into the running accumulators, + * 3. accumulates the weighted clustering cost (inertia). + * + * Data norms must be precomputed by the caller and passed in via L2NormBatch. + * + * @tparam DataT Data / weight type (float, double) + * @tparam IndexT Index type (int, int64_t) + * + * @param[in] handle RAFT resources handle + * @param[in] batch_data Device batch data [batch_size x n_features] + * @param[in] batch_weights Device batch weights [batch_size] + * @param[in] centroids Current centroids [n_clusters x n_features] + * @param[in] metric Distance metric + * @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute + * @param[in] batch_centroids_param Batch-centroids param forwarded to + * minClusterAndDistanceCompute + * @param[inout] minClusterAndDistance Work buffer [batch_size] + * @param[in] L2NormBatch Precomputed data norms [batch_size] + * @param[inout] L2NormBuf_OR_DistBuf Resizable scratch + * @param[inout] workspace Resizable scratch + * @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into) + * @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into) + * @param[inout] clustering_cost Running cost scalar (device) (added into) + * @param[in] centroid_norms Optional precomputed centroid norms [n_clusters]. + * When provided, skips internal centroid norm computation. + */ +template +void process_batch( + raft::resources const& handle, + raft::device_matrix_view batch_data, + raft::device_vector_view batch_weights, + raft::device_matrix_view centroids, + cuvs::distance::DistanceType metric, + int batch_samples_param, + int batch_centroids_param, + raft::device_vector_view, IndexT> minClusterAndDistance, + raft::device_vector_view L2NormBatch, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + rmm::device_uvector& workspace, + raft::device_matrix_view centroid_sums, + raft::device_vector_view weight_per_cluster, + raft::device_scalar_view clustering_cost, + rmm::device_uvector& batch_workspace, + std::optional> centroid_norms = std::nullopt) +{ + cudaStream_t stream = raft::resource::get_cuda_stream(handle); + + minClusterAndDistanceCompute(handle, + batch_data, + centroids, + minClusterAndDistance, + L2NormBatch, + L2NormBuf_OR_DistBuf, + metric, + batch_samples_param, + batch_centroids_param, + workspace, + centroid_norms); + + KeyValueIndexOp conversion_op; + thrust::transform_iterator, + const raft::KeyValuePair*> + labels_itr(minClusterAndDistance.data_handle(), conversion_op); + + compute_centroid_adjustments(handle, + batch_data, + batch_weights, + labels_itr, + static_cast(centroid_sums.extent(0)), + centroid_sums, + weight_per_cluster, + batch_workspace, + /*reset_sums=*/false); + + raft::linalg::map( + handle, + minClusterAndDistance, + [=] __device__(const raft::KeyValuePair kvp, DataT wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance), + batch_weights); + + auto batch_cost = raft::make_device_scalar(handle, DataT{0}); + computeClusterCost( + handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{}); + raft::linalg::add(clustering_cost.data_handle(), + clustering_cost.data_handle(), + batch_cost.data_handle(), + 1, + stream); +} + } // namespace cuvs::cluster::kmeans::detail diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index fdec2bdd73..736baf6b7d 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -6,6 +6,7 @@ #pragma once #include "../kmeans.cuh" +#include "kmeans_common.cuh" #include #include @@ -15,14 +16,15 @@ #include #include #include +#include #include #include #include #include #include #include -#include #include +#include #include #include #include @@ -33,7 +35,9 @@ #include #include +#include #include +#include #include #include @@ -467,15 +471,9 @@ void checkWeights(const raft::resources& handle, const auto& comm = raft::resource::get_comms(handle); - auto n_samples = weight.extent(0); - size_t temp_storage_bytes = 0; - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); - - workspace.resize(temp_storage_bytes, stream); - - RAFT_CUDA_TRY(cub::DeviceReduce::Sum( - workspace.data(), temp_storage_bytes, weight.data_handle(), wt_aggr.data(), n_samples, stream)); + auto n_samples = weight.extent(0); + raft::linalg::mapThenSumReduce( + wt_aggr.data(), n_samples, raft::identity_op{}, stream, weight.data_handle()); comm.allreduce(wt_aggr.data(), // sendbuff wt_aggr.data(), // recvbuff @@ -484,16 +482,21 @@ void checkWeights(const raft::resources& handle, stream); DataT wt_sum = wt_aggr.value(stream); raft::resource::sync_stream(handle, stream); + RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)"); - if (wt_sum != n_samples) { + const auto target = static_cast(n_samples); + const DataT tol = target * std::numeric_limits::epsilon(); + if (std::abs(wt_sum - target) > tol) { CUVS_LOG_KMEANS(handle, "[Warning!] KMeans: normalizing the user provided sample weights to " "sum up to %d samples", n_samples); - DataT scale = n_samples / wt_sum; raft::linalg::map( - handle, weight, raft::mul_const_op(scale), raft::make_const_mdspan(weight)); + handle, + weight, + raft::compose_op(raft::mul_const_op{target}, raft::div_const_op{wt_sum}), + raft::make_const_mdspan(weight)); } } @@ -557,9 +560,15 @@ void fit(const raft::resources& handle, // resource auto newCentroids = raft::make_device_matrix(handle, n_clusters, n_features); - // temporary buffer to store the weights per cluster, destructor releases - // the resource - auto wtInCluster = raft::make_device_vector(handle, n_clusters); + // Running weighted sum of samples per cluster and sum of weights per cluster, + // accumulated by process_batch and allreduced each iteration. + auto centroid_sums = raft::make_device_matrix(handle, n_clusters, n_features); + auto weight_per_cluster = raft::make_device_vector(handle, n_clusters); + auto clustering_cost = raft::make_device_scalar(handle, DataT{0}); + + // Separate workspace for compute_centroid_adjustments (distinct from the one + // used by minClusterAndDistanceCompute inside process_batch). + rmm::device_uvector batch_workspace(0, stream); // L2 norm of X: ||x||^2 auto L2NormX = raft::make_device_vector(handle, n_samples); @@ -571,6 +580,11 @@ void fit(const raft::resources& handle, X.data_handle(), n_samples, n_features), L2NormX.view()); } + auto L2NormX_const = + raft::make_device_vector_view(L2NormX.data_handle(), L2NormX.extent(0)); + + auto weight_const = + raft::make_device_vector_view(weight.data_handle(), weight.extent(0)); DataT priorClusteringCost = 0; for (n_iter[0] = 1; n_iter[0] <= params.max_iter; ++n_iter[0]) { @@ -581,171 +595,82 @@ void fit(const raft::resources& handle, auto const_centroids = raft::make_device_matrix_view( centroids.data_handle(), centroids.extent(0), centroids.extent(1)); - // computes minClusterAndDistance[0:n_samples) where - // minClusterAndDistance[i] is a pair where - // 'key' is index to an sample in 'centroids' (index of the nearest - // centroid) and 'value' is the distance between the sample 'X[i]' and the - // 'centroid[key]' - cuvs::cluster::kmeans::min_cluster_and_distance(handle, - X, - const_centroids, - minClusterAndDistance.view(), - L2NormX.view(), - L2NormBuf_OR_DistBuf, - params.metric, - params.batch_samples, - params.batch_centroids, - workspace); - - workspace.resize(n_samples, stream); - - cuda::transform_iterator keys_itr( - minClusterAndDistance.data_handle(), - cuvs::cluster::kmeans::detail::KeyValueIndexOp{}); - raft::linalg::reduce_rows_by_key((DataT*)X.data_handle(), - X.extent(1), - keys_itr, - weight.data_handle(), - workspace.data(), - X.extent(0), - X.extent(1), - static_cast(n_clusters), - newCentroids.data_handle(), - stream); - - // Reduce weights by key to compute weight in each cluster - raft::linalg::reduce_cols_by_key(weight.data_handle(), - keys_itr, - wtInCluster.data_handle(), - (IndexT)1, - (IndexT)weight.extent(0), - (IndexT)n_clusters, - stream); - // merge the local histogram from all ranks - comm.allreduce(wtInCluster.data_handle(), // sendbuff - wtInCluster.data_handle(), // recvbuff - wtInCluster.size(), // count + // Reset running accumulators. process_batch accumulates into these. + raft::matrix::fill(handle, centroid_sums.view(), DataT{0}); + raft::matrix::fill(handle, weight_per_cluster.view(), DataT{0}); + raft::matrix::fill(handle, clustering_cost.view(), DataT{0}); + + // Local accumulation: process the entire local shard as a single batch. + cuvs::cluster::kmeans::detail::process_batch(handle, + X, + weight_const, + const_centroids, + params.metric, + params.batch_samples, + params.batch_centroids, + minClusterAndDistance.view(), + L2NormX_const, + L2NormBuf_OR_DistBuf, + workspace, + centroid_sums.view(), + weight_per_cluster.view(), + clustering_cost.view(), + batch_workspace); + + // Reduce partial sums, counts, and inertia across all ranks. + comm.allreduce(centroid_sums.data_handle(), + centroid_sums.data_handle(), + centroid_sums.size(), raft::comms::op_t::SUM, stream); - - // reduces newCentroids from all ranks - comm.allreduce(newCentroids.data_handle(), // sendbuff - newCentroids.data_handle(), // recvbuff - newCentroids.size(), // count + comm.allreduce(weight_per_cluster.data_handle(), + weight_per_cluster.data_handle(), + weight_per_cluster.size(), + raft::comms::op_t::SUM, + stream); + comm.allreduce(clustering_cost.data_handle(), + clustering_cost.data_handle(), + 1, raft::comms::op_t::SUM, stream); - // Computes newCentroids[i] = newCentroids[i]/wtInCluster[i] where - // newCentroids[n_clusters x n_features] - 2D array, newCentroids[i] has - // sum of all the samples assigned to cluster-i - // wtInCluster[n_clusters] - 1D array, wtInCluster[i] contains # of - // samples in cluster-i. - // Note - when wtInCluster[i] is 0, newCentroid[i] is reset to 0 - - raft::linalg::matrix_vector_op( + // Finalize: divide centroid sums by weight per cluster; empty clusters keep + // their old centroid. + cuvs::cluster::kmeans::detail::finalize_centroids( handle, - raft::make_const_mdspan(newCentroids.view()), - raft::make_const_mdspan(wtInCluster.view()), - newCentroids.view(), - cuda::proclaim_return_type([=] __device__(DataT mat, DataT vec) { - if (vec == 0) - return DataT(0); - else - return mat / vec; - })); - - // copy the centroids[i] to newCentroids[i] when wtInCluster[i] is 0 - cub::ArgIndexInputIterator itr_wt(wtInCluster.data_handle()); - raft::matrix::gather_if( - centroids.data_handle(), - centroids.extent(1), - centroids.extent(0), - itr_wt, - itr_wt, - wtInCluster.extent(0), - newCentroids.data_handle(), - cuda::proclaim_return_type( - [=] __device__(raft::KeyValuePair map) { // predicate - // copy when the # of samples in the cluster is 0 - if (map.value == 0) - return true; - else - return false; - }), - cuda::proclaim_return_type( - [=] __device__(raft::KeyValuePair map) { // map - return map.key; - }), - stream); - - // compute the squared norm between the newCentroids and the original - // centroids, destructor releases the resource - auto sqrdNorm = raft::make_device_scalar(handle, 1); - raft::linalg::mapThenSumReduce( - sqrdNorm.data_handle(), - newCentroids.size(), - cuda::proclaim_return_type([=] __device__(const DataT a, const DataT b) { - DataT diff = a - b; - return diff * diff; - }), - stream, - centroids.data_handle(), - newCentroids.data_handle()); - - DataT sqrdNormError = 0; - raft::copy(handle, raft::make_host_scalar_view(&sqrdNormError), sqrdNorm.view()); + raft::make_const_mdspan(centroid_sums.view()), + raft::make_const_mdspan(weight_per_cluster.view()), + const_centroids, + newCentroids.view()); + + // Convergence: squared norm shift between old and new centroids. + DataT sqrdNormError = cuvs::cluster::kmeans::detail::compute_centroid_shift( + handle, raft::make_const_mdspan(centroids), raft::make_const_mdspan(newCentroids.view())); raft::copy(handle, raft::make_device_vector_view(centroids.data_handle(), newCentroids.size()), raft::make_device_vector_view(newCentroids.data_handle(), newCentroids.size())); + // Read the globally-reduced clustering cost back to host for the + // inertia-based convergence check (always active). + DataT curClusteringCost = 0; + raft::copy(handle, + raft::make_host_scalar_view(&curClusteringCost), + raft::make_const_mdspan(clustering_cost.view())); + ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, + "An error occurred in the distributed operation. This can result " + "from a failed rank"); + bool done = false; - if (params.inertia_check) { - rmm::device_scalar> clusterCostD(stream); - - // calculate cluster cost phi_x(C) - cuvs::cluster::kmeans::cluster_cost( - handle, - minClusterAndDistance.view(), - workspace, - raft::make_device_scalar_view(clusterCostD.data()), - cuda::proclaim_return_type>( - [] __device__(const raft::KeyValuePair& a, - const raft::KeyValuePair& b) { - raft::KeyValuePair res; - res.key = 0; - res.value = a.value + b.value; - return res; - })); - - // Cluster cost phi_x(C) from all ranks - comm.allreduce(&(clusterCostD.data()->value), - &(clusterCostD.data()->value), - 1, - raft::comms::op_t::SUM, - stream); - - DataT curClusteringCost = 0; - raft::copy(handle, - raft::make_host_scalar_view(&curClusteringCost), - raft::make_device_scalar_view(&(clusterCostD.data()->value))); - - ASSERT(comm.sync_stream(stream) == raft::comms::status_t::SUCCESS, - "An error occurred in the distributed operation. This can result " - "from a failed rank"); - ASSERT(curClusteringCost != (DataT)0.0, - "Too few points and centroids being found is getting 0 cost from " - "centers\n"); - - if (n_iter[0] > 1) { - DataT delta = curClusteringCost / priorClusteringCost; - if (delta > 1 - params.tol) done = true; - } - priorClusteringCost = curClusteringCost; + if (curClusteringCost == DataT{0}) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (n_iter[0] > 1 && priorClusteringCost > DataT{0}) { + DataT delta = curClusteringCost / priorClusteringCost; + if (delta > 1 - params.tol) done = true; } + priorClusteringCost = curClusteringCost; - raft::resource::sync_stream(handle, stream); if (sqrdNormError < params.tol) done = true; if (done) { @@ -754,6 +679,8 @@ void fit(const raft::resources& handle, break; } } + + inertia[0] = priorClusteringCost; } }; // namespace cuvs::cluster::kmeans::mg::detail diff --git a/cpp/src/cluster/detail/kmeans_mg_batched.cuh b/cpp/src/cluster/detail/kmeans_mg_batched.cuh new file mode 100644 index 0000000000..ccf8df4272 --- /dev/null +++ b/cpp/src/cluster/detail/kmeans_mg_batched.cuh @@ -0,0 +1,569 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ +#pragma once + +#include "kmeans.cuh" +#include "kmeans_batched.cuh" +#include "kmeans_common.cuh" + +#include "../../core/omp_wrapper.hpp" +#include "../../neighbors/detail/ann_utils.cuh" + +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +#include +#include +#include +#include +#include +#include + +namespace cuvs::cluster::kmeans::mg::detail { + +// --------------------------------------------------------------------------- +// NCCL data-type helper +// --------------------------------------------------------------------------- +template +ncclDataType_t nccl_dtype(); + +template <> +inline ncclDataType_t nccl_dtype() +{ + return ncclFloat; +} +template <> +inline ncclDataType_t nccl_dtype() +{ + return ncclDouble; +} +template <> +inline ncclDataType_t nccl_dtype() +{ + return ncclInt64; +} + +// --------------------------------------------------------------------------- +// Comm macros — select raw NCCL vs RAFT comms at each call site. +// These are local to this translation unit and undef'd at the bottom. +// --------------------------------------------------------------------------- + +#define SNMG_ALLREDUCE(sendbuf, recvbuf, count) \ + do { \ + using _snmg_val_t = std::remove_pointer_t; \ + if (use_nccl) { \ + RAFT_NCCL_TRY(ncclAllReduce( \ + sendbuf, recvbuf, count, nccl_dtype<_snmg_val_t>(), ncclSum, nccl_comm, stream)); \ + } else { \ + const auto& _snmg_comm = raft::resource::get_comms(dev_res); \ + _snmg_comm.allreduce(sendbuf, recvbuf, count, raft::comms::op_t::SUM, stream); \ + } \ + } while (0) + +#define SNMG_BCAST(buf, count, root) \ + do { \ + using _snmg_bcast_t = std::remove_pointer_t; \ + if (use_nccl) { \ + RAFT_NCCL_TRY( \ + ncclBroadcast(buf, buf, count, nccl_dtype<_snmg_bcast_t>(), root, nccl_comm, stream)); \ + } else { \ + const auto& _snmg_comm = raft::resource::get_comms(dev_res); \ + _snmg_comm.bcast(buf, count, root, stream); \ + } \ + } while (0) + +#define SNMG_GROUP_START() \ + do { \ + if (use_nccl) { RAFT_NCCL_TRY(ncclGroupStart()); } \ + } while (0) + +#define SNMG_GROUP_END() \ + do { \ + if (use_nccl) { RAFT_NCCL_TRY(ncclGroupEnd()); } \ + } while (0) + +// --------------------------------------------------------------------------- +// mnmg_fit — shared multi-GPU core (Paths 1 & 2) +// --------------------------------------------------------------------------- +template +void mnmg_fit(const raft::resources& handle, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X_local, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + // --- Setup: rank, num_ranks, dev_res, comm mechanism --- + bool use_nccl = raft::resource::is_multi_gpu(handle); + int rank, num_ranks; + ncclComm_t nccl_comm{}; + + if (use_nccl) { + rank = cuvs::core::omp::get_thread_num(); + num_ranks = raft::resource::get_num_ranks(handle); + nccl_comm = raft::resource::get_nccl_comm_for_rank(handle, rank); + } else { + const auto& comm = raft::resource::get_comms(handle); + rank = comm.get_rank(); + num_ranks = comm.get_size(); + } + + const raft::resources& dev_res = + use_nccl ? raft::resource::set_current_device_to_rank(handle, rank) : handle; + + auto stream = raft::resource::get_cuda_stream(dev_res); + auto n_local = X_local.extent(0); + auto n_features = X_local.extent(1); + auto n_clusters = static_cast(params.n_clusters); + auto metric = params.metric; + + RAFT_EXPECTS(n_clusters > 0, "n_clusters must be positive"); + RAFT_EXPECTS(static_cast(centroids.extent(0)) == n_clusters, + "centroids.extent(0) must equal n_clusters"); + RAFT_EXPECTS(centroids.extent(1) == n_features, "centroids.extent(1) must equal n_features"); + RAFT_EXPECTS(num_ranks > 0, "num_ranks must be positive"); + + RAFT_LOG_DEBUG("SNMG KMeans fit: rank=%d/%d, n_local=%zu, n_features=%zu, n_clusters=%d", + rank, + num_ranks, + static_cast(n_local), + static_cast(n_features), + static_cast(n_clusters)); + + // --- Resolve streaming batch size --- + IdxT streaming_batch_size = static_cast(params.streaming_batch_size); + if (streaming_batch_size <= 0 || streaming_batch_size > n_local) { + streaming_batch_size = std::max(n_local, IdxT{1}); + } + + bool has_data = (n_local > 0); + + // --- Allocate work buffers once (O2) --- + auto rank_centroids = raft::make_device_matrix(dev_res, n_clusters, n_features); + auto new_centroids = raft::make_device_matrix(dev_res, n_clusters, n_features); + auto centroid_sums = raft::make_device_matrix(dev_res, n_clusters, n_features); + auto weight_per_cluster = raft::make_device_vector(dev_res, n_clusters); + auto batch_sums = raft::make_device_matrix(dev_res, n_clusters, n_features); + auto batch_counts = raft::make_device_vector(dev_res, n_clusters); + auto clustering_cost = raft::make_device_vector(dev_res, 1); + auto batch_clustering_cost = raft::make_device_vector(dev_res, 1); + IdxT alloc_batch_size = has_data ? streaming_batch_size : IdxT{1}; + auto batch_weights = raft::make_device_vector(dev_res, alloc_batch_size); + auto minClusterAndDistance = + raft::make_device_vector, IdxT>(dev_res, alloc_batch_size); + auto L2NormBatch = raft::make_device_vector(dev_res, alloc_batch_size); + rmm::device_uvector L2NormBuf_OR_DistBuf(0, stream); + rmm::device_uvector workspace(0, stream); + + // --- Weight normalization via allreduce (only when sample weights are provided) --- + T weight_scale = T{1}; + if (sample_weight.has_value()) { + auto d_n_local = raft::make_device_scalar(dev_res, static_cast(n_local)); + SNMG_ALLREDUCE(d_n_local.data_handle(), d_n_local.data_handle(), 1); + raft::resource::sync_stream(dev_res); + IdxT global_n{}; + raft::copy(&global_n, d_n_local.data_handle(), 1, stream); + raft::resource::sync_stream(dev_res); + + T local_wt_sum = T{0}; + const T* sw = sample_weight->data_handle(); + for (IdxT i = 0; i < n_local; ++i) { + local_wt_sum += sw[i]; + } + auto d_wt = raft::make_device_scalar(dev_res, local_wt_sum); + SNMG_ALLREDUCE(d_wt.data_handle(), d_wt.data_handle(), 1); + raft::resource::sync_stream(dev_res); + T global_wt{}; + raft::copy(&global_wt, d_wt.data_handle(), 1, stream); + raft::resource::sync_stream(dev_res); + RAFT_EXPECTS(std::isfinite(global_wt) && global_wt > T{0}, + "invalid parameter (sum of sample weights must be finite and positive)"); + const auto global_n_wt = static_cast(global_n); + const T tol = global_n_wt * std::numeric_limits::epsilon(); + if (std::abs(global_wt - global_n_wt) > tol) { weight_scale = global_n_wt / global_wt; } + } + + // --- n_init handling --- + auto n_init = params.n_init; + if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array && n_init != 1) { + RAFT_LOG_DEBUG( + "Explicit initial center position passed: performing only one init in " + "k-means instead of n_init=%d", + n_init); + n_init = 1; + } + + auto best_centroids = n_init > 1 + ? raft::make_device_matrix(dev_res, n_clusters, n_features) + : raft::make_device_matrix(dev_res, 0, 0); + T best_inertia = std::numeric_limits::max(); + IdxT best_n_iter = 0; + + // Per-rank local state (avoids data races on shared host scalars in OMP) + T local_inertia = T{0}; + IdxT local_n_iter = 0; + + std::mt19937 gen(params.rng_state.seed); + + // Allreduce scratch for synchronized convergence + auto d_done = raft::make_device_scalar(dev_res, 0); + + // Construct the batch iterator once; reset it each Lloyd iter / n_init iter. + std::optional> data_batches_opt; + if (has_data) { + data_batches_opt.emplace(X_local.data_handle(), + n_local, + n_features, + streaming_batch_size, + stream, + rmm::mr::get_current_device_resource_ref(), + true); + } + + // --- Main n_init loop --- + for (int seed_iter = 0; seed_iter < n_init; ++seed_iter) { + cuvs::cluster::kmeans::params iter_params = params; + iter_params.rng_state.seed = gen(); + + // --- Centroid initialization (rank 0 only, then broadcast) --- + if (iter_params.init != cuvs::cluster::kmeans::params::InitMethod::Array) { + if (rank == 0) { + cuvs::cluster::kmeans::detail::init_centroids_from_host_sample( + dev_res, iter_params, streaming_batch_size, X_local, rank_centroids.view(), workspace); + } + } else { + if (rank == 0) { + raft::copy( + rank_centroids.data_handle(), centroids.data_handle(), n_clusters * n_features, stream); + } + } + raft::resource::sync_stream(dev_res); + SNMG_BCAST(rank_centroids.data_handle(), n_clusters * n_features, 0); + raft::resource::sync_stream(dev_res); + + if (has_data && !sample_weight.has_value()) { + raft::matrix::fill(dev_res, batch_weights.view(), T{1}); + } + + T prior_cluster_cost = T{0}; + + // --- Lloyd iterations --- + for (local_n_iter = 1; local_n_iter <= iter_params.max_iter; ++local_n_iter) { + RAFT_LOG_DEBUG("SNMG KMeans: iteration %d on rank %d", local_n_iter, rank); + + raft::matrix::fill(dev_res, centroid_sums.view(), T{0}); + raft::matrix::fill(dev_res, weight_per_cluster.view(), T{0}); + raft::matrix::fill(dev_res, clustering_cost.view(), T{0}); + + auto rank_centroids_const = raft::make_device_matrix_view( + rank_centroids.data_handle(), n_clusters, n_features); + + // Phase 1: local batch accumulation (skip if no local data) + if (has_data) { + auto& data_batches = *data_batches_opt; + data_batches.reset(); + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); + + raft::matrix::fill(dev_res, batch_clustering_cost.view(), T{0}); + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); + + cuvs::cluster::kmeans::detail::copy_and_scale_batch_weights(dev_res, + sample_weight, + data_batch.offset(), + current_batch_size, + weight_scale, + batch_weights); + + auto batch_weights_view = raft::make_device_vector_view( + batch_weights.data_handle(), current_batch_size); + + auto L2NormBatch_view = + raft::make_device_vector_view(L2NormBatch.data_handle(), current_batch_size); + + if (metric == cuvs::distance::DistanceType::L2Expanded || + metric == cuvs::distance::DistanceType::L2SqrtExpanded) { + raft::linalg::norm( + dev_res, + raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features), + L2NormBatch_view); + } + + auto L2NormBatch_const = raft::make_const_mdspan(L2NormBatch_view); + + auto minClusterAndDistance_view = + raft::make_device_vector_view, IdxT>( + minClusterAndDistance.data_handle(), current_batch_size); + + cuvs::cluster::kmeans::detail::minClusterAndDistanceCompute( + dev_res, + batch_data_view, + rank_centroids_const, + minClusterAndDistance_view, + L2NormBatch_const, + L2NormBuf_OR_DistBuf, + metric, + params.batch_samples, + params.batch_centroids, + workspace); + + auto minClusterAndDistance_const = raft::make_const_mdspan(minClusterAndDistance_view); + + cuvs::cluster::kmeans::detail::accumulate_batch_centroids( + dev_res, + batch_data_view, + minClusterAndDistance_const, + batch_weights_view, + centroid_sums.view(), + weight_per_cluster.view(), + batch_sums.view(), + batch_counts.view()); + + raft::linalg::map( + dev_res, + minClusterAndDistance_view, + [=] __device__(const raft::KeyValuePair kvp, T wt) { + raft::KeyValuePair res; + res.value = kvp.value * wt; + res.key = kvp.key; + return res; + }, + raft::make_const_mdspan(minClusterAndDistance_view), + batch_weights_view); + + cuvs::cluster::kmeans::detail::computeClusterCost( + dev_res, + minClusterAndDistance_view, + workspace, + raft::make_device_scalar_view(batch_clustering_cost.data_handle()), + raft::value_op{}, + raft::add_op{}); + + raft::linalg::add(dev_res, + raft::make_const_mdspan(clustering_cost.view()), + raft::make_const_mdspan(batch_clustering_cost.view()), + clustering_cost.view()); + } + } + + // Phase 2: grouped allreduce + SNMG_GROUP_START(); + SNMG_ALLREDUCE( + centroid_sums.data_handle(), centroid_sums.data_handle(), n_clusters * n_features); + SNMG_ALLREDUCE( + weight_per_cluster.data_handle(), weight_per_cluster.data_handle(), n_clusters); + SNMG_ALLREDUCE(clustering_cost.data_handle(), clustering_cost.data_handle(), 1); + SNMG_GROUP_END(); + raft::resource::sync_stream(dev_res); + + // Phase 3: finalize centroids + auto centroid_sums_const = raft::make_device_matrix_view( + centroid_sums.data_handle(), n_clusters, n_features); + auto weight_per_cluster_const = + raft::make_device_vector_view(weight_per_cluster.data_handle(), n_clusters); + + cuvs::cluster::kmeans::detail::finalize_centroids(dev_res, + centroid_sums_const, + weight_per_cluster_const, + rank_centroids_const, + new_centroids.view()); + + // Phase 4: convergence check — synchronized across all ranks + T sqrdNormError = cuvs::cluster::kmeans::detail::compute_centroid_shift( + dev_res, + raft::make_const_mdspan(rank_centroids.view()), + raft::make_const_mdspan(new_centroids.view())); + + raft::copy( + rank_centroids.data_handle(), new_centroids.data_handle(), n_clusters * n_features, stream); + + bool done = false; + + raft::copy(&local_inertia, clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(dev_res); + + if (local_inertia == T{0}) { + RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids."); + } else if (local_n_iter > 1 && prior_cluster_cost > T{0}) { + T delta = local_inertia / prior_cluster_cost; + if (delta > 1 - params.tol) { done = true; } + } + prior_cluster_cost = local_inertia; + + if (sqrdNormError < params.tol) { done = true; } + + // Allreduce the convergence flag so all ranks agree (prevents NCCL deadlock + // from floating-point non-determinism in compute_centroid_shift) + int64_t done_val = done ? 1 : 0; + raft::copy(d_done.data_handle(), &done_val, 1, stream); + raft::resource::sync_stream(dev_res); + SNMG_ALLREDUCE(d_done.data_handle(), d_done.data_handle(), 1); + raft::resource::sync_stream(dev_res); + raft::copy(&done_val, d_done.data_handle(), 1, stream); + raft::resource::sync_stream(dev_res); + done = (done_val > 0); + + if (done) { + RAFT_LOG_DEBUG( + "SNMG KMeans: threshold triggered after %d iterations on rank %d", local_n_iter, rank); + break; + } + } + + // Final inertia recomputation against converged centroids + raft::matrix::fill(dev_res, clustering_cost.view(), T{0}); + if (has_data) { + auto rank_centroids_const = raft::make_device_matrix_view( + rank_centroids.data_handle(), n_clusters, n_features); + + auto& data_batches = *data_batches_opt; + data_batches.reset(); + for (const auto& data_batch : data_batches) { + IdxT current_batch_size = static_cast(data_batch.size()); + + auto batch_data_view = raft::make_device_matrix_view( + data_batch.data(), current_batch_size, n_features); + + cuvs::cluster::kmeans::detail::copy_and_scale_batch_weights(dev_res, + sample_weight, + data_batch.offset(), + current_batch_size, + weight_scale, + batch_weights); + + std::optional> batch_sw = std::nullopt; + if (sample_weight.has_value()) { + batch_sw = raft::make_device_vector_view(batch_weights.data_handle(), + current_batch_size); + } + + raft::matrix::fill(dev_res, batch_clustering_cost.view(), T{0}); + cuvs::cluster::kmeans::cluster_cost( + dev_res, + batch_data_view, + rank_centroids_const, + raft::make_device_scalar_view(batch_clustering_cost.data_handle()), + batch_sw); + + raft::linalg::add(dev_res, + raft::make_const_mdspan(clustering_cost.view()), + raft::make_const_mdspan(batch_clustering_cost.view()), + clustering_cost.view()); + } + } + SNMG_ALLREDUCE(clustering_cost.data_handle(), clustering_cost.data_handle(), 1); + raft::resource::sync_stream(dev_res); + raft::copy(&local_inertia, clustering_cost.data_handle(), 1, stream); + raft::resource::sync_stream(dev_res); + + RAFT_LOG_DEBUG("SNMG KMeans: n_init %d/%d completed, inertia=%f, n_iter=%d on rank %d", + seed_iter + 1, + n_init, + static_cast(local_inertia), + local_n_iter, + rank); + + // Best-of-n_init tracking + if (n_init > 1 && local_inertia < best_inertia) { + best_inertia = local_inertia; + best_n_iter = local_n_iter; + raft::copy(best_centroids.data_handle(), + rank_centroids.data_handle(), + n_clusters * n_features, + stream); + } + } + + // --- Final output (rank 0 writes to caller-provided views) --- + if (n_init > 1) { + raft::copy( + rank_centroids.data_handle(), best_centroids.data_handle(), n_clusters * n_features, stream); + local_inertia = best_inertia; + local_n_iter = best_n_iter; + } + + if (rank == 0) { + raft::copy( + centroids.data_handle(), rank_centroids.data_handle(), n_clusters * n_features, stream); + raft::resource::sync_stream(dev_res); + inertia[0] = local_inertia; + n_iter[0] = local_n_iter; + } +} + +// --------------------------------------------------------------------------- +// batched_fit_omp — OpenMP wrapper for Path 1 (cuVS / SNMG) +// --------------------------------------------------------------------------- +template +void batched_fit_omp(const raft::resources& clique, + const cuvs::cluster::kmeans::params& params, + raft::host_matrix_view X, + std::optional> sample_weight, + raft::device_matrix_view centroids, + raft::host_scalar_view inertia, + raft::host_scalar_view n_iter) +{ + raft::resource::get_nccl_comms(clique); + int num_ranks = raft::resource::get_num_ranks(clique); + IdxT n_samples = X.extent(0); + IdxT n_features = X.extent(1); + + IdxT base = n_samples / num_ranks; + IdxT rem = n_samples % num_ranks; + + cuvs::core::omp::check_threads(num_ranks); +#pragma omp parallel num_threads(num_ranks) + { + int r = cuvs::core::omp::get_thread_num(); + IdxT offset = r * base + std::min(r, rem); + IdxT n_local = base + (r < rem ? 1 : 0); + + auto X_local = raft::make_host_matrix_view( + X.data_handle() + offset * n_features, n_local, n_features); + + std::optional> sw_local; + if (sample_weight.has_value()) { + sw_local = + raft::make_host_vector_view(sample_weight->data_handle() + offset, n_local); + } + + mnmg_fit(clique, params, X_local, sw_local, centroids, inertia, n_iter); + } +} + +// Undef local macros +#undef SNMG_ALLREDUCE +#undef SNMG_BCAST +#undef SNMG_GROUP_START +#undef SNMG_GROUP_END + +} // namespace cuvs::cluster::kmeans::mg::detail diff --git a/cpp/src/cluster/detail/minClusterDistanceCompute.cu b/cpp/src/cluster/detail/minClusterDistanceCompute.cu index 8370ff922f..bcfc381753 100644 --- a/cpp/src/cluster/detail/minClusterDistanceCompute.cu +++ b/cpp/src/cluster/detail/minClusterDistanceCompute.cu @@ -7,6 +7,8 @@ #include +#include + namespace cuvs::cluster::kmeans::detail { // Calculates a pair for every sample in input 'X' where key is an @@ -23,36 +25,34 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, int batch_samples, int batch_centroids, - rmm::device_uvector& workspace) + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); auto n_features = X.extent(1); auto n_clusters = centroids.extent(0); - // todo(lsugy): change batch size computation when using fusedL2NN! - bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || + bool is_fused = metric == cuvs::distance::DistanceType::L2Expanded || metric == cuvs::distance::DistanceType::L2SqrtExpanded; auto dataBatchSize = is_fused ? (IndexT)n_samples : getDataBatchSize(batch_samples, n_samples); auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - centroids, - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + centroids, + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { - // TODO: Unless pool allocator is used, passing in a workspace for this - // isn't really increasing performance because this needs to do a re-allocation - // anyways. ref https://github.com/rapidsai/raft/issues/930 L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -87,7 +87,7 @@ void minClusterAndDistanceCompute( datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -154,7 +154,8 @@ void minClusterAndDistanceCompute( cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int64_t) @@ -164,16 +165,18 @@ INSTANTIATE_MIN_CLUSTER_AND_DISTANCE(double, int) #undef INSTANTIATE_MIN_CLUSTER_AND_DISTANCE template -void minClusterDistanceCompute(raft::resources const& handle, - raft::device_matrix_view X, - raft::device_matrix_view centroids, - raft::device_vector_view minClusterDistance, - raft::device_vector_view L2NormX, - rmm::device_uvector& L2NormBuf_OR_DistBuf, - cuvs::distance::DistanceType metric, - int batch_samples, - int batch_centroids, - rmm::device_uvector& workspace) +void minClusterDistanceCompute( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::device_vector_view minClusterDistance, + raft::device_vector_view L2NormX, + rmm::device_uvector& L2NormBuf_OR_DistBuf, + cuvs::distance::DistanceType metric, + int batch_samples, + int batch_centroids, + rmm::device_uvector& workspace, + std::optional> precomputed_centroid_norms) { cudaStream_t stream = raft::resource::get_cuda_stream(handle); auto n_samples = X.extent(0); @@ -186,21 +189,22 @@ void minClusterDistanceCompute(raft::resources const& handle, auto centroidsBatchSize = getCentroidsBatchSize(batch_centroids, n_clusters); if (is_fused) { - L2NormBuf_OR_DistBuf.resize(n_clusters, stream); - raft::linalg::norm( - handle, - raft::make_device_matrix_view( - centroids.data_handle(), centroids.extent(0), centroids.extent(1)), - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + if (!precomputed_centroid_norms.has_value()) { + L2NormBuf_OR_DistBuf.resize(n_clusters, stream); + raft::linalg::norm( + handle, + raft::make_device_matrix_view( + centroids.data_handle(), centroids.extent(0), centroids.extent(1)), + raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters)); + } } else { L2NormBuf_OR_DistBuf.resize(dataBatchSize * centroidsBatchSize, stream); } - // Note - pairwiseDistance and centroidsNorm share the same buffer - // centroidsNorm [n_clusters] - tensor wrapper around centroids L2 Norm - auto centroidsNorm = - raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); - // pairwiseDistance[ns x nc] - tensor wrapper around the distance buffer + auto centroidsNorm_view = + precomputed_centroid_norms.has_value() + ? precomputed_centroid_norms.value() + : raft::make_device_vector_view(L2NormBuf_OR_DistBuf.data(), n_clusters); auto pairwiseDistance = raft::make_device_matrix_view( L2NormBuf_OR_DistBuf.data(), dataBatchSize, centroidsBatchSize); @@ -232,7 +236,7 @@ void minClusterDistanceCompute(raft::resources const& handle, datasetView.data_handle(), centroids.data_handle(), L2NormXView.data_handle(), - centroidsNorm.data_handle(), + centroidsNorm_view.data_handle(), ns, n_clusters, n_features, @@ -290,7 +294,8 @@ void minClusterDistanceCompute(raft::resources const& handle, cuvs::distance::DistanceType metric, \ int batch_samples, \ int batch_centroids, \ - rmm::device_uvector& workspace); + rmm::device_uvector& workspace, \ + std::optional>); INSTANTIATE_MIN_CLUSTER_DISTANCE(float, int64_t) INSTANTIATE_MIN_CLUSTER_DISTANCE(double, int64_t) diff --git a/cpp/src/cluster/kmeans.cuh b/cpp/src/cluster/kmeans.cuh index e4f9821990..d1394210ed 100644 --- a/cpp/src/cluster/kmeans.cuh +++ b/cpp/src/cluster/kmeans.cuh @@ -405,7 +405,7 @@ void min_cluster_distance(raft::resources const& handle, } /** - * @brief Compute (optionally weighted) cluster cost (inertia). + * @brief Compute (optionally weighted) cluster cost (inertia) * * @tparam DataT float or double * @tparam IndexT Index type @@ -413,7 +413,7 @@ void min_cluster_distance(raft::resources const& handle, * @param[in] handle The raft handle * @param[in] X Input data [n_samples x n_features] * @param[in] centroids Cluster centroids [n_clusters x n_features] - * @param[out] cost Sum of squared distances to nearest centroid + * @param[out] cost Sum of squared distances to nearest centroid (device) * @param[in] sample_weight Optional per-sample weights [n_samples] */ template @@ -421,7 +421,7 @@ void cluster_cost( raft::resources const& handle, raft::device_matrix_view X, raft::device_matrix_view centroids, - raft::host_scalar_view cost, + raft::device_scalar_view cost, std::optional> sample_weight = std::nullopt) { auto stream = raft::resource::get_cuda_stream(handle); @@ -432,7 +432,6 @@ void cluster_cost( rmm::device_uvector workspace(n_samples * sizeof(IndexT), stream); auto x_norms = raft::make_device_vector(handle, n_samples); - raft::linalg::norm(handle, X, x_norms.view()); auto min_cluster_distance = raft::make_device_vector(handle, n_samples); @@ -453,7 +452,6 @@ void cluster_cost( n_clusters, workspace); - // Apply sample weights if provided if (sample_weight.has_value()) { raft::linalg::map(handle, min_cluster_distance.view(), @@ -462,12 +460,35 @@ void cluster_cost( sample_weight.value()); } - auto device_cost = raft::make_device_scalar(handle, DataT(0)); - cuvs::cluster::kmeans::cluster_cost( - handle, min_cluster_distance.view(), workspace, device_cost.view(), raft::add_op{}); - raft::copy(handle, cost, raft::make_const_mdspan(device_cost.view())); + handle, min_cluster_distance.view(), workspace, cost, raft::add_op{}); +} +/** + * @brief Compute (optionally weighted) cluster cost (inertia) — host-scalar output. + * + * Convenience wrapper that copies the result to host and synchronizes. + * + * @tparam DataT float or double + * @tparam IndexT Index type + * + * @param[in] handle The raft handle + * @param[in] X Input data [n_samples x n_features] + * @param[in] centroids Cluster centroids [n_clusters x n_features] + * @param[out] cost Sum of squared distances to nearest centroid (host) + * @param[in] sample_weight Optional per-sample weights [n_samples] + */ +template +void cluster_cost( + raft::resources const& handle, + raft::device_matrix_view X, + raft::device_matrix_view centroids, + raft::host_scalar_view cost, + std::optional> sample_weight = std::nullopt) +{ + auto device_cost = raft::make_device_scalar(handle, DataT(0)); + cuvs::cluster::kmeans::cluster_cost(handle, X, centroids, device_cost.view(), sample_weight); + raft::copy(handle, cost, raft::make_const_mdspan(device_cost.view())); raft::resource::sync_stream(handle); } diff --git a/cpp/src/cluster/kmeans_fit_double.cu b/cpp/src/cluster/kmeans_fit_double.cu index d7e4748e33..7851eeff6e 100644 --- a/cpp/src/cluster/kmeans_fit_double.cu +++ b/cpp/src/cluster/kmeans_fit_double.cu @@ -8,6 +8,10 @@ #include "kmeans_impl.cuh" #include +#ifdef CUVS_BUILD_MG_ALGOS +#include "detail/kmeans_mg_batched.cuh" +#endif + namespace cuvs::cluster::kmeans { #define INSTANTIATE_FIT_MAIN(DataT, IndexT) \ @@ -72,8 +76,18 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); +#ifdef CUVS_BUILD_MG_ALGOS + if (raft::resource::is_multi_gpu(handle)) { + mg::detail::batched_fit_omp( + handle, params, X, sample_weight, centroids, inertia, n_iter); + } else if (raft::resource::comms_initialized(handle)) { + mg::detail::mnmg_fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); + } else +#endif + { + detail::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + } } } // namespace cuvs::cluster::kmeans diff --git a/cpp/src/cluster/kmeans_fit_float.cu b/cpp/src/cluster/kmeans_fit_float.cu index f86fabcfbd..30f10dedb5 100644 --- a/cpp/src/cluster/kmeans_fit_float.cu +++ b/cpp/src/cluster/kmeans_fit_float.cu @@ -8,6 +8,10 @@ #include "kmeans_impl.cuh" #include +#ifdef CUVS_BUILD_MG_ALGOS +#include "detail/kmeans_mg_batched.cuh" +#endif + namespace cuvs::cluster::kmeans { #define INSTANTIATE_FIT_MAIN(DataT, IndexT) \ @@ -72,8 +76,18 @@ void fit(raft::resources const& handle, raft::host_scalar_view inertia, raft::host_scalar_view n_iter) { - cuvs::cluster::kmeans::detail::fit( - handle, params, X, sample_weight, centroids, inertia, n_iter); +#ifdef CUVS_BUILD_MG_ALGOS + if (raft::resource::is_multi_gpu(handle)) { + mg::detail::batched_fit_omp( + handle, params, X, sample_weight, centroids, inertia, n_iter); + } else if (raft::resource::comms_initialized(handle)) { + mg::detail::mnmg_fit( + handle, params, X, sample_weight, centroids, inertia, n_iter); + } else +#endif + { + detail::fit(handle, params, X, sample_weight, centroids, inertia, n_iter); + } } } // namespace cuvs::cluster::kmeans diff --git a/cpp/tests/CMakeLists.txt b/cpp/tests/CMakeLists.txt index b30f108789..b694c8b6c8 100644 --- a/cpp/tests/CMakeLists.txt +++ b/cpp/tests/CMakeLists.txt @@ -318,6 +318,14 @@ if(BUILD_MG_ALGOS) PERCENT 100 ADDITIONAL_DEP NCCL::NCCL ) + + ConfigureTest( + NAME CLUSTER_KMEANS_MG_BATCHED_TEST + PATH cluster/kmeans_mg_batched.cu + GPUS 1 + PERCENT 100 + ADDITIONAL_DEP NCCL::NCCL + ) endif() ConfigureTest( diff --git a/cpp/tests/cluster/kmeans_mg_batched.cu b/cpp/tests/cluster/kmeans_mg_batched.cu new file mode 100644 index 0000000000..5f3bcee725 --- /dev/null +++ b/cpp/tests/cluster/kmeans_mg_batched.cu @@ -0,0 +1,388 @@ +/* + * SPDX-FileCopyrightText: Copyright (c) 2026, NVIDIA CORPORATION. + * SPDX-License-Identifier: Apache-2.0 + */ + +#include "../test_utils.cuh" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include + +#include +#include +#include +#include + +namespace cuvs { + +template +struct KmeansSNMGInputs { + int n_row; + int n_col; + int n_clusters; + T tol; + int weight_mode; // 0 = no weights, 1 = uniform, 2 = mild non-uniform, 3 = extreme non-uniform + int streaming_batch_size; + int n_init; + cuvs::cluster::kmeans::params::InitMethod init = cuvs::cluster::kmeans::params::Array; + bool inertia_check = true; + int max_iter = 20; +}; + +template +class KmeansSNMGTest : public ::testing::TestWithParam> { + protected: + KmeansSNMGTest() : clique_() { clique_.set_memory_pool(50); } + + void runTest() + { + testparams_ = ::testing::TestWithParam>::GetParam(); + + int n_samples = testparams_.n_row; + int n_features = testparams_.n_col; + int n_clusters = testparams_.n_clusters; + int num_ranks = raft::resource::get_num_ranks(clique_); + + auto stream = raft::resource::get_cuda_stream(clique_); + + auto X = raft::make_device_matrix(clique_, n_samples, n_features); + auto labels = raft::make_device_vector(clique_, n_samples); + + raft::random::make_blobs(X.data_handle(), + labels.data_handle(), + n_samples, + n_features, + n_clusters, + stream, + true, + nullptr, + nullptr, + T(1.0), + false, + (T)-10.0f, + (T)10.0f, + (uint64_t)1234); + + // Copy X to host + std::vector h_X(n_samples * n_features); + raft::update_host(h_X.data(), X.data_handle(), n_samples * n_features, stream); + raft::resource::sync_stream(clique_, stream); + + auto h_X_view = + raft::make_host_matrix_view(h_X.data(), n_samples, n_features); + + auto d_centroids_snmg = raft::make_device_matrix(clique_, n_clusters, n_features); + auto d_centroids_ref = raft::make_device_matrix(clique_, n_clusters, n_features); + + if (testparams_.init == cuvs::cluster::kmeans::params::Array) { + std::vector h_labels(n_samples); + raft::update_host(h_labels.data(), labels.data_handle(), n_samples, stream); + raft::resource::sync_stream(clique_, stream); + + std::vector h_centroids(n_clusters * n_features, T(0)); + std::vector counts(n_clusters, 0); + for (int i = 0; i < n_samples; ++i) { + int c = h_labels[i]; + counts[c]++; + for (int j = 0; j < n_features; ++j) + h_centroids[c * n_features + j] += h_X[i * n_features + j]; + } + for (int c = 0; c < n_clusters; ++c) { + if (counts[c] > 0) { + for (int j = 0; j < n_features; ++j) + h_centroids[c * n_features + j] /= T(counts[c]); + } + } + + raft::update_device( + d_centroids_snmg.data_handle(), h_centroids.data(), n_clusters * n_features, stream); + raft::copy(d_centroids_ref.data_handle(), + d_centroids_snmg.data_handle(), + n_clusters * n_features, + stream); + raft::resource::sync_stream(clique_, stream); + } + + // --- Prepare sample weights --- + std::optional> h_sw = std::nullopt; + std::vector h_sample_weight; + if (testparams_.weight_mode > 0) { + h_sample_weight.resize(n_samples); + for (int i = 0; i < n_samples; ++i) { + if (testparams_.weight_mode == 3) + h_sample_weight[i] = (i % 10 == 0) ? T(100) : T(1); + else if (testparams_.weight_mode == 2) + h_sample_weight[i] = T(1) + T(i % 5); + else + h_sample_weight[i] = T(1); + } + h_sw = raft::make_host_vector_view(h_sample_weight.data(), n_samples); + } + + // --- Run SNMG fit --- + cuvs::cluster::kmeans::params snmg_params; + snmg_params.n_clusters = n_clusters; + snmg_params.tol = testparams_.tol; + snmg_params.max_iter = testparams_.max_iter; + snmg_params.n_init = testparams_.n_init; + snmg_params.rng_state.seed = 42; + snmg_params.init = testparams_.init; + snmg_params.inertia_check = testparams_.inertia_check; + snmg_params.streaming_batch_size = testparams_.streaming_batch_size; + + T snmg_inertia = T{0}; + int64_t snmg_n_iter = 0; + + cuvs::cluster::kmeans::fit(clique_, + snmg_params, + h_X_view, + h_sw, + d_centroids_snmg.view(), + raft::make_host_scalar_view(&snmg_inertia), + raft::make_host_scalar_view(&snmg_n_iter)); + + raft::resource::sync_stream(clique_, stream); + + // --- Run single-GPU reference fit --- + raft::resources sg_handle; + auto sg_stream = raft::resource::get_cuda_stream(sg_handle); + + auto d_centroids_sg = raft::make_device_matrix(sg_handle, n_clusters, n_features); + if (testparams_.init == cuvs::cluster::kmeans::params::Array) { + raft::copy(d_centroids_sg.data_handle(), + d_centroids_ref.data_handle(), + n_clusters * n_features, + sg_stream); + raft::resource::sync_stream(sg_handle, sg_stream); + } + + cuvs::cluster::kmeans::params sg_params = snmg_params; + + T sg_inertia = T{0}; + int64_t sg_n_iter = 0; + + cuvs::cluster::kmeans::fit(sg_handle, + sg_params, + h_X_view, + h_sw, + d_centroids_sg.view(), + raft::make_host_scalar_view(&sg_inertia), + raft::make_host_scalar_view(&sg_n_iter)); + + raft::resource::sync_stream(sg_handle, sg_stream); + + // --- Predict labels using both centroid sets on single GPU --- + rmm::device_uvector d_labels_snmg(n_samples, sg_stream); + rmm::device_uvector d_labels_sg(n_samples, sg_stream); + rmm::device_uvector d_labels_ref(n_samples, sg_stream); + + raft::copy(d_labels_ref.data(), labels.data_handle(), n_samples, sg_stream); + + auto X_dev_view = + raft::make_device_matrix_view(X.data_handle(), n_samples, n_features); + + cuvs::cluster::kmeans::params pred_params; + pred_params.n_clusters = n_clusters; + + // Copy SNMG centroids to single-GPU handle for predict + auto d_centroids_snmg_copy = + raft::make_device_matrix(sg_handle, n_clusters, n_features); + raft::copy(d_centroids_snmg_copy.data_handle(), + d_centroids_snmg.data_handle(), + n_clusters * n_features, + sg_stream); + + auto d_centroids_sg_int = raft::make_device_matrix(sg_handle, n_clusters, n_features); + raft::copy(d_centroids_sg_int.data_handle(), + d_centroids_sg.data_handle(), + n_clusters * n_features, + sg_stream); + + T pred_inertia_snmg = T{0}; + cuvs::cluster::kmeans::predict( + sg_handle, + pred_params, + X_dev_view, + std::nullopt, + raft::make_device_matrix_view( + d_centroids_snmg_copy.data_handle(), n_clusters, n_features), + raft::make_device_vector_view(d_labels_snmg.data(), n_samples), + true, + raft::make_host_scalar_view(&pred_inertia_snmg)); + + T pred_inertia_sg = T{0}; + cuvs::cluster::kmeans::predict( + sg_handle, + pred_params, + X_dev_view, + std::nullopt, + raft::make_device_matrix_view( + d_centroids_sg_int.data_handle(), n_clusters, n_features), + raft::make_device_vector_view(d_labels_sg.data(), n_samples), + true, + raft::make_host_scalar_view(&pred_inertia_sg)); + + raft::resource::sync_stream(sg_handle, sg_stream); + + // --- Evaluate: compare SNMG labels with reference (make_blobs) labels --- + ari_vs_ref_ = raft::stats::adjusted_rand_index( + d_labels_ref.data(), d_labels_snmg.data(), n_samples, sg_stream); + + // ARI between SNMG and single-GPU results + ari_vs_sg_ = raft::stats::adjusted_rand_index( + d_labels_sg.data(), d_labels_snmg.data(), n_samples, sg_stream); + + raft::resource::sync_stream(sg_handle, sg_stream); + + snmg_inertia_ = snmg_inertia; + sg_inertia_ = sg_inertia; + snmg_n_iter_ = snmg_n_iter; + sg_n_iter_ = sg_n_iter; + + // --- Centroid-level comparison for deterministic (Array) init --- + if (testparams_.init == cuvs::cluster::kmeans::params::Array) { + std::vector h_c_snmg(n_clusters * n_features); + std::vector h_c_sg(n_clusters * n_features); + raft::update_host( + h_c_snmg.data(), d_centroids_snmg_copy.data_handle(), n_clusters * n_features, sg_stream); + raft::update_host( + h_c_sg.data(), d_centroids_sg_int.data_handle(), n_clusters * n_features, sg_stream); + raft::resource::sync_stream(sg_handle, sg_stream); + + double max_rel = 0; + for (int i = 0; i < n_clusters * n_features; ++i) { + double denom = std::max(double{1e-8}, std::abs(static_cast(h_c_sg[i]))); + double rel = + std::abs(static_cast(h_c_snmg[i]) - static_cast(h_c_sg[i])) / denom; + max_rel = std::max(max_rel, rel); + } + max_centroid_rel_diff_ = max_rel; + has_centroid_comparison_ = true; + } + + if (ari_vs_ref_ < 0.94 || ari_vs_sg_ < 0.94) { + std::cout << "SNMG KMeans: ARI vs ref = " << ari_vs_ref_ << ", ARI vs SG = " << ari_vs_sg_ + << ", num_ranks = " << num_ranks << ", snmg_inertia = " << snmg_inertia + << ", sg_inertia = " << sg_inertia << ", snmg_n_iter = " << snmg_n_iter + << ", sg_n_iter = " << sg_n_iter << std::endl; + } + } + + void SetUp() override { runTest(); } + + void checkResult() + { + // make_blobs generates well-separated clusters (spread=1.0, range [-10,10]). + // ARI >= 0.94 allows for minor label disagreement from floating-point + // non-determinism across GPUs while still catching real clustering failures. + ASSERT_GE(ari_vs_ref_, 0.94); + ASSERT_GE(ari_vs_sg_, 0.94); + ASSERT_GT(snmg_n_iter_, int64_t{0}); + ASSERT_LE(snmg_n_iter_, static_cast(testparams_.max_iter)); + if (testparams_.init == cuvs::cluster::kmeans::params::Array) { + EXPECT_GE(ari_vs_sg_, 0.98); + if (sg_inertia_ > 0) { + EXPECT_LT(std::abs(snmg_inertia_ - sg_inertia_) / sg_inertia_, decltype(sg_inertia_){0.02}); + } + } + if (has_centroid_comparison_) { + EXPECT_LT(max_centroid_rel_diff_, 0.02) + << "SNMG vs SG centroid max relative diff = " << max_centroid_rel_diff_; + } + } + + raft::device_resources_snmg clique_; + KmeansSNMGInputs testparams_; + double ari_vs_ref_ = 0; + double ari_vs_sg_ = 0; + T snmg_inertia_ = T{0}; + T sg_inertia_ = T{0}; + int64_t snmg_n_iter_ = 0; + int64_t sg_n_iter_ = 0; + double max_centroid_rel_diff_ = 0; + bool has_centroid_comparison_ = false; +}; + +// ============================================================================ +// Float test inputs +// ============================================================================ +const std::vector> snmg_inputsf = { + // n_row, n_col, n_clusters, tol, weight_mode, streaming_batch_size, n_init[, init] + {1000, 32, 5, 0.0001f, 0, 1000, 1}, + {1000, 32, 5, 0.0001f, 1, 1000, 1}, + {1000, 32, 5, 0.0001f, 0, 128, 1}, + {10000, 16, 10, 0.0001f, 0, 2000, 1}, + {10000, 16, 10, 0.0001f, 1, 2000, 1}, + {10000, 16, 10, 0.0001f, 0, 500, 1}, + {1001, 32, 5, 0.0001f, 0, 1001, 1}, + {1000, 32, 5, 0.0001f, 0, 1000, 1, cuvs::cluster::kmeans::params::KMeansPlusPlus}, + {1001, 32, 5, 0.0001f, 0, 128, 1}, + // Non-uniform weights: exercises weight_scale = global_n / global_wt normalization + {1000, 32, 5, 0.0001f, 2, 1000, 1}, + {10000, 16, 10, 0.0001f, 2, 2000, 1}, + // Extreme non-uniform weights (100:1 ratio): stresses weight normalization + {1000, 32, 5, 0.0001f, 3, 1000, 1}, + // Extreme batch size = 1: single-element work buffers, many batch iterations + {100, 8, 3, 0.001f, 0, 1, 1}, + // Very small dataset: some ranks may get only 2-3 rows with 4+ GPUs + {10, 4, 3, 0.001f, 0, 10, 1}, + // Fewer rows than GPUs: exercises empty-rank (has_data=false) partitions on 4+ GPU systems + {3, 4, 2, 0.001f, 0, 3, 1}, + // Trivial single cluster: convergence should be immediate + {1000, 16, 1, 0.0001f, 0, 1000, 1}, + // Batch size > n_samples: tests per-rank clamping logic + {1000, 32, 5, 0.0001f, 0, 5000, 1}, + // n_init > 1 with KMeansPlusPlus: best-of-n seed management across ranks + {1000, 32, 5, 0.0001f, 0, 1000, 3, cuvs::cluster::kmeans::params::KMeansPlusPlus}, + // inertia_check=false: convergence only via centroid shift + {1000, 32, 5, 0.0001f, 0, 1000, 1, cuvs::cluster::kmeans::params::Array, false}, + {1000, 32, 5, 0.0001f, 0, 128, 1, cuvs::cluster::kmeans::params::Array, false}, + // max_iter saturation: algorithm should stop at max_iter without convergence + {1000, 32, 5, 0.0001f, 0, 1000, 1, cuvs::cluster::kmeans::params::Array, true, 2}, +}; + +// ============================================================================ +// Double test inputs +// ============================================================================ +const std::vector> snmg_inputsd = { + {1000, 32, 5, 0.0001, 0, 1000, 1}, + {1000, 32, 5, 0.0001, 0, 128, 1}, + {1000, 32, 5, 0.0001, 1, 1000, 1}, + {1000, 32, 5, 0.0001, 2, 1000, 1}, + {10000, 16, 10, 0.0001, 0, 2000, 1}, + {100, 8, 3, 0.001, 0, 1, 1}, + {10, 4, 3, 0.001, 0, 10, 1}, + {3, 4, 2, 0.001, 0, 3, 1}, + {1000, 16, 1, 0.0001, 0, 1000, 1}, + {1000, 32, 5, 0.0001, 0, 5000, 1}, + {1000, 32, 5, 0.0001, 0, 1000, 1, cuvs::cluster::kmeans::params::KMeansPlusPlus}, + {1000, 32, 5, 0.0001, 0, 1000, 1, cuvs::cluster::kmeans::params::Array, false}, + {1000, 32, 5, 0.0001, 0, 1000, 1, cuvs::cluster::kmeans::params::Array, true, 2}, +}; + +// ============================================================================ +// Test fixtures +// ============================================================================ +typedef KmeansSNMGTest KmeansSNMGTestF; +typedef KmeansSNMGTest KmeansSNMGTestD; + +TEST_P(KmeansSNMGTestF, Result) { checkResult(); } + +TEST_P(KmeansSNMGTestD, Result) { checkResult(); } + +INSTANTIATE_TEST_SUITE_P(KmeansSNMGTests, KmeansSNMGTestF, ::testing::ValuesIn(snmg_inputsf)); +INSTANTIATE_TEST_SUITE_P(KmeansSNMGTests, KmeansSNMGTestD, ::testing::ValuesIn(snmg_inputsd)); + +} // namespace cuvs