Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.streaming_batch_size.
*
* Multi-GPU dispatch is selected automatically based on the handle state:
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
* is split across GPUs internally with an OpenMP parallel region and NCCL.
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as
* this worker's partition, and RAFT communicators are used for collectives.
* - Otherwise: single-GPU batched k-means.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.hpp>
Expand Down Expand Up @@ -196,7 +203,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle.
* @param[in] handle The raft handle. When a multi-GPU resource is
* attached, multi-GPU dispatch is used automatically.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.streaming_batch_size.
* @param[in] X Training instances on HOST memory. The data must
Expand Down
215 changes: 163 additions & 52 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/memory_type.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
Expand All @@ -43,6 +44,9 @@
#include <cuda.h>
#include <cuda/iterator>
#include <thrust/for_each.h>
#include <thrust/iterator/transform_iterator.h>

#include <raft/linalg/add.cuh>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -129,43 +133,36 @@ void countLabels(raft::resources const& handle,
stream));
}

template <typename DataT, typename IndexT>
void checkWeight(raft::resources const& handle,
raft::device_vector_view<DataT, IndexT> weight,
rmm::device_uvector<char>& workspace)
/**
* @brief Compute the sum of sample weights.
*
* Device-accessible mdspans are reduced on device via mapThenSumReduce;
* host mdspans are summed on the host.
*
* @return Sum of weights.
*/
template <typename DataT, typename IndexT, typename Accessor>
DataT weightSum(
raft::resources const& handle,
raft::mdspan<const DataT, raft::vector_extent<IndexT>, raft::layout_right, Accessor> weight)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto wt_aggr = raft::make_device_scalar<DataT>(handle, 0);
auto n_samples = weight.extent(0);

size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceReduce::Sum(
nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream));

workspace.resize(temp_storage_bytes, stream);

RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(),
temp_storage_bytes,
weight.data_handle(),
wt_aggr.data_handle(),
n_samples,
stream));
DataT wt_sum = 0;
raft::copy(handle,
raft::make_host_scalar_view(&wt_sum),
raft::make_device_scalar_view(wt_aggr.data_handle()));
raft::resource::sync_stream(handle, stream);

if (wt_sum != n_samples) {
RAFT_LOG_DEBUG(
"[Warning!] KMeans: normalizing the user provided sample weight to "
"sum up to %d samples",
n_samples);

auto scale = static_cast<DataT>(n_samples) / wt_sum;
raft::linalg::map(
handle, weight, raft::mul_const_op<DataT>{scale}, raft::make_const_mdspan(weight));
auto n_samples = weight.extent(0);

DataT wt_sum = DataT{0};
if constexpr (raft::is_device_mdspan_v<decltype(weight)>) {
auto stream = raft::resource::get_cuda_stream(handle);
auto d_wt_sum = raft::make_device_scalar<DataT>(handle, DataT{0});
raft::linalg::mapThenSumReduce(
d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle());
raft::copy(&wt_sum, d_wt_sum.data_handle(), 1, stream);
raft::resource::sync_stream(handle);
} else {
for (IndexT i = 0; i < n_samples; ++i) {
wt_sum += weight(i);
}
}
RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)");
return wt_sum;
}

template <typename IndexT>
Expand Down Expand Up @@ -262,7 +259,7 @@ void sampleCentroids(raft::resources const& handle,
raft::copy(handle,
raft::make_host_scalar_view(&nPtsSampledInRank),
raft::make_device_scalar_view(nSelected.data_handle()));
raft::resource::sync_stream(handle, stream);
raft::resource::sync_stream(handle);

uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle();
thrust::for_each_n(raft::resource::get_thrust_policy(handle),
Expand Down Expand Up @@ -367,7 +364,9 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>> precomputed_centroid_norms =
std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \
extern template void minClusterAndDistanceCompute<DataT, IndexT>( \
Expand All @@ -380,7 +379,8 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>>);

EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int)
Expand All @@ -399,7 +399,9 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>>
precomputed_centroid_norms = std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \
extern template void minClusterDistanceCompute<DataT, IndexT>( \
Expand All @@ -412,7 +414,8 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>>);

EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t)
Expand Down Expand Up @@ -484,14 +487,18 @@ void countSamplesInCluster(raft::resources const& handle,
* @tparam IndexT Index type
* @tparam LabelsIterator Iterator type for cluster labels
*
* @param[in] handle RAFT resources handle
* @param[in] X Input samples [n_samples x n_features]
* @param[in] sample_weights Weights for each sample [n_samples]
* @param[in] cluster_labels Cluster assignment for each sample (iterator)
* @param[in] n_clusters Number of clusters
* @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features]
* @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters]
* @param[inout] workspace Workspace buffer for intermediate operations
* @param[in] handle RAFT resources handle
* @param[in] X Input samples [n_samples x n_features]
* @param[in] sample_weights Weights for each sample [n_samples]
* @param[in] cluster_labels Cluster assignment for each sample (iterator)
* @param[in] n_clusters Number of clusters
* @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features]
* @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same
* overwrite-vs-accumulate semantics as `centroid_sums`
* @param[inout] workspace Workspace buffer for intermediate operations
* @param[in] reset_sums If true (default), outputs are reset to zero before reducing;
* if false, this call's contribution is accumulated into the
* existing `centroid_sums`
*/
template <typename DataT, typename IndexT, typename LabelsIterator>
void compute_centroid_adjustments(
Expand All @@ -502,7 +509,8 @@ void compute_centroid_adjustments(
IndexT n_clusters,
raft::device_matrix_view<DataT, IndexT, raft::row_major> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
rmm::device_uvector<char>& workspace)
rmm::device_uvector<char>& workspace,
bool reset_sums = true)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
Expand All @@ -518,15 +526,17 @@ void compute_centroid_adjustments(
X.extent(1),
n_clusters,
centroid_sums.data_handle(),
stream);
stream,
reset_sums);

raft::linalg::reduce_cols_by_key(sample_weights.data_handle(),
cluster_labels,
weight_per_cluster.data_handle(),
static_cast<IndexT>(1),
static_cast<IndexT>(n_samples),
n_clusters,
stream);
stream,
reset_sums);
}
/**
* @brief Finalize centroids by dividing accumulated sums by counts.
Expand Down Expand Up @@ -594,8 +604,109 @@ DataT compute_centroid_shift(raft::resources const& handle,
new_centroids.data_handle());
DataT result = 0;
raft::copy(&result, sqrdNorm.data_handle(), 1, stream);
raft::resource::sync_stream(handle, stream);
raft::resource::sync_stream(handle);
return result;
}

/**
* @brief Process a single batch of data in the Lloyd iteration.
*
* Given one batch of data + precomputed norms + weights + current centroids it
* 1. finds the nearest centroid for every sample,
* 2. accumulates weighted centroid sums and counts into the running accumulators,
* 3. accumulates the weighted clustering cost (inertia).
*
* Data norms must be precomputed by the caller and passed in via L2NormBatch.
*
* @tparam DataT Data / weight type (float, double)
* @tparam IndexT Index type (int, int64_t)
*
* @param[in] handle RAFT resources handle
* @param[in] batch_data Device batch data [batch_size x n_features]
* @param[in] batch_weights Device batch weights [batch_size]
* @param[in] centroids Current centroids [n_clusters x n_features]
* @param[in] metric Distance metric
* @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute
* @param[in] batch_centroids_param Batch-centroids param forwarded to
* minClusterAndDistanceCompute
* @param[inout] minClusterAndDistance Work buffer [batch_size]
* @param[in] L2NormBatch Precomputed data norms [batch_size]
* @param[inout] L2NormBuf_OR_DistBuf Resizable scratch
* @param[inout] workspace Resizable scratch
* @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into)
* @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into)
* @param[inout] clustering_cost Running cost scalar (device) (added into)
* @param[in] centroid_norms Optional precomputed centroid norms [n_clusters].
* When provided, skips internal centroid norm computation.
*/
template <typename DataT, typename IndexT>
void process_batch(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> batch_data,
raft::device_vector_view<const DataT, IndexT> batch_weights,
raft::device_matrix_view<const DataT, IndexT> centroids,
cuvs::distance::DistanceType metric,
int batch_samples_param,
int batch_centroids_param,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
raft::device_vector_view<const DataT, IndexT> L2NormBatch,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace,
raft::device_matrix_view<DataT, IndexT> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
raft::device_scalar_view<DataT> clustering_cost,
rmm::device_uvector<char>& batch_workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>> centroid_norms = std::nullopt)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

minClusterAndDistanceCompute<DataT, IndexT>(handle,
batch_data,
centroids,
minClusterAndDistance,
L2NormBatch,
L2NormBuf_OR_DistBuf,
metric,
batch_samples_param,
batch_centroids_param,
workspace,
centroid_norms);

KeyValueIndexOp<IndexT, DataT> conversion_op;
thrust::transform_iterator<KeyValueIndexOp<IndexT, DataT>,
const raft::KeyValuePair<IndexT, DataT>*>
labels_itr(minClusterAndDistance.data_handle(), conversion_op);

compute_centroid_adjustments(handle,
batch_data,
batch_weights,
labels_itr,
static_cast<IndexT>(centroid_sums.extent(0)),
centroid_sums,
weight_per_cluster,
batch_workspace,
/*reset_sums=*/false);

raft::linalg::map(
handle,
minClusterAndDistance,
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
},
raft::make_const_mdspan(minClusterAndDistance),
batch_weights);

auto batch_cost = raft::make_device_scalar<DataT>(handle, DataT{0});
computeClusterCost(
handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{});
raft::linalg::add(clustering_cost.data_handle(),
clustering_cost.data_handle(),
batch_cost.data_handle(),
1,
stream);
}

} // namespace cuvs::cluster::kmeans::detail
Loading
Loading