diff --git a/cpp/src/cluster/detail/kmeans.cuh b/cpp/src/cluster/detail/kmeans.cuh index 5a35f203b3..79f20385f6 100644 --- a/cpp/src/cluster/detail/kmeans.cuh +++ b/cpp/src/cluster/detail/kmeans.cuh @@ -138,13 +138,14 @@ void kmeansPlusPlus(raft::resources const& handle, raft::linalg::norm(handle, X, L2NormX.view()); } - raft::random::RngState rng(params.rng_state.seed, params.rng_state.type); - std::mt19937 gen(params.rng_state.seed); + std::mt19937_64 gen_64(params.rng_state.seed); + uint64_t gpu_seed = gen_64(); + raft::random::RngState rng(gpu_seed, params.rng_state.type); std::uniform_int_distribution<> dis(0, n_samples - 1); // <<< Step-1 >>>: C <-- sample a point uniformly at random from X auto initialCentroid = raft::make_device_matrix_view( - X.data_handle() + dis(gen) * n_features, 1, n_features); + X.data_handle() + dis(gen_64) * n_features, 1, n_features); int n_clusters_picked = 1; // store the chosen centroid in the buffer diff --git a/cpp/src/cluster/detail/kmeans_mg.cuh b/cpp/src/cluster/detail/kmeans_mg.cuh index fdec2bdd73..fe984b2fe3 100644 --- a/cpp/src/cluster/detail/kmeans_mg.cuh +++ b/cpp/src/cluster/detail/kmeans_mg.cuh @@ -58,6 +58,7 @@ static cuvs::cluster::kmeans::params default_params; template void initRandom(const raft::resources& handle, const cuvs::cluster::kmeans::params& params, + std::mt19937_64& gen_64, raft::device_matrix_view X, raft::device_matrix_view centroids) { @@ -96,8 +97,9 @@ void initRandom(const raft::resources& handle, auto centroidsSampledInRank = raft::make_device_matrix(handle, nCentroidsSampledInRank, n_features); + uint64_t gpu_seed = gen_64(); cuvs::cluster::kmeans::shuffle_and_gather( - handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, params.rng_state.seed); + handle, X, centroidsSampledInRank.view(), nCentroidsSampledInRank, gpu_seed); std::vector displs(n_ranks); std::exclusive_scan(nCentroidsElementsToReceiveFromRank.begin(), @@ -130,6 +132,7 @@ void initRandom(const raft::resources& handle, template void initKMeansPlusPlus(const raft::resources& handle, const cuvs::cluster::kmeans::params& params, + std::mt19937_64& gen_64, raft::device_matrix_view X, raft::device_matrix_view centroidsRawData, rmm::device_uvector& workspace) @@ -144,7 +147,6 @@ void initKMeansPlusPlus(const raft::resources& handle, auto n_clusters = params.n_clusters; auto metric = params.metric; - raft::random::RngState rng(params.rng_state.seed, raft::random::GeneratorType::GenPhilox); // <<<< Step-1 >>> : C <- sample a point uniformly at random from X // 1.1 - Select a rank r' at random from the available n_rank ranks with a @@ -157,9 +159,8 @@ void initKMeansPlusPlus(const raft::resources& handle, // Choose rp on rank 0 and broadcast to all ranks to guarantee agreement int rp = 0; if (my_rank == KMEANS_COMM_ROOT) { - std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_rank - 1); - rp = dis(gen); + rp = dis(gen_64); } { rmm::device_scalar rp_d(stream); @@ -182,10 +183,9 @@ void initKMeansPlusPlus(const raft::resources& handle, // 1.2 - Rank r' samples a point uniformly at random from the local dataset // X which will be used as the initial centroid for kmeans++ if (my_rank == rp) { - std::mt19937 gen(params.rng_state.seed); std::uniform_int_distribution<> dis(0, n_samples - 1); - int cIdx = dis(gen); + int cIdx = dis(gen_64); auto centroidsView = raft::make_device_matrix_view( X.data_handle() + cIdx * n_features, 1, n_features); @@ -316,6 +316,9 @@ void initKMeansPlusPlus(const raft::resources& handle, // <<<< Step-4 >>> : Sample each point x in X independently and identify new // potentialCentroids + uint64_t gpu_seed; + gpu_seed = gen_64(); + raft::random::RngState rng(gpu_seed, params.rng_state.type); raft::random::uniform( handle, rng, uniformRands.data_handle(), uniformRands.extent(0), (DataT)0, (DataT)1); cuvs::cluster::kmeans::SamplingOp select_op(psi, @@ -404,16 +407,17 @@ void initKMeansPlusPlus(const raft::resources& handle, // seed they should generate the same potentialCentroids auto const_centroids = raft::make_device_matrix_view( potentialCentroids.data_handle(), potentialCentroids.extent(0), potentialCentroids.extent(1)); + auto params_copy = params; + params_copy.rng_state.seed = gen_64(); cuvs::cluster::kmeans::init_plus_plus( - handle, params, const_centroids, centroidsRawData, workspace); + handle, params_copy, const_centroids, centroidsRawData, workspace); auto inertia = raft::make_host_scalar(0); auto n_iter = raft::make_host_scalar(0); auto weight_view = raft::make_device_vector_view(weight.data_handle(), weight.extent(0)); - cuvs::cluster::kmeans::params params_copy = params; - params_copy.rng_state = default_params.rng_state; - + // Update the seed one more time + params_copy.rng_state.seed = gen_64(); cuvs::cluster::kmeans::fit_main(handle, params_copy, const_centroids, @@ -436,10 +440,10 @@ void initKMeansPlusPlus(const raft::resources& handle, // generate `n_random_clusters` centroids cuvs::cluster::kmeans::params rand_params = params; - rand_params.rng_state = default_params.rng_state; + rand_params.rng_state.seed = gen_64(); rand_params.init = cuvs::cluster::kmeans::params::InitMethod::Random; rand_params.n_clusters = n_random_clusters; - initRandom(handle, rand_params, X, centroidsRawData); + initRandom(handle, rand_params, gen_64, X, centroidsRawData); // copy centroids generated during kmeans|| iteration to the buffer raft::copy( @@ -514,6 +518,11 @@ void fit(const raft::resources& handle, auto n_clusters = params.n_clusters; auto metric = params.metric; + const int my_rank = comm.get_rank(); + const int n_ranks = comm.get_size(); + + std::mt19937_64 gen_64(params.rng_state.seed + (uint64_t(my_rank) << 32)); + auto weight = raft::make_device_vector(handle, n_samples); if (sample_weight) { raft::copy(handle, weight.view(), sample_weight.value()); @@ -529,11 +538,11 @@ void fit(const raft::resources& handle, CUVS_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers by randomly choosing from the " "input data.\n"); - initRandom(handle, params, X, centroids); + initRandom(handle, params, gen_64, X, centroids); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) { // default method to initialize is kmeans++ CUVS_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers using k-means++ algorithm.\n"); - initKMeansPlusPlus(handle, params, X, centroids, workspace); + initKMeansPlusPlus(handle, params, gen_64, X, centroids, workspace); } else if (params.init == cuvs::cluster::kmeans::params::InitMethod::Array) { CUVS_LOG_KMEANS(handle, "KMeans.fit: initialize cluster centers from the ndarray array input "