Skip to content
Open
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions cpp/src/cluster/detail/kmeans.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -138,13 +138,14 @@ void kmeansPlusPlus(raft::resources const& handle,
raft::linalg::norm<raft::linalg::L2Norm, raft::Apply::ALONG_ROWS>(handle, X, L2NormX.view());
}

raft::random::RngState rng(params.rng_state.seed, params.rng_state.type);
std::mt19937 gen(params.rng_state.seed);
std::mt19937_64 gen_64(params.rng_state.seed);
uint64_t gpu_seed = gen_64();
raft::random::RngState rng(gpu_seed, params.rng_state.type);
std::uniform_int_distribution<> dis(0, n_samples - 1);

// <<< Step-1 >>>: C <-- sample a point uniformly at random from X
auto initialCentroid = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + dis(gen) * n_features, 1, n_features);
X.data_handle() + dis(gen_64) * n_features, 1, n_features);
int n_clusters_picked = 1;

// store the chosen centroid in the buffer
Expand Down
37 changes: 23 additions & 14 deletions cpp/src/cluster/detail/kmeans_mg.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ static cuvs::cluster::kmeans::params default_params;
template <typename DataT, typename IndexT>
void initRandom(const raft::resources& handle,
const cuvs::cluster::kmeans::params& params,
std::mt19937_64& gen_64,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<DataT, IndexT> centroids)
{
Expand Down Expand Up @@ -96,8 +97,9 @@ void initRandom(const raft::resources& handle,
auto centroidsSampledInRank =
raft::make_device_matrix<DataT, IndexT>(handle, nCentroidsSampledInRank, n_features);

uint64_t gpu_seed = gen_64();
cuvs::cluster::kmeans::shuffle_and_gather(
handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed);
handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, gpu_seed);

std::vector<size_t> displs(n_ranks);
std::exclusive_scan(nCentroidsElementsToReceiveFromRank.begin(),
Expand Down Expand Up @@ -130,6 +132,7 @@ void initRandom(const raft::resources& handle,
template <typename DataT, typename IndexT>
void initKMeansPlusPlus(const raft::resources& handle,
const cuvs::cluster::kmeans::params& params,
std::mt19937_64& gen_64,
raft::device_matrix_view<const DataT, IndexT> X,
raft::device_matrix_view<DataT, IndexT> centroidsRawData,
rmm::device_uvector<char>& workspace)
Expand All @@ -144,7 +147,6 @@ void initKMeansPlusPlus(const raft::resources& handle,
auto n_clusters = params.n_clusters;
auto metric = params.metric;

raft::random::RngState rng(params.rng_state.seed, raft::random::GeneratorType::GenPhilox);

// <<<< Step-1 >>> : C <- sample a point uniformly at random from X
// 1.1 - Select a rank r' at random from the available n_rank ranks with a
Expand All @@ -157,9 +159,8 @@ void initKMeansPlusPlus(const raft::resources& handle,
// Choose rp on rank 0 and broadcast to all ranks to guarantee agreement
int rp = 0;
if (my_rank == KMEANS_COMM_ROOT) {
std::mt19937 gen(params.rng_state.seed);
std::uniform_int_distribution<> dis(0, n_rank - 1);
rp = dis(gen);
rp = dis(gen_64);
}
{
rmm::device_scalar<int> rp_d(stream);
Expand All @@ -182,10 +183,9 @@ void initKMeansPlusPlus(const raft::resources& handle,
// 1.2 - Rank r' samples a point uniformly at random from the local dataset
// X which will be used as the initial centroid for kmeans++
if (my_rank == rp) {
std::mt19937 gen(params.rng_state.seed);
std::uniform_int_distribution<> dis(0, n_samples - 1);

int cIdx = dis(gen);
int cIdx = dis(gen_64);
auto centroidsView = raft::make_device_matrix_view<const DataT, IndexT>(
X.data_handle() + cIdx * n_features, 1, n_features);

Expand Down Expand Up @@ -316,6 +316,9 @@ void initKMeansPlusPlus(const raft::resources& handle,

// <<<< Step-4 >>> : Sample each point x in X independently and identify new
// potentialCentroids
uint64_t gpu_seed;
gpu_seed = gen_64();
raft::random::RngState rng(gpu_seed, params.rng_state.type);
raft::random::uniform(
handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1);
cuvs::cluster::kmeans::SamplingOp<DataT, IndexT> select_op(psi,
Expand Down Expand Up @@ -404,16 +407,17 @@ void initKMeansPlusPlus(const raft::resources& handle,
// seed they should generate the same potentialCentroids
auto const_centroids = raft::make_device_matrix_view<const DataT, IndexT>(
potentialCentroids.data_handle(), potentialCentroids.extent(0), potentialCentroids.extent(1));
auto params_copy = params;
params_copy.rng_state.seed = gen_64();
cuvs::cluster::kmeans::init_plus_plus(
handle, params, const_centroids, centroidsRawData, workspace);
handle, params_copy, const_centroids, centroidsRawData, workspace);
Comment on lines 406 to +413
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an issue. When initKMeansPlusPlus oversamples (potentialCentroids.extent(0) > n_clusters), a mini single-GPU KMeans is run locally on each rank to reduce the candidates down to n_clusters. Since potentialCentroids and weight are already identical across ranks (via prior allgatherv/allreduce), and no communication happens after this reduction, all ranks must use the same RNG seed to produce identical results. Currently, init_plus_plus is given a per-rank-divergent seed from gen_64(), which breaks this invariant.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can generate a seed at rank 0 and broadcast it to other ranks? Would that be an acceptable solution? What implications would a sync here will have on load balancing?

Copy link
Copy Markdown
Contributor

@viclafargue viclafargue Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What implications would a sync here will have on load balancing?

At this step the workers just completed the allreduce operation. Broadcasting a seed should be relatively cheap to do.

We can generate a seed at rank 0 and broadcast it to other ranks? Would that be an acceptable solution?

I think that this solution could work. That said since earlier steps already introduced some randomness, we could theoretically perform this recluster step with a constant seed.

Also it would be great if we could double check that everything works as expected before we merge the PR. Especially that the KMeans algorithm actually starts with similar centroids initialization on all workers whatever the init mode chosen and whether there is a recluster step or not.


auto inertia = raft::make_host_scalar<DataT>(0);
auto n_iter = raft::make_host_scalar<IndexT>(0);
auto weight_view =
raft::make_device_vector_view<const DataT, IndexT>(weight.data_handle(), weight.extent(0));
cuvs::cluster::kmeans::params params_copy = params;
params_copy.rng_state = default_params.rng_state;

// Update the seed one more time
params_copy.rng_state.seed = gen_64();
cuvs::cluster::kmeans::fit_main<DataT, IndexT>(handle,
params_copy,
const_centroids,
Expand All @@ -436,10 +440,10 @@ void initKMeansPlusPlus(const raft::resources& handle,

// generate `n_random_clusters` centroids
cuvs::cluster::kmeans::params rand_params = params;
rand_params.rng_state = default_params.rng_state;
rand_params.rng_state.seed = gen_64();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be misleading to set this rand_params.rng_state.seed field here as initRandom does not use it anymore, but uses the gen_64 argument instead.

rand_params.init = cuvs::cluster::kmeans::params::InitMethod::Random;
rand_params.n_clusters = n_random_clusters;
initRandom(handle, rand_params, X, centroidsRawData);
initRandom(handle, rand_params, gen_64, X, centroidsRawData);

// copy centroids generated during kmeans|| iteration to the buffer
raft::copy(
Expand Down Expand Up @@ -514,6 +518,11 @@ void fit(const raft::resources& handle,
auto n_clusters = params.n_clusters;
auto metric = params.metric;

const int my_rank = comm.get_rank();
const int n_ranks = comm.get_size();

std::mt19937_64 gen_64(params.rng_state.seed + (uint64_t(my_rank) << 32));

auto weight = raft::make_device_vector<DataT, IndexT>(handle, n_samples);
if (sample_weight) {
raft::copy(handle, weight.view(), sample_weight.value());
Expand All @@ -529,11 +538,11 @@ void fit(const raft::resources& handle,
CUVS_LOG_KMEANS(handle,
"KMeans.fit: initialize cluster centers by randomly choosing from the "
"input data.\n");
initRandom<DataT, IndexT>(handle, params, X, centroids);
initRandom<DataT, IndexT>(handle, params, gen_64, X, centroids);
} else if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) {
// default method to initialize is kmeans++
CUVS_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n");
initKMeansPlusPlus<DataT, IndexT>(handle, params, X, centroids, workspace);
initKMeansPlusPlus<DataT, IndexT>(handle, params, gen_64, X, centroids, workspace);
} else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) {
CUVS_LOG_KMEANS(handle,
"KMeans.fit: initialize cluster centers from the ndarray array input "
Expand Down
Loading