Skip to content

[Cleanup] Combine Batched and Regular KMeans Impl#2015

Open
tarang-jain wants to merge 54 commits intorapidsai:mainfrom
tarang-jain:combine-batch
Open

[Cleanup] Combine Batched and Regular KMeans Impl#2015
tarang-jain wants to merge 54 commits intorapidsai:mainfrom
tarang-jain:combine-batch

Conversation

@tarang-jain
Copy link
Copy Markdown
Contributor

@tarang-jain tarang-jain commented Apr 10, 2026

Combine batched and regular k-means implementations

  • Unified the batched (host-data) and regular (device-data) k-means fit into a single kmeans_fit template that works with both host and device mdspans via batch_load_iterator
  • Unified the device and host initialization paths in init_centroids
  • Removed the inertia_check parameter — inertia-based convergence checking now always runs. Zero clustering cost (perfect fit) logs a warning instead of asserting. This is needed because spectral clustering can cause all points to converge on the cluster centroids itself.
  • Added init_size parameter to control how many samples are drawn for KMeansPlusPlus initialization. Defaults to n_samples for device data, (3 * n_clusters) for host data
  • Replaced per-iteration centroid raft::copy with std::swap of buffer pointers
  • For streaming fit, precompute data norms once and cache them: host norms cached to a host buffer on the first iteration and copied back for subsequent iterations. process_batch no longer computes norms internally
  • Replaced raw cudaPointerGetAttributes call with raft::memory_type_from_pointer
  • Replaced cub::DeviceReduce::Sum calls with raft::linalg::mapThenSumReduce
  • Guarded weight normalization against overflow: apply (w / wt_sum) * n_samples via a composed op instead of precomputing a scale, so very small wt_sum values don't produce inf
  • Renamed checkWeight to weightSum and made it mdspan-based with an Accessor template: device reduce for device weights, host loop for host weights. Callers apply the scaling themselves
  • Eliminated batch_sums / batch_counts scratch buffers by accumulating directly into centroid_sums / weight_per_cluster via reset_sums=false in reduce_rows_by_key / reduce_cols_by_key, removing two per-batch raft::linalg::add kernels
  • Removed dead update_centroids helpers (both the detail and public template) — no remaining callers after the fit_main consolidation
  • Perf: remove multiple raft::sync_stream calls and add a CUDA Event to record if the convergence criteria is met. Convergence check is now done on device. Average per-iteration time with mandatory inertia check now matches previous benchmarks even when previously inertia check was disabled.

C Tests

This PR adds C tests for KMeans. These were missing. Here we test both -- the old version and the new (i.e. breaking change).

Benchmarks:

With mandatory early stopping. Batch size is such that we fill up 90% of available GPU memory (95830MiB)
HW:
GPU:
NVIDIA H100 NVL (CUDA 13.0)
CPU:

Architecture:             x86_64
  CPU op-mode(s):         32-bit, 64-bit
  Address sizes:          52 bits physical, 57 bits virtual
  Byte Order:             Little Endian
CPU(s):                   256
  On-line CPU(s) list:    0-255
Vendor ID:                AuthenticAMD
  Model name:             AMD EPYC 9554 64-Core Processor 
================================================================================
 SUMMARY
================================================================================
  n_clusters     batch_size  fit_time(s)        inertia   n_iter
----------------------------------------------------------------
      10,000     29,120,352      1584.72     2.8677e+08       30
      20,000     29,120,352      2907.34     2.7368e+08       31
      30,000     29,101,305      4254.43     2.6617e+08       31
      40,000     29,092,704      5836.12     2.6086e+08       32
      50,000     29,083,488      7107.04     2.5680e+08       31

Breaking Change

This PR is a breaking change of the C++ API because the inertia_check param is removed. The breaking changes to the C ABI will be applied in 26.08

@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 10, 2026

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@tarang-jain tarang-jain self-assigned this Apr 10, 2026
@tarang-jain tarang-jain added improvement Improves an existing functionality non-breaking Introduces a non-breaking change cpp labels Apr 10, 2026
@tarang-jain tarang-jain marked this pull request as ready for review April 14, 2026 01:10
@tarang-jain tarang-jain requested review from a team as code owners April 14, 2026 01:10
Comment thread c/include/cuvs/cluster/kmeans.h
Copy link
Copy Markdown
Contributor

@viclafargue viclafargue left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! Here are some comments.


auto minClusterAndDistance = raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(
handle, streaming_batch_size);
auto L2NormBatch = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

pams.streaming_batch_size = 0 by default in the data on device case, but nothing prevent a user from setting a value. This would allocate a smaller than n_samples L2NormBatch which would cause OOB writes (and later reads) during norm computation.

We should probably guard this with a check :
RAFT_EXPECTS(streaming_batch_size == n_samples || !data_on_device, ...)

Copy link
Copy Markdown
Contributor Author

@tarang-jain tarang-jain Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have updated this so that for device arrays, we simply ignore the streaming_batch_size and use the entire dataset always.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
Comment on lines +661 to +663
auto init_sample =
raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
auto init_sample =
raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
if (init_sample_size == n_samples && data_on_device) {
auto init_sample_const = raft::make_device_matrix_view<const DataT, IndexT>(X.data_handle(), n_samples, n_features);
// pass directly to kmeansPlusPlus / initScalableKMeansPlusPlus
} else {
auto init_sample = raft::make_device_matrix<DataT, IndexT>(handle, init_sample_size, n_features);
raft::matrix::sample_rows(handle, random_state, X, init_sample.view());
// pass init_sample to kmeansPlusPlus / initScalableKMeansPlusPlus
}

If init_size = 0 in the data on device path, we basically double memory use by copying the dataset over. Let's skip this by creating a view on the dataset.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I completely skipped the sampling for the device path. That is how it was being done earlier. The init size is only used if the data is on host.

Comment on lines +731 to +732
auto batch_workspace = rmm::device_uvector<char>(
current_batch_sz, stream, raft::resource::get_workspace_resource(handle));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Every call to process_batch allocates both this workspace and the device scalar below. Both buffers could be instantiated out of the process_batch function.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

moved the workspace buffer allocation outside the process_batch function.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
raft::matrix::sample_rows(handle, random_state, X, centroidsRawData);
} else if (iter_params.init == cuvs::cluster::kmeans::params::InitMethod::KMeansPlusPlus) {
IndexT default_init_size =
data_on_device ? n_samples : std::min(static_cast<IndexT>(3 * n_clusters), n_samples);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unlikely to be an actual issue, but n_clusters could be casted before the multiplication to avoid any risk of integer overflow.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have gotten rid of the batching for the device path. So when a user sets a batch size for device mdspan, we just set it to n_samples and warn the user. We should definitely not be creating a new buffer just for the init sample if we can accommodate the entire input matrix on device already.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
Comment on lines +876 to +881
DataT curClusteringCost = DataT{0};
raft::copy(&curClusteringCost, clustering_cost.data_handle(), 1, stream);
raft::resource::sync_stream(handle, stream);

if (curClusteringCost == DataT{0}) {
RAFT_LOG_WARN("Zero clustering cost detected: all points coincide with their centroids.");
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Going from ASSERT to RAFT_LOG_WARN may indeed be useful for the spectral clustering case. However, removing the inertia_check option forces the sync at every iteration. Do we truly need to drop this option?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need the log and an assert might be better?

Early stopping (aka skipping iterations) is ultimately going to be the best way to extract perf here. Whether it's by explicitly computing inertia or just looking at the residuals of the centroids from the prior iteration.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Copy link
Copy Markdown
Contributor Author

@tarang-jain tarang-jain Apr 20, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Until the iteration has completed, the CPU should not start the next iteration. So all the operations on the GPU stream must complete to finish the iteration.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like inertia check / residuals could be done on gpu if we had to in order to avoid syncing so we would only need to sync in the final iteration, right?

Yes, but this is throwing an error in the spectral clustering case wherein all the points converge on the centroids themselves. This is happening in one of the spectral tests and an assertion here is leading to an error, where instead it should simply return those centroids directly.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need the log and an assert might be better?

Therefore, I had to change it to a warning instead of an assertion. Earlier those spectral tests were skipping the inertia check which was avoiding the assertion.

Comment on lines +638 to +646
} else {
std::vector<DataT> h_weights(n_samples);
auto d_view = raft::make_device_vector_view<const DataT, IndexT>(weight_ptr, n_samples);
auto h_view = raft::make_host_vector_view<DataT, IndexT>(h_weights.data(), n_samples);
raft::copy(handle, h_view, d_view);
raft::resource::sync_stream(handle);
for (IndexT i = 0; i < n_samples; ++i) {
wt_sum += h_weights[i];
}
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the data on device case since the data is already on device it would be much faster to sumreduce thanks to cub::DeviceReduce::Sum or raft::linalg::reduce. The summation would also have better precision since it is done in a tree fashion O(log N).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When its device accessible, I have changed that to a raft::linalg::mapThenSumReduce. I have also removed this function and directly updated checkWeight (changed its name to weightSum). We do the scaling after the weight sum is computed.

double oversampling_factor,
int batch_samples,
int batch_centroids,
bool inertia_check,
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add a comment saying the field is present but deprecated.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dont think its necessary to add a comment here in the .pxd. The C header already has that information. And this file will be updated along with the C headers / src files.

Comment thread cpp/src/cluster/detail/kmeans.cuh Outdated
Comment thread c/include/cuvs/cluster/kmeans.h Outdated
@copy-pr-bot
Copy link
Copy Markdown

copy-pr-bot Bot commented Apr 28, 2026

This pull request requires additional validation before any workflows can run on NVIDIA's runners.

Pull request vetters can view their responsibilities here.

Contributors can view more details about this message here.

@tarang-jain
Copy link
Copy Markdown
Contributor Author

/ok to test b1c034e

@tarang-jain tarang-jain requested a review from a team as a code owner April 29, 2026 17:17
@tarang-jain
Copy link
Copy Markdown
Contributor Author

/ok to test 73293cf

Copy link
Copy Markdown
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just had a minor suggestion, not blocking

Comment on lines 115 to +123
/**
* If true, check inertia during iterations for early convergence.
* Number of samples to randomly draw for the KMeansPlusPlus initialization
* step. A random subset of this size is used for centroid seeding.
* When set to 0 the default depends on the data location:
* - Device data: n_samples (use the full dataset).
* - Host data: min(3 * n_clusters, n_samples).
* Default: 0.
*/
bool inertia_check = false;
int64_t init_size = 0;
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I soflty agree with this rabbit comment, would be worth to add a not in the PR description and code, C++ callers using the C++ API directly will fail to compile, not warn. That's a real source-API break worth a release-notes line at least

@tarang-jain
Copy link
Copy Markdown
Contributor Author

@dantegd I have updated the PR desc. Since this PR is marked as breaking, it will automatically be mentioned in CHANGELOG.md by the bot, right?

cluster_centers, impl->n_lists(), impl->dim());
if (impl->metric() == distance::DistanceType::CosineExpanded) {
raft::linalg::row_normalize<raft::linalg::L2Norm>(handle, centers_const_view, centers_view);
}
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This normalization was lost in #2001. So this PR adds it again.

rapids-bot Bot pushed a commit to rapidsai/cuml that referenced this pull request Apr 30, 2026
Depends on rapidsai/cuvs#2015. Inertia checking is being made mandatory and rapidsai/cuvs#2015 is a breaking change. This PR is needed to prevent compilation failures.

Authors:
  - Tarang Jain (https://github.com/tarang-jain)

Approvers:
  - Jim Crist-Harif (https://github.com/jcrist)
  - Anupam (https://github.com/aamijar)
  - Victor Lafargue (https://github.com/viclafargue)

URL: #8033
@tarang-jain
Copy link
Copy Markdown
Contributor Author

/ok to test e28c200

Copy link
Copy Markdown
Member

@dantegd dantegd left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cmake review

Comment on lines +13 to +18
static void fill_matrix_tensor(DLManagedTensor* t,
void* data,
int64_t* shape,
DLDeviceType device_type,
uint8_t code,
uint8_t bits)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The common usage for C tests seems to be #include "../../src/core/interop.hpp" and use cuvs::core::to_dlpack to simplify the interaction with DLManagedTensor. Can it get used here too, instead of fill_matrix_tensor and fill_vector_tensor

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Merge run_kmeans_c and kmeans_c into the same file.
The total line number won't be so big once the deprecated path are removed.

Comment on lines +644 to +645
raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features),
raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raft::make_device_vector_view(centroidsRawData.data_handle(), n_clusters * n_features),
raft::make_device_vector_view(centroids.data_handle(), n_clusters * n_features));
centroidsRawData.view(),
centroids.view());

Comment on lines +683 to +686
rmm::device_uvector<DataT> centroid_buf_A(centroid_buf_size, stream);
rmm::device_uvector<DataT> centroid_buf_B(centroid_buf_size, stream);
DataT* cur_centroids_ptr = centroid_buf_A.data();
DataT* new_centroids_ptr = centroid_buf_B.data();
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keep the cur / new naming instead of A/B

auto centroid_norms_buf = raft::make_device_vector<DataT, IndexT>(handle, n_clusters);
auto clustering_cost = raft::make_device_scalar<DataT>(handle, DataT{0});

rmm::device_uvector<char> batch_workspace(streaming_batch_size, stream);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can ws be reused? (given workspace or local workspace)


bool need_compute_norms = metric == cuvs::distance::DistanceType::L2Expanded ||
metric == cuvs::distance::DistanceType::L2SqrtExpanded;
bool use_norm_cache = need_compute_norms && !data_on_device;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On the device path the norms will be recomputed max_iter * n_init. It might be faster as well to cache the norm there (on device preferably since the dataset already fit there).

auto minClusterAndDistance = raft::make_device_vector<raft::KeyValuePair<IndexT, DataT>, IndexT>(
handle, streaming_batch_size);
auto L2NormBatch = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
auto batch_weights_buf = raft::make_device_vector<DataT, IndexT>(handle, streaming_batch_size);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently the weights are recomputed every iteration.
For device path this vector can contain pre-normalized weights, and weight-norm can be skipped

if (weight_ptr != nullptr) {
weight_batches.emplace(weight_ptr, n_samples, 1, streaming_batch_size, stream);
} else {
raft::matrix::fill(handle, batch_weights_buf.view(), DataT{1});
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fill with 1/n_samples and skip weight normalization


for (n_current_iter = 1; n_current_iter <= iter_params.max_iter; ++n_current_iter) {
if (n_current_iter > 1) {
RAFT_CUDA_TRY(cudaEventSynchronize(convergence_event));
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that there would be any differences between the use of cudaEventSynchronize and a sync_stream at the end of the loop? The only operation that happens is the loop increment?

Comment on lines +937 to +941
cuvs::cluster::kmeans::cluster_cost(handle,
batch_data_view,
centroids_const,
raft::make_host_scalar_view(&batch_cost),
batch_sw);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cluster_cost could be modified to accept an optional X_norm to avoid recomputing it. That function also synchronize each time it is called, while it would be better to accumulate cluster_cost on device and transfer+sync once at the end.
Maybe it's not in the scope of that PR but wanted to point it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

breaking Introduces a breaking change cpp improvement Improves an existing functionality

Projects

Status: No status

Development

Successfully merging this pull request may close these issues.

5 participants