Skip to content
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
66d7fd3
combine impls
tarang-jain Apr 10, 2026
07707af
Multi-GPU Batched KMeans
viclafargue Apr 13, 2026
efc270f
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 13, 2026
0a09e6f
rm inertia_check
tarang-jain Apr 13, 2026
99a5730
change to warning
tarang-jain Apr 13, 2026
a077406
style
tarang-jain Apr 13, 2026
d659875
add init_size param
tarang-jain Apr 13, 2026
ec2e8b7
Merge branch 'main' into combine-batch
tarang-jain Apr 13, 2026
03a6473
docs
tarang-jain Apr 13, 2026
42a8d9d
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 13, 2026
86af2fa
rm direct cuda api calls
tarang-jain Apr 13, 2026
d4e4e2c
std::swap instead of raft::copy
tarang-jain Apr 14, 2026
0819af5
cache batch norms
tarang-jain Apr 14, 2026
e0f079c
centroid norms can also be cached per iteration
tarang-jain Apr 14, 2026
c2f7390
mg n_iter
tarang-jain Apr 14, 2026
b9c3102
pre-commit
tarang-jain Apr 14, 2026
e3956c1
do not break c abi
tarang-jain Apr 14, 2026
986d78a
Merge branch 'main' into combine-batch
tarang-jain Apr 14, 2026
7197b71
cluster_cost on device
viclafargue Apr 14, 2026
84ab315
Updated testing
viclafargue Apr 14, 2026
47d4b94
templating
viclafargue Apr 15, 2026
a8e1d26
Merge branch 'main' into combine-batch
tarang-jain Apr 16, 2026
384d054
fix checkWeight
tarang-jain Apr 21, 2026
455b286
merge upstream:
tarang-jain Apr 21, 2026
5462809
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 21, 2026
6ba759c
fix compilation
tarang-jain Apr 21, 2026
e76eaac
rel_tol
tarang-jain Apr 22, 2026
afbefdf
pass workspace
tarang-jain Apr 22, 2026
e62a63c
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 22, 2026
e4f08bf
style
tarang-jain Apr 22, 2026
6e4a8f0
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 22, 2026
4a8a85c
do not use batch scratch space; rm update_centroids
tarang-jain Apr 22, 2026
bbf2a9f
move the debug log
tarang-jain Apr 22, 2026
410092c
add new suffixed param struct
tarang-jain Apr 22, 2026
c515c1e
address pr reviews
tarang-jain Apr 22, 2026
e8e63ab
fix docstring
tarang-jain Apr 22, 2026
30c457c
fix wt_sum warning
tarang-jain Apr 22, 2026
ab96623
rm deprecationwarning and instead add FutureWarning:=
tarang-jain Apr 22, 2026
269f23c
unweighted to never materialize batch weights
tarang-jain Apr 22, 2026
80a22ca
add cpp tests
tarang-jain Apr 23, 2026
ac06b05
update cpp tests
tarang-jain Apr 23, 2026
855624a
Merge branch 'main' into mg-batched-kmeans
viclafargue Apr 23, 2026
0a6748d
refactor
viclafargue Apr 23, 2026
7055272
rename to mnmg_fit
viclafargue Apr 23, 2026
0569340
revert batch norms cache
tarang-jain Apr 23, 2026
8cac63a
increase zero cost threshold
tarang-jain Apr 24, 2026
f6df4ae
apply cuda event plus re-add h_norm_cache
tarang-jain Apr 24, 2026
9fc74b1
rm cosine expanded stuff
tarang-jain Apr 24, 2026
dec3dc4
resolve merge conflicts
tarang-jain Apr 28, 2026
0d030a2
change suffix of the params struct
tarang-jain Apr 28, 2026
b1c034e
replace 06 by 08, add todo and note
tarang-jain Apr 28, 2026
a482495
update to v2
tarang-jain Apr 28, 2026
8ecfdc1
avoid stream sync inside weight sum
tarang-jain Apr 29, 2026
1e1525e
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
ec22e07
empty
tarang-jain Apr 29, 2026
d2e410d
empty
tarang-jain Apr 29, 2026
b791c38
Merge branch 'main' into combine-batch
tarang-jain Apr 29, 2026
a05a006
new signatures with new struct
tarang-jain Apr 29, 2026
73293cf
Merge branch 'combine-batch' of https://github.com/tarang-jain/cuvs i…
tarang-jain Apr 29, 2026
880c7b9
Merge branch 'main' of https://github.com/rapidsai/cuvs into combine-…
tarang-jain Apr 30, 2026
e2035ec
revert change to calls in py and rust; add c tests
tarang-jain Apr 30, 2026
e28c200
Merge branch 'main' into combine-batch
tarang-jain May 1, 2026
55bbdad
use to_dlpack
tarang-jain May 5, 2026
9a9b8ee
cache device weights
tarang-jain May 5, 2026
a800b27
rm event
tarang-jain May 5, 2026
3db8582
update names
tarang-jain May 5, 2026
c048352
rename
tarang-jain May 5, 2026
2f968f8
rm docs
tarang-jain May 5, 2026
affe85a
empty
tarang-jain May 5, 2026
c6dea64
fix norm cache
tarang-jain May 5, 2026
7dfab3e
revert changes to minClusterDistanceCompute
tarang-jain May 6, 2026
7a383da
update tests to use mdspan instead of rmm
tarang-jain May 6, 2026
ce6c4b5
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
5a06a44
Merge branch 'main' into combine-batch
tarang-jain May 6, 2026
28cda6a
Merge branch 'combine-batch' into mg-batched-kmeans
viclafargue May 7, 2026
bfb5290
Addressing review
viclafargue May 7, 2026
add9db1
optimize convergence check
viclafargue May 7, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion cpp/include/cuvs/cluster/kmeans.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,13 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* on the host. Data is processed in GPU-sized batches, streaming from host to device.
* The batch size is controlled by params.streaming_batch_size.
*
* Multi-GPU dispatch is selected automatically based on the handle state:
* - If `raft::resource::is_multi_gpu(handle)` (cuVS SNMG): the full dataset X
* is split across GPUs internally with an OpenMP parallel region and NCCL.
* - If `raft::resource::comms_initialized(handle)` (Dask/Ray/MPI): X is treated as
* this worker's partition, and RAFT communicators are used for collectives.
* - Otherwise: single-GPU batched k-means.
Comment thread
coderabbitai[bot] marked this conversation as resolved.
*
* @code{.cpp}
* #include <raft/core/resources.hpp>
* #include <cuvs/cluster/kmeans.hpp>
Expand Down Expand Up @@ -196,7 +203,8 @@ enum class kmeans_type { KMeans = 0, KMeansBalanced = 1 };
* raft::make_host_scalar_view(&n_iter));
* @endcode
*
* @param[in] handle The raft handle.
* @param[in] handle The raft handle. When a multi-GPU resource is
* attached, multi-GPU dispatch is used automatically.
* @param[in] params Parameters for KMeans model. Batch size is read from
* params.streaming_batch_size.
* @param[in] X Training instances on HOST memory. The data must
Expand Down
215 changes: 163 additions & 52 deletions cpp/src/cluster/detail/kmeans_common.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#include <raft/core/kvp.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/mdarray.hpp>
#include <raft/core/memory_type.hpp>
#include <raft/core/operators.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
Expand All @@ -43,6 +44,9 @@
#include <cuda.h>
#include <cuda/iterator>
#include <thrust/for_each.h>
#include <thrust/iterator/transform_iterator.h>

#include <raft/linalg/add.cuh>

#include <algorithm>
#include <cmath>
Expand Down Expand Up @@ -129,43 +133,36 @@ void countLabels(raft::resources const& handle,
stream));
}

template <typename DataT, typename IndexT>
void checkWeight(raft::resources const& handle,
raft::device_vector_view<DataT, IndexT> weight,
rmm::device_uvector<char>& workspace)
/**
* @brief Compute the sum of sample weights.
*
* Device-accessible mdspans are reduced on device via mapThenSumReduce;
* host mdspans are summed on the host.
*
* @return Sum of weights.
*/
template <typename DataT, typename IndexT, typename Accessor>
DataT weightSum(
raft::resources const& handle,
raft::mdspan<const DataT, raft::vector_extent<IndexT>, raft::layout_right, Accessor> weight)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto wt_aggr = raft::make_device_scalar<DataT>(handle, 0);
auto n_samples = weight.extent(0);

size_t temp_storage_bytes = 0;
RAFT_CUDA_TRY(cub::DeviceReduce::Sum(
nullptr, temp_storage_bytes, weight.data_handle(), wt_aggr.data_handle(), n_samples, stream));

workspace.resize(temp_storage_bytes, stream);

RAFT_CUDA_TRY(cub::DeviceReduce::Sum(workspace.data(),
temp_storage_bytes,
weight.data_handle(),
wt_aggr.data_handle(),
n_samples,
stream));
DataT wt_sum = 0;
raft::copy(handle,
raft::make_host_scalar_view(&wt_sum),
raft::make_device_scalar_view(wt_aggr.data_handle()));
raft::resource::sync_stream(handle, stream);

if (wt_sum != n_samples) {
RAFT_LOG_DEBUG(
"[Warning!] KMeans: normalizing the user provided sample weight to "
"sum up to %d samples",
n_samples);

auto scale = static_cast<DataT>(n_samples) / wt_sum;
raft::linalg::map(
handle, weight, raft::mul_const_op<DataT>{scale}, raft::make_const_mdspan(weight));
auto n_samples = weight.extent(0);

DataT wt_sum = DataT{0};
if constexpr (raft::is_device_mdspan_v<decltype(weight)>) {
auto stream = raft::resource::get_cuda_stream(handle);
auto d_wt_sum = raft::make_device_scalar<DataT>(handle, DataT{0});
raft::linalg::mapThenSumReduce(
d_wt_sum.data_handle(), n_samples, raft::identity_op{}, stream, weight.data_handle());
raft::copy(&wt_sum, d_wt_sum.data_handle(), 1, stream);
raft::resource::sync_stream(handle);
Comment thread
viclafargue marked this conversation as resolved.
Outdated
} else {
for (IndexT i = 0; i < n_samples; ++i) {
wt_sum += weight(i);
}
}
RAFT_EXPECTS(wt_sum > DataT{0}, "invalid parameter (sum of sample weights must be positive)");
return wt_sum;
}

template <typename IndexT>
Expand Down Expand Up @@ -262,7 +259,7 @@ void sampleCentroids(raft::resources const& handle,
raft::copy(handle,
raft::make_host_scalar_view(&nPtsSampledInRank),
raft::make_device_scalar_view(nSelected.data_handle()));
raft::resource::sync_stream(handle, stream);
raft::resource::sync_stream(handle);

uint8_t* rawPtr_isSampleCentroid = isSampleCentroid.data_handle();
thrust::for_each_n(raft::resource::get_thrust_policy(handle),
Expand Down Expand Up @@ -367,7 +364,9 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>> precomputed_centroid_norms =
std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(DataT, IndexT) \
extern template void minClusterAndDistanceCompute<DataT, IndexT>( \
Expand All @@ -380,7 +379,8 @@ void minClusterAndDistanceCompute(
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>>);

EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_AND_DISTANCE(float, int)
Expand All @@ -399,7 +399,9 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric,
int batch_samples,
int batch_centroids,
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>>
precomputed_centroid_norms = std::nullopt);

#define EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(DataT, IndexT) \
extern template void minClusterDistanceCompute<DataT, IndexT>( \
Expand All @@ -412,7 +414,8 @@ void minClusterDistanceCompute(raft::resources const& handle,
cuvs::distance::DistanceType metric, \
int batch_samples, \
int batch_centroids, \
rmm::device_uvector<char>& workspace);
rmm::device_uvector<char>& workspace, \
std::optional<raft::device_vector_view<const DataT, IndexT>>);

EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(float, int64_t)
EXTERN_TEMPLATE_MIN_CLUSTER_DISTANCE(double, int64_t)
Expand Down Expand Up @@ -484,14 +487,18 @@ void countSamplesInCluster(raft::resources const& handle,
* @tparam IndexT Index type
* @tparam LabelsIterator Iterator type for cluster labels
*
* @param[in] handle RAFT resources handle
* @param[in] X Input samples [n_samples x n_features]
* @param[in] sample_weights Weights for each sample [n_samples]
* @param[in] cluster_labels Cluster assignment for each sample (iterator)
* @param[in] n_clusters Number of clusters
* @param[out] centroid_sums Output weighted sum per cluster [n_clusters x n_features]
* @param[out] weight_per_cluster Output sum of weights per cluster [n_clusters]
* @param[inout] workspace Workspace buffer for intermediate operations
* @param[in] handle RAFT resources handle
* @param[in] X Input samples [n_samples x n_features]
* @param[in] sample_weights Weights for each sample [n_samples]
* @param[in] cluster_labels Cluster assignment for each sample (iterator)
* @param[in] n_clusters Number of clusters
* @param[inout] centroid_sums Weighted sum per cluster [n_clusters x n_features]
* @param[inout] weight_per_cluster Sum of weights per cluster [n_clusters]. Follows the same
* overwrite-vs-accumulate semantics as `centroid_sums`
* @param[inout] workspace Workspace buffer for intermediate operations
* @param[in] reset_sums If true (default), outputs are reset to zero before reducing;
* if false, this call's contribution is accumulated into the
* existing `centroid_sums`
*/
template <typename DataT, typename IndexT, typename LabelsIterator>
void compute_centroid_adjustments(
Expand All @@ -502,7 +509,8 @@ void compute_centroid_adjustments(
IndexT n_clusters,
raft::device_matrix_view<DataT, IndexT, raft::row_major> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
rmm::device_uvector<char>& workspace)
rmm::device_uvector<char>& workspace,
bool reset_sums = true)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);
auto n_samples = X.extent(0);
Expand All @@ -518,15 +526,17 @@ void compute_centroid_adjustments(
X.extent(1),
n_clusters,
centroid_sums.data_handle(),
stream);
stream,
reset_sums);

raft::linalg::reduce_cols_by_key(sample_weights.data_handle(),
cluster_labels,
weight_per_cluster.data_handle(),
static_cast<IndexT>(1),
static_cast<IndexT>(n_samples),
n_clusters,
stream);
stream,
reset_sums);
}
/**
* @brief Finalize centroids by dividing accumulated sums by counts.
Expand Down Expand Up @@ -594,8 +604,109 @@ DataT compute_centroid_shift(raft::resources const& handle,
new_centroids.data_handle());
DataT result = 0;
raft::copy(&result, sqrdNorm.data_handle(), 1, stream);
raft::resource::sync_stream(handle, stream);
raft::resource::sync_stream(handle);
return result;
}

/**
* @brief Process a single batch of data in the Lloyd iteration.
*
* Given one batch of data + precomputed norms + weights + current centroids it
* 1. finds the nearest centroid for every sample,
* 2. accumulates weighted centroid sums and counts into the running accumulators,
* 3. accumulates the weighted clustering cost (inertia).
*
* Data norms must be precomputed by the caller and passed in via L2NormBatch.
*
* @tparam DataT Data / weight type (float, double)
* @tparam IndexT Index type (int, int64_t)
*
* @param[in] handle RAFT resources handle
* @param[in] batch_data Device batch data [batch_size x n_features]
* @param[in] batch_weights Device batch weights [batch_size]
* @param[in] centroids Current centroids [n_clusters x n_features]
* @param[in] metric Distance metric
* @param[in] batch_samples_param Batch-samples param forwarded to minClusterAndDistanceCompute
* @param[in] batch_centroids_param Batch-centroids param forwarded to
* minClusterAndDistanceCompute
* @param[inout] minClusterAndDistance Work buffer [batch_size]
* @param[in] L2NormBatch Precomputed data norms [batch_size]
* @param[inout] L2NormBuf_OR_DistBuf Resizable scratch
* @param[inout] workspace Resizable scratch
* @param[inout] centroid_sums Running weighted sums [n_clusters x n_features] (added into)
* @param[inout] weight_per_cluster Running weight counts [n_clusters] (added into)
* @param[inout] clustering_cost Running cost scalar (device) (added into)
* @param[in] centroid_norms Optional precomputed centroid norms [n_clusters].
* When provided, skips internal centroid norm computation.
*/
template <typename DataT, typename IndexT>
void process_batch(
raft::resources const& handle,
raft::device_matrix_view<const DataT, IndexT> batch_data,
raft::device_vector_view<const DataT, IndexT> batch_weights,
raft::device_matrix_view<const DataT, IndexT> centroids,
cuvs::distance::DistanceType metric,
int batch_samples_param,
int batch_centroids_param,
raft::device_vector_view<raft::KeyValuePair<IndexT, DataT>, IndexT> minClusterAndDistance,
raft::device_vector_view<const DataT, IndexT> L2NormBatch,
rmm::device_uvector<DataT>& L2NormBuf_OR_DistBuf,
rmm::device_uvector<char>& workspace,
raft::device_matrix_view<DataT, IndexT> centroid_sums,
raft::device_vector_view<DataT, IndexT> weight_per_cluster,
raft::device_scalar_view<DataT> clustering_cost,
rmm::device_uvector<char>& batch_workspace,
std::optional<raft::device_vector_view<const DataT, IndexT>> centroid_norms = std::nullopt)
{
cudaStream_t stream = raft::resource::get_cuda_stream(handle);

minClusterAndDistanceCompute<DataT, IndexT>(handle,
batch_data,
centroids,
minClusterAndDistance,
L2NormBatch,
L2NormBuf_OR_DistBuf,
metric,
batch_samples_param,
batch_centroids_param,
workspace,
centroid_norms);

KeyValueIndexOp<IndexT, DataT> conversion_op;
thrust::transform_iterator<KeyValueIndexOp<IndexT, DataT>,
const raft::KeyValuePair<IndexT, DataT>*>
labels_itr(minClusterAndDistance.data_handle(), conversion_op);

compute_centroid_adjustments(handle,
batch_data,
batch_weights,
labels_itr,
static_cast<IndexT>(centroid_sums.extent(0)),
centroid_sums,
weight_per_cluster,
batch_workspace,
/*reset_sums=*/false);

raft::linalg::map(
handle,
minClusterAndDistance,
[=] __device__(const raft::KeyValuePair<IndexT, DataT> kvp, DataT wt) {
raft::KeyValuePair<IndexT, DataT> res;
res.value = kvp.value * wt;
res.key = kvp.key;
return res;
},
raft::make_const_mdspan(minClusterAndDistance),
batch_weights);

auto batch_cost = raft::make_device_scalar<DataT>(handle, DataT{0});
computeClusterCost(
handle, minClusterAndDistance, workspace, batch_cost.view(), raft::value_op{}, raft::add_op{});
raft::linalg::add(clustering_cost.data_handle(),
clustering_cost.data_handle(),
batch_cost.data_handle(),
1,
stream);
}

} // namespace cuvs::cluster::kmeans::detail
Loading
Loading