diff --git a/algorithms/linfa-clustering/benches/k_means.rs b/algorithms/linfa-clustering/benches/k_means.rs index b72982917..8997d6c2f 100644 --- a/algorithms/linfa-clustering/benches/k_means.rs +++ b/algorithms/linfa-clustering/benches/k_means.rs @@ -5,7 +5,7 @@ use criterion::{ use linfa::benchmarks::config; use linfa::prelude::*; use linfa::DatasetBase; -use linfa_clustering::{IncrKMeansError, KMeans, KMeansInit}; +use linfa_clustering::{IncrKMeansError, KMeans, KMeansAlgorithm, KMeansInit}; use linfa_datasets::generate; use ndarray::Array2; use ndarray_rand::RandomExt; @@ -36,33 +36,41 @@ impl Drop for Stats { fn k_means_bench(c: &mut Criterion) { let mut rng = Xoshiro256Plus::seed_from_u64(40); let cluster_sizes = [(100, 4), (400, 10), (3000, 10)]; + let algorithms = [KMeansAlgorithm::Lloyd, KMeansAlgorithm::Hamerly]; let n_features = 3; - let mut benchmark = c.benchmark_group("naive_k_means"); + let mut benchmark = c.benchmark_group("k_means"); config::set_default_benchmark_configs(&mut benchmark); benchmark.plot_config(PlotConfiguration::default().summary_scale(AxisScale::Logarithmic)); - for &(cluster_size, n_clusters) in &cluster_sizes { - let rng = &mut rng; - let centroids = - Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng); - let dataset = DatasetBase::from(generate::blobs(cluster_size, ¢roids, rng)); - let mut stats = Stats::default(); + for &algorithm in &algorithms { + for &(cluster_size, n_clusters) in &cluster_sizes { + let rng = &mut rng; + let centroids = + Array2::random_using((n_clusters, n_features), Uniform::new(-30., 30.), rng); + let dataset = DatasetBase::from(generate::blobs(cluster_size, ¢roids, rng)); + let mut stats = Stats::default(); - benchmark.bench_function( - BenchmarkId::new("naive_k_means", format!("{n_clusters}x{cluster_size}")), - |bencher| { - bencher.iter(|| { - let m = KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone())) - .init_method(KMeansInit::KMeansPlusPlus) - .max_n_iterations(black_box(1000)) - .tolerance(black_box(1e-3)) - .fit(&dataset) - .unwrap(); - stats.add(m.inertia()); - }); - }, - ); + benchmark.bench_function( + BenchmarkId::new( + "k_means", + format!("{algorithm:?}:{n_clusters}x{cluster_size}"), + ), + |bencher| { + bencher.iter(|| { + let m = + KMeans::params_with_rng(black_box(n_clusters), black_box(rng.clone())) + .init_method(KMeansInit::KMeansPlusPlus) + .algorithm(algorithm) + .max_n_iterations(black_box(1000)) + .tolerance(black_box(1e-3)) + .fit(&dataset) + .unwrap(); + stats.add(m.inertia()); + }); + }, + ); + } } benchmark.finish(); diff --git a/algorithms/linfa-clustering/src/k_means/algorithm.rs b/algorithms/linfa-clustering/src/k_means/algorithm.rs index b537af64f..57e65bac0 100644 --- a/algorithms/linfa-clustering/src/k_means/algorithm.rs +++ b/algorithms/linfa-clustering/src/k_means/algorithm.rs @@ -2,11 +2,11 @@ use std::cmp::Ordering; use std::fmt::Debug; use crate::k_means::{KMeansParams, KMeansValidParams}; -use crate::IncrKMeansError; use crate::{k_means::errors::KMeansError, KMeansInit}; +use crate::{IncrKMeansError, KMeansAlgorithm, KMeansParamsError}; use linfa::{prelude::*, DatasetBase, Float}; use linfa_nn::distance::{Distance, L2Dist}; -use ndarray::{Array1, Array2, ArrayBase, Axis, Data, DataMut, Ix1, Ix2, Zip}; +use ndarray::{Array1, Array2, ArrayBase, ArrayView2, Axis, Data, DataMut, Ix1, Ix2, Zip}; use ndarray_rand::rand::{Rng, SeedableRng}; use rand_xoshiro::Xoshiro256Plus; @@ -29,12 +29,15 @@ use serde_crate::{Deserialize, Serialize}; /// /// We provide a modified version of the _standard algorithm_ (also known as Lloyd's Algorithm), /// called m_k-means, which uses a slightly modified update step to avoid problems with empty -/// clusters. We also provide an incremental version of the algorithm that runs on smaller batches -/// of input data. +/// clusters. In addition to Lloyd's algorithm, we also provide Hamerly's accelerated algorithm, +/// which produces the same results but skips many distance computations using the triangle +/// inequality. We also provide an incremental version of the algorithm that runs on smaller +/// batches of input data. /// /// More details on the algorithm can be found in the next section or /// [here](https://en.wikipedia.org/wiki/K-means_clustering). Details on m_k-means can be found /// [here](https://www.researchgate.net/publication/228414762_A_Modified_k-means_Algorithm_to_Avoid_Empty_Clusters). +/// Details on Hamerly's algorithm can be found [here](https://cs.baylor.edu/~hamerly/papers/sdm_2010.pdf). /// /// ## Standard algorithm /// @@ -54,6 +57,27 @@ use serde_crate::{Deserialize, Serialize}; /// euclidean distance between the old and the new clusters is below `tolerance` or /// we exceed the `max_n_iterations`). /// +/// ## Hamerly's algorithm +/// +/// Hamerly's algorithm is an exact accelerated variant of Lloyd's algorithm: given the same +/// initial centroids it converges to the same final centroids, but usually in a fraction of the +/// distance computations. For every observation it maintains an upper bound on the distance to +/// its currently assigned centroid and a lower bound on the distance to the closest +/// non-assigned centroid. At each iteration, these bounds together with the inter-centroid +/// distances are used to cheaply prove that an observation cannot have changed cluster, in +/// which case the exact distance is not recomputed at all. +/// +/// Hamerly is typically faster than Lloyd when clusters are reasonably well separated and the +/// number of clusters is moderate; when clusters overlap heavily or `n_clusters` is very large, +/// the bookkeeping overhead can outweigh the savings. Hamerly requires a true metric for its +/// triangle-inequality bounds to hold, so any custom distance function used with it must satisfy +/// the metric axioms (`L2Dist`, `L1Dist` and `LInfDist` all qualify). +/// +/// The algorithm variant is selected on [`KMeansParams`](crate::KMeansParams) via +/// [`algorithm`](crate::KMeansParams::algorithm) with a [`KMeansAlgorithm`](crate::KMeansAlgorithm) +/// value. Lloyd is the default; pass `KMeansAlgorithm::Hamerly` to opt in. Hamerly only affects +/// standard batch `fit`: the incremental `fit_with` path always uses Lloyd. +/// /// ## Incremental Algorithm /// /// In addition to the standard algorithm, we also provide an incremental version of K-means known @@ -216,20 +240,61 @@ impl> KMeans { } } -impl, T, D: Distance> - Fit, T, KMeansError> for KMeansValidParams -{ - type Object = KMeans; - - /// Given an input matrix `observations`, with shape `(n_observations, n_features)`, - /// `fit` identifies `n_clusters` centroids based on the training data distribution. +impl> KMeansValidParams { + /// Fit KMeans using Hamerly's accelerated algorithm. /// - /// An instance of `KMeans` is returned. - /// - fn fit( + /// Uses triangle inequality to skip unnecessary distance computations. + /// Reference: + fn fit_hamerly, T>( &self, dataset: &DatasetBase, T>, - ) -> Result { + ) -> Result, KMeansError> { + let mut rng = self.rng().clone(); + let observations = dataset.records().view(); + let mut min_inertia = F::infinity(); + let mut best_centroids = None; + let mut best_memberships = None; + + for _ in 0..self.n_runs() { + let centroids = + self.init_method() + .run(self.dist_fn(), self.n_clusters(), observations, &mut rng); + let mut hamerly = HamerlyAlgorithm::new(self.dist_fn(), observations, centroids); + + let mut n_iter = 0; + let inertia = loop { + // No need to reassign observations on first iteration + if n_iter > 0 { + hamerly.reassign_observations(); + } + n_iter += 1; + + let update = hamerly.recompute_centroids(); + + if update.convergence_dist < self.tolerance() || n_iter == self.max_n_iterations() { + break hamerly.inertia(); + } + + hamerly.update_bounds(&update.distances_moved); + }; + + if inertia < min_inertia { + min_inertia = inertia; + let (centroids, memberships) = hamerly.into_parts(); + best_centroids = Some(centroids); + best_memberships = Some(memberships); + } + } + + let memberships = best_memberships.unwrap_or_else(|| Array1::zeros(dataset.nsamples())); + self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships) + } + + /// Fit KMeans with Lloyd's algorithm. + fn fit_lloyd, T>( + &self, + dataset: &DatasetBase, T>, + ) -> Result, KMeansError> { let mut rng = self.rng().clone(); let observations = dataset.records().view(); let n_samples = dataset.nsamples(); @@ -274,6 +339,16 @@ impl, T, D: Distance> } } + self.get_kmeans_result(dataset, min_inertia, best_centroids, memberships) + } + + fn get_kmeans_result, T>( + &self, + dataset: &DatasetBase, T>, + min_inertia: F, + best_centroids: Option>, + memberships: Array1, + ) -> Result, KMeansError> { match best_centroids { Some(centroids) => { let mut cluster_count = Array1::zeros(self.n_clusters()); @@ -292,6 +367,250 @@ impl, T, D: Distance> } } +impl, T, D: Distance> + Fit, T, KMeansError> for KMeansValidParams +{ + type Object = KMeans; + + /// Given an input matrix `observations`, with shape `(n_observations, n_features)`, + /// `fit` identifies `n_clusters` centroids based on the training data distribution. + /// + /// An instance of `KMeans` is returned. + fn fit( + &self, + dataset: &DatasetBase, T>, + ) -> Result { + match self.algorithm() { + KMeansAlgorithm::Lloyd => self.fit_lloyd(dataset), + KMeansAlgorithm::Hamerly => self.fit_hamerly(dataset), + } + } +} + +struct CentroidUpdate { + distances_moved: Array1, + convergence_dist: F, +} + +/// Encapsulates all state and logic for a single Hamerly K-means run. +struct HamerlyAlgorithm<'a, F: Float, D: Distance> { + /// Distance metric used for all point-to-centroid comparisons. + dist_fn: &'a D, + /// Input data matrix, shape `(n_observations, n_features)`. + observations: ArrayView2<'a, F>, + /// Current centroid positions, shape `(n_clusters, n_features)`. + centroids: Array2, + /// Cluster index assigned to each observation. + memberships: Array1, + /// Per-observation upper bound on the distance to its assigned centroid. + upper_bounds: Array1, + /// Per-observation lower bound on the distance to the nearest non-assigned centroid. + lower_bounds: Array1, + /// Number of observations currently assigned to each centroid. + centroid_counts: Array1, + /// Running coordinate sum of observations per centroid, shape `(n_clusters, n_features)`. + centroid_sums: Array2, + /// Memberships before reassignment + prev_memberships: Array1, +} + +impl<'a, F: Float, D: Distance> HamerlyAlgorithm<'a, F, D> { + fn new(dist_fn: &'a D, observations: ArrayView2<'a, F>, centroids: Array2) -> Self { + let n_observations = observations.nrows(); + let mut memberships = Array1::zeros(n_observations); + let mut upper_bounds = Array1::zeros(n_observations); + let mut lower_bounds = Array1::zeros(n_observations); + + Zip::from(observations.rows()) + .and(&mut memberships) + .and(&mut upper_bounds) + .and(&mut lower_bounds) + .par_for_each(|obs, membership, upper, lower| { + let (idx, closest_dist, second_dist) = + two_closest_centroids(dist_fn, ¢roids, &obs); + *membership = idx; + *upper = closest_dist; + *lower = second_dist; + }); + + let mut centroid_counts: Array1 = Array1::zeros(centroids.nrows()); + let mut centroid_sums = Array2::zeros(centroids.dim()); + for (obs, &m) in observations.rows().into_iter().zip(memberships.iter()) { + centroid_counts[m] += 1; + let mut row = centroid_sums.row_mut(m); + row += &obs; + } + + let prev_memberships = Array1::zeros(n_observations); + + Self { + dist_fn, + observations, + centroids, + memberships, + upper_bounds, + lower_bounds, + centroid_counts, + centroid_sums, + prev_memberships, + } + } + + fn nearest_inter_centroid_distances(&self) -> Array1 { + let mut dists = Array1::zeros(self.centroids.nrows()); + for (i, centroid) in self.centroids.rows().into_iter().enumerate() { + let (_, _, second_dist) = + two_closest_centroids(self.dist_fn, &self.centroids, ¢roid); + dists[i] = second_dist; + } + dists + } + + fn reassign_observations(&mut self) { + let nearest_center_dists = self.nearest_inter_centroid_distances(); + let centroids = &self.centroids; + let observations = self.observations; + let dist_fn = self.dist_fn; + + Zip::from(observations.rows()) + .and(&mut self.memberships) + .and(&mut self.upper_bounds) + .and(&mut self.lower_bounds) + .and(&mut self.prev_memberships) + .par_for_each(|obs, membership, upper, lower, prev_slot| { + let current = *membership; + *prev_slot = current; + let threshold = F::max(nearest_center_dists[current] / F::cast(2), *lower); + + if *upper > threshold { + *upper = dist_fn.distance(obs.view(), centroids.row(current).view()); + + if *upper > threshold { + let (idx, closest_dist, second_dist) = + two_closest_centroids(dist_fn, centroids, &obs); + *membership = idx; + *upper = closest_dist; + *lower = second_dist; + } + } + }); + + for (i, (&old_membership, &new_membership)) in self + .prev_memberships + .iter() + .zip(self.memberships.iter()) + .enumerate() + { + if old_membership != new_membership { + let observation = self.observations.row(i); + self.centroid_counts[old_membership] -= 1; + self.centroid_counts[new_membership] += 1; + let mut old_centroid_sum = self.centroid_sums.row_mut(old_membership); + old_centroid_sum -= &observation; + let mut new_centroid_sum = self.centroid_sums.row_mut(new_membership); + new_centroid_sum += &observation; + } + } + } + + /// Recomputes centroids from accumulated centroid sums and counts + fn recompute_centroids(&mut self) -> CentroidUpdate { + // m_k-means trick: The old centroid is treated as an extra point in each cluster as is done in Lloyd + let mut new_centroids = &self.centroid_sums + &self.centroids; + Zip::from(new_centroids.rows_mut()) + .and(&self.centroid_counts) + .for_each(|mut centroid_sum, &n_members| { + // + 1 because we have added old centroid as an extra point + centroid_sum /= F::cast(n_members + 1); + }); + + let mut distances_moved = Array1::zeros(self.centroids.nrows()); + Zip::from(&mut distances_moved) + .and(self.centroids.rows()) + .and(new_centroids.rows()) + .for_each(|d, old, new| *d = self.dist_fn.distance(old, new)); + + let convergence_dist = self + .dist_fn + .distance(self.centroids.view(), new_centroids.view()); + self.centroids = new_centroids; + + CentroidUpdate { + distances_moved, + convergence_dist, + } + } + + fn update_bounds(&mut self, distances_moved: &Array1) { + let (farthest_moved_idx, second_farthest_moved_idx) = two_farthest_indices(distances_moved); + Zip::from(&self.memberships) + .and(&mut self.upper_bounds) + .and(&mut self.lower_bounds) + .par_for_each(|¢roid_idx, upper, lower| { + *upper += distances_moved[centroid_idx]; + if centroid_idx == farthest_moved_idx { + *lower -= distances_moved[second_farthest_moved_idx]; + } else { + *lower -= distances_moved[farthest_moved_idx]; + } + }); + } + + fn inertia(&self) -> F { + compute_inertia( + self.dist_fn, + self.observations, + &self.memberships, + &self.centroids, + ) + } + + fn into_parts(self) -> (Array2, Array1) { + (self.centroids, self.memberships) + } +} + +/// Returns the indices of the two centroids that moved the farthest. +/// +/// For fewer than two elements the second index duplicates the first; callers +/// only read `second_farthest` when an observation's own centroid is the +/// farthest mover, which cannot happen when there is only one centroid. +fn two_farthest_indices(distances: &Array1) -> (usize, usize) { + if distances.len() < 2 { + return (0, 0); + } + let (mut farthest, mut second_farthest) = if distances[1] >= distances[0] { + (1, 0) + } else { + (0, 1) + }; + for i in 2..distances.len() { + if distances[i] >= distances[farthest] { + second_farthest = farthest; + farthest = i; + } else if distances[i] > distances[second_farthest] { + second_farthest = i; + } + } + (farthest, second_farthest) +} + +/// Computes total inertia: sum of squared distances from each observation to +/// its assigned centroid. +fn compute_inertia>( + dist_fn: &D, + observations: ArrayView2, + memberships: &Array1, + centroids: &Array2, +) -> F { + observations + .rows() + .into_iter() + .zip(memberships.iter()) + .map(|(obs, &m)| dist_fn.rdistance(obs.view(), centroids.row(m).view())) + .fold(F::zero(), |acc, d| acc + d) +} + impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data, T, D: 'a + Distance + Debug> FitWith<'a, ArrayBase, T, IncrKMeansError>> for KMeansValidParams @@ -306,11 +625,23 @@ impl<'a, F: Float + Debug, R: Rng + Clone, DA: Data, T, D: 'a + Distan /// `None`, then it's initialized using the specified initialization algorithm. The return /// value consists of the updated model and a `bool` value that indicates whether the algorithm /// has converged. + /// + /// Only [`KMeansAlgorithm::Lloyd`](crate::KMeansAlgorithm::Lloyd) is supported here: the + /// Mini-Batch path always uses Lloyd's update. Configuring + /// [`KMeansAlgorithm::Hamerly`](crate::KMeansAlgorithm::Hamerly) and then calling + /// `fit_with` returns [`KMeansParamsError::IncrementalHamerly`], because Hamerly's + /// per-observation bounds rely on a persistent dataset across iterations and cannot + /// amortise across independent Mini-Batch batches. fn fit_with( &self, model: Self::ObjectIn, dataset: &'a DatasetBase, T>, ) -> Result> { + if *self.algorithm() == KMeansAlgorithm::Hamerly { + return Err(IncrKMeansError::InvalidParams( + KMeansParamsError::IncrementalHamerly, + )); + } let observations = dataset.records().view(); let n_samples = dataset.nsamples(); @@ -531,7 +862,7 @@ pub(crate) fn update_min_dists>( }); } -// Efficient combination of `update_cluster_memberships` and `update_min_dists`. +/// Efficient combination of `update_cluster_memberships` and `update_min_dists`. pub(crate) fn update_memberships_and_dists>( dist_fn: &D, centroids: &ArrayBase + Sync, Ix2>, @@ -549,6 +880,44 @@ pub(crate) fn update_memberships_and_dists>( }); } +/// Given a matrix of centroids with shape (n_centroids, n_features) and an observation, +/// return the index of the two closest centroids (the index of the corresponding row in `centroids`) +/// and their distances. +/// +/// Uses `distance` (not `rdistance`) because Hamerly's triangle-inequality bounds +/// only hold under a true metric — do not "optimize" this to squared distance. +fn two_closest_centroids>( + dist_fn: &D, + // (n_centroids, n_features) + centroids: &ArrayBase, Ix2>, + // (n_features) + observation: &ArrayBase, Ix1>, +) -> (usize, F, F) { + if centroids.nrows() == 1 { + return (0, F::cast(0), F::cast(0)); + } + let first_centroid = centroids.row(0); + let second_centroid = centroids.row(1); + let dist1 = dist_fn.distance(observation.view(), first_centroid.view()); + let dist2 = dist_fn.distance(observation.view(), second_centroid.view()); + + let mut closest_index = if dist1 < dist2 { 0 } else { 1 }; + let mut closest_distance = if dist1 < dist2 { dist1 } else { dist2 }; + let mut second_closest_distance = if dist1 < dist2 { dist2 } else { dist1 }; + + for (centroid_index, centroid) in centroids.rows().into_iter().skip(2).enumerate() { + let distance = dist_fn.distance(observation.view(), centroid.view()); + if closest_distance <= distance && distance < second_closest_distance { + second_closest_distance = distance; + } else if distance < closest_distance { + second_closest_distance = closest_distance; + closest_index = centroid_index + 2; // We skipped 2 centroids + closest_distance = distance; + } + } + (closest_index, closest_distance, second_closest_distance) +} + /// Given a matrix of centroids with shape (n_centroids, n_features) and an observation, /// return the index of the closest centroid (the index of the corresponding row in `centroids`). pub(crate) fn closest_centroid>( @@ -593,6 +962,7 @@ mod tests { fn autotraits() { fn has_autotraits() {} has_autotraits::>(); + has_autotraits::(); has_autotraits::(); has_autotraits::(); has_autotraits::>(); @@ -831,6 +1201,22 @@ mod tests { ); } + #[test] + fn fit_with_rejects_hamerly() { + let rng = Xoshiro256Plus::seed_from_u64(45); + let params = KMeans::params_with_rng(2, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(array![[0., 0.], [10., 10.]])); + let data = DatasetBase::from(array![[1., 1.], [11., 11.]]); + let err = params + .fit_with(None, &data) + .expect_err("Hamerly + fit_with must be rejected"); + assert!(matches!( + err, + IncrKMeansError::InvalidParams(KMeansParamsError::IncrementalHamerly) + )); + } + #[test] fn test_tolerance() { let rng = Xoshiro256Plus::seed_from_u64(45); @@ -861,6 +1247,517 @@ mod tests { .expect("KMeans fitted"); } + fn sort_centroids(c: &Array2) -> Array2 { + let mut rows: Vec> = c.rows().into_iter().map(|r| r.to_vec()).collect(); + rows.sort_by(|a, b| { + for (x, y) in a.iter().zip(b.iter()) { + match x.partial_cmp(y) { + Some(std::cmp::Ordering::Equal) => continue, + Some(ord) => return ord, + None => continue, + } + } + std::cmp::Ordering::Equal + }); + let flat: Vec = rows.into_iter().flatten().collect(); + Array2::from_shape_vec((c.nrows(), c.ncols()), flat).unwrap() + } + + fn hamerly_lloyd_equivalence>(dist_fn: D, init: KMeansInit) { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(6, rng.clone(), dist_fn.clone()) + .n_runs(3) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(init.clone()) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(6, rng.clone(), dist_fn) + .n_runs(3) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_eq!(model_lloyd.centroids().nrows(), 6); + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-4 + ); + assert_abs_diff_eq!( + sort_centroids(model_lloyd.centroids()), + sort_centroids(model_hamerly.centroids()), + epsilon = 1e-4 + ); + } + + #[test] + fn hamerly_lloyd_equivalence_random_l2() { + hamerly_lloyd_equivalence(L2Dist, KMeansInit::Random); + } + + #[test] + fn hamerly_lloyd_equivalence_plusplus_l2() { + hamerly_lloyd_equivalence(L2Dist, KMeansInit::KMeansPlusPlus); + } + + fn hamerly_lloyd_equivalence_para>(dist_fn: D) { + // KMeansPara uses Rayon parallelism and is non-deterministic across concurrent test + // runs. Pre-compute centroids deterministically and pass them as Precomputed so + // both Lloyd and Hamerly start from the same initial centroids. + let mut rng = Xoshiro256Plus::seed_from_u64(99); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + let init = KMeansInit::Precomputed(KMeansInit::KMeansPlusPlus.run( + &dist_fn, + 6, + dataset.records().view(), + &mut rng, + )); + hamerly_lloyd_equivalence(dist_fn, init); + } + + #[test] + fn hamerly_lloyd_equivalence_para_l2() { + hamerly_lloyd_equivalence_para(L2Dist); + } + + #[test] + fn hamerly_lloyd_equivalence_random_l1() { + hamerly_lloyd_equivalence(L1Dist, KMeansInit::Random); + } + + #[test] + fn hamerly_lloyd_equivalence_plusplus_l1() { + hamerly_lloyd_equivalence(L1Dist, KMeansInit::KMeansPlusPlus); + } + + #[test] + fn hamerly_lloyd_equivalence_para_l1() { + hamerly_lloyd_equivalence_para(L1Dist); + } + + #[test] + fn test_two_closest_centroids_l2() { + let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10); + assert_abs_diff_eq!(second, f64::sqrt(82.0), epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_l1() { + let centroids = array![[0.0, 0.0], [10.0, 0.0], [0.0, 10.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L1Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, 2.0, epsilon = 1e-10); + assert_abs_diff_eq!(second, 10.0, epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_single() { + let centroids = array![[5.0, 5.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 0); + assert_abs_diff_eq!(closest, 0.0); + assert_abs_diff_eq!(second, 0.0); + } + + #[test] + fn test_two_closest_centroids_obs_is_centroid() { + let centroids = array![[0.0, 0.0], [3.0, 4.0], [10.0, 0.0]]; + let obs = array![3.0, 4.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + assert_eq!(idx, 1); + assert_abs_diff_eq!(closest, 0.0, epsilon = 1e-10); + assert_abs_diff_eq!(second, 5.0, epsilon = 1e-10); + } + + #[test] + fn test_two_closest_centroids_equidistant() { + let centroids = array![[2.0, 0.0], [0.0, 2.0]]; + let obs = array![1.0, 1.0]; + let (idx, closest, second) = two_closest_centroids(&L2Dist, ¢roids, &obs); + // When equidistant, index 1 is chosen because `if dist1 < dist2` is false + assert_eq!(idx, 1); + assert_abs_diff_eq!(closest, f64::sqrt(2.0), epsilon = 1e-10); + assert_abs_diff_eq!(second, f64::sqrt(2.0), epsilon = 1e-10); + } + + #[test] + fn test_two_farthest_indices() { + // Distinct values + assert_eq!(two_farthest_indices(&array![1.0, 5.0, 3.0, 2.0]), (1, 2)); + + // All equal: repeated >= swaps chain through all indices + assert_eq!(two_farthest_indices(&array![3.0, 3.0, 3.0]), (2, 1)); + + // Two elements + assert_eq!(two_farthest_indices(&array![2.0, 7.0]), (1, 0)); + assert_eq!(two_farthest_indices(&array![7.0, 2.0]), (0, 1)); + + // Largest at end + assert_eq!(two_farthest_indices(&array![8.0, 1.0, 2.0, 9.0]), (3, 0)); + + // Largest at start: second must be the actual runner-up + assert_eq!(two_farthest_indices(&array![9.0, 1.0, 2.0, 8.0]), (0, 3)); + + // Single element degenerates to (0, 0) + assert_eq!(two_farthest_indices(&array![1.0]), (0, 0)); + } + + #[test] + fn test_recompute_centroids() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [0.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + // m_k-means: new = (sums + old) / (counts + 1) = [8/4, 12/4], [15/3, 30/3] + hamerly.centroid_sums = array![[8.0, 12.0], [15.0, 30.0]]; + hamerly.centroid_counts = array![3_usize, 2]; + hamerly.recompute_centroids(); + assert_abs_diff_eq!( + hamerly.centroids, + array![[2.0, 3.0], [5.0, 10.0]], + epsilon = 1e-10 + ); + + // Empty cluster: (0 + old) / (0 + 1) = old, so the centroid is preserved. + let centroids2 = array![[7.0, 7.0], [0.0, 0.0]]; + let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + hamerly2.centroid_sums = array![[0.0, 0.0], [15.0, 30.0]]; + hamerly2.centroid_counts = array![0_usize, 2]; + hamerly2.recompute_centroids(); + assert_abs_diff_eq!( + hamerly2.centroids, + array![[7.0, 7.0], [5.0, 10.0]], + epsilon = 1e-10 + ); + } + + #[test] + fn test_recompute_centroids_distances_moved() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [10.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + // m_k-means: new = (sums + old) / (counts + 1) = [2/2, 0/2], [20/2, 6/2] + // → [1.0, 0.0] and [10.0, 3.0], moved 1.0 and 3.0 respectively + hamerly.centroid_sums = array![[2.0, 0.0], [10.0, 6.0]]; + hamerly.centroid_counts = array![1_usize, 1]; + let update = hamerly.recompute_centroids(); + assert_abs_diff_eq!(update.distances_moved, array![1.0, 3.0], epsilon = 1e-10); + + // No movement + let centroids2 = array![[5.0, 5.0], [10.0, 10.0]]; + let mut hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + hamerly2.centroid_sums = array![[5.0, 5.0], [10.0, 10.0]]; + hamerly2.centroid_counts = array![1_usize, 1]; + let update2 = hamerly2.recompute_centroids(); + assert_abs_diff_eq!(update2.distances_moved, array![0.0, 0.0], epsilon = 1e-10); + } + + #[test] + fn test_nearest_inter_centroid_distances() { + let obs = array![[0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [3.0, 0.0], [0.0, 4.0]]; + let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + let dists = hamerly.nearest_inter_centroid_distances(); + assert_abs_diff_eq!(dists, array![3.0, 3.0, 4.0], epsilon = 1e-10); + + // Two centroids: symmetric + let centroids2 = array![[0.0, 0.0], [5.0, 0.0]]; + let hamerly2 = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids2); + let dists2 = hamerly2.nearest_inter_centroid_distances(); + assert_abs_diff_eq!(dists2, array![5.0, 5.0], epsilon = 1e-10); + } + + #[test] + fn test_hamerly_strategy_new() { + let obs = array![[0.0, 0.0], [1.0, 0.0], [10.0, 10.0]]; + let centroids = array![[0.0, 0.0], [10.0, 10.0]]; + let hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + assert_eq!(hamerly.memberships, array![0_usize, 0, 1]); + assert_eq!(hamerly.centroid_counts, array![2_usize, 1]); + assert_abs_diff_eq!( + hamerly.centroid_sums, + array![[1.0, 0.0], [10.0, 10.0]], + epsilon = 1e-10 + ); + } + + #[test] + fn test_update_bounds_oracle() { + let obs = array![[0.0, 0.0], [10.0, 0.0], [0.0, 0.0]]; + let centroids = array![[0.0, 0.0], [10.0, 0.0]]; + let mut hamerly = HamerlyAlgorithm::new(&L2Dist, obs.view(), centroids); + hamerly.memberships = array![0_usize, 1, 0]; + hamerly.upper_bounds = array![5.0, 3.0, 4.0]; + hamerly.lower_bounds = array![2.0, 1.0, 3.0]; + let distances_moved = array![1.0, 0.5]; + hamerly.update_bounds(&distances_moved); + assert_abs_diff_eq!(hamerly.upper_bounds, array![6.0, 3.5, 5.0], epsilon = 1e-10); + assert_abs_diff_eq!(hamerly.lower_bounds, array![1.5, 0.0, 2.5], epsilon = 1e-10); + } + + #[test] + fn test_compute_inertia() { + let obs = array![[0.0, 0.0], [3.0, 4.0]]; + let memberships = array![0_usize, 0]; + let centroids = array![[1.0, 1.0]]; + let inertia = compute_inertia(&L2Dist, obs.view(), &memberships, ¢roids); + // rdistance: (0-1)^2+(0-1)^2 + (3-1)^2+(4-1)^2 = 2 + 13 = 15 + assert_abs_diff_eq!(inertia, 15.0, epsilon = 1e-10); + } + + fn test_n_runs_hamerly>(dist_fn: D) { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + + for init in &[ + KMeansInit::Random, + KMeansInit::KMeansPlusPlus, + KMeansInit::KMeansPara, + ] { + let dataset = DatasetBase::from(data.clone()); + let model = KMeans::params_with(3, rng.clone(), dist_fn.clone()) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init.clone()) + .fit(&dataset) + .expect("KMeans fitted"); + let clusters = model.predict(dataset); + let inertia = calc_inertia!( + dist_fn, + model.centroids(), + clusters.records, + clusters.targets + ); + let total_dist = model.transform(&clusters.records.view()).sum(); + assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5); + + let single_cluster: usize = model.predict(&data.row(0)); + assert_abs_diff_eq!(single_cluster, clusters.targets[0]); + + let dataset2 = DatasetBase::from(clusters.records().clone()); + let model2 = KMeans::params_with(3, rng.clone(), dist_fn.clone()) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(init.clone()) + .fit(&dataset2) + .expect("KMeans fitted"); + let clusters2 = model2.predict(dataset2); + let inertia2 = calc_inertia!( + dist_fn, + model2.centroids(), + clusters2.records, + clusters2.targets + ); + let total_dist2 = model2.transform(&clusters2.records.view()).sum(); + assert_abs_diff_eq!(inertia2, total_dist2, epsilon = 1e-5); + + if *init == KMeansInit::Random { + assert!(inertia2 <= inertia); + } + } + } + + #[test] + fn test_n_runs_hamerly_l2dist() { + test_n_runs_hamerly(L2Dist); + } + + #[test] + fn test_n_runs_hamerly_l1dist() { + test_n_runs_hamerly(L1Dist); + } + + #[test] + fn test_hamerly_precomputed_centroids() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![ + [0.0, 0.0], + [1.0, 0.0], + [0.0, 1.0], + [10.0, 10.0], + [11.0, 10.0], + [10.0, 11.0] + ]; + let init_centroids = array![[0.0, 0.0], [10.0, 10.0]]; + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(2, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(KMeansInit::Precomputed(init_centroids.clone())) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(2, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(init_centroids)) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_abs_diff_eq!( + model_lloyd.centroids(), + model_hamerly.centroids(), + epsilon = 1e-1 + ); + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-1 + ); + } + + #[test] + fn test_hamerly_single_cluster() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[1.0, 2.0], [3.0, 4.0], [5.0, 6.0], [7.0, 8.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[4.0, 5.0]], epsilon = 1e-4); + } + + #[test] + fn test_hamerly_n_clusters_eq_n_samples() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[1.0, 2.0], [10.0, 20.0], [-5.0, -5.0], [100.0, 0.0]]; + let dataset = DatasetBase::from(data.clone()); + let model = KMeans::params_with_rng(4, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(data)) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_single_observation() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[3.0, 7.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[3.0, 7.0]], epsilon = 1e-10); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_identical_data() { + let rng = Xoshiro256Plus::seed_from_u64(42); + let data = array![[5.0, 5.0], [5.0, 5.0], [5.0, 5.0], [5.0, 5.0]]; + let dataset = DatasetBase::from(data); + let model = KMeans::params_with_rng(1, rng) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[5.0, 5.0]], epsilon = 1e-10); + assert_abs_diff_eq!(model.inertia(), 0.0, epsilon = 1e-10); + } + + #[test] + fn test_hamerly_high_dimensionality() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let data: Array2 = Array::random_using((200, 50), Uniform::new(-100., 100.), &mut rng); + let dataset = DatasetBase::from(data); + + let model_lloyd = KMeans::params_with(5, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Lloyd) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("Lloyd fitted"); + let model_hamerly = KMeans::params_with(5, rng.clone(), L2Dist) + .n_runs(1) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("Hamerly fitted"); + + assert_abs_diff_eq!( + model_lloyd.inertia(), + model_hamerly.inertia(), + epsilon = 1e-5 + ); + assert_abs_diff_eq!( + model_lloyd.centroids(), + model_hamerly.centroids(), + epsilon = 1e-5 + ); + } + + #[test] + fn test_hamerly_max_n_iterations() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + let _model = KMeans::params_with(6, rng.clone(), L2Dist) + .n_runs(1) + .max_n_iterations(5) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Random) + .fit(&dataset) + .expect("KMeans fitted"); + } + + #[test] + fn test_hamerly_tolerance() { + let rng = Xoshiro256Plus::seed_from_u64(45); + let data = DatasetBase::from(array![[1., 1.], [11., 11.]]); + let model = KMeans::params_with_rng(1, rng) + .tolerance(8.5) + .algorithm(KMeansAlgorithm::Hamerly) + .init_method(KMeansInit::Precomputed(array![[0., 0.]])) + .fit(&data) + .expect("KMeans fitted"); + assert_abs_diff_eq!(model.centroids(), &array![[4., 4.]], epsilon = 1e-1); + } + + #[test] + fn test_hamerly_predict_transform_consistency() { + let mut rng = Xoshiro256Plus::seed_from_u64(42); + let xt = Array::random_using(100, Uniform::new(0., 1.0), &mut rng).insert_axis(Axis(1)); + let yt = function_test_1d(&xt); + let data = concatenate(Axis(1), &[xt.view(), yt.view()]).unwrap(); + let dataset = DatasetBase::from(data); + + let model = KMeans::params_with(3, rng.clone(), L2Dist) + .algorithm(KMeansAlgorithm::Hamerly) + .fit(&dataset) + .expect("Hamerly fitted"); + + let clusters = model.predict(dataset); + assert!(clusters.targets.iter().all(|&c| c < 3)); + + let inertia = calc_inertia!( + L2Dist, + model.centroids(), + clusters.records, + clusters.targets + ); + let total_dist = model.transform(&clusters.records.view()).sum(); + assert_abs_diff_eq!(inertia, total_dist, epsilon = 1e-5); + } + fn fittable, (), KMeansError>>(_: T) {} #[test] fn thread_rng_fittable() { diff --git a/algorithms/linfa-clustering/src/k_means/errors.rs b/algorithms/linfa-clustering/src/k_means/errors.rs index bcc26b569..d675ea8c5 100644 --- a/algorithms/linfa-clustering/src/k_means/errors.rs +++ b/algorithms/linfa-clustering/src/k_means/errors.rs @@ -11,6 +11,11 @@ pub enum KMeansParamsError { Tolerance, #[error("max_n_iterations cannot be 0")] MaxIterations, + #[error( + "only KMeansAlgorithm::Lloyd is supported by fit_with (Mini-Batch K-means); \ + Hamerly requires a persistent dataset across iterations and cannot be used incrementally" + )] + IncrementalHamerly, } /// An error when modeling a KMeans algorithm diff --git a/algorithms/linfa-clustering/src/k_means/hyperparams.rs b/algorithms/linfa-clustering/src/k_means/hyperparams.rs index 52b0e2a93..84b8e0650 100644 --- a/algorithms/linfa-clustering/src/k_means/hyperparams.rs +++ b/algorithms/linfa-clustering/src/k_means/hyperparams.rs @@ -1,4 +1,4 @@ -use crate::KMeansParamsError; +use crate::{KMeansAlgorithm, KMeansParamsError}; use super::init::KMeansInit; use linfa::prelude::*; @@ -35,6 +35,8 @@ pub struct KMeansValidParams> { rng: R, /// Distance metric used in the centroid assignment step dist_fn: D, + /// Algorithm variant used for the assignment step + algorithm: KMeansAlgorithm, } #[derive(Clone, Debug, PartialEq)] @@ -75,6 +77,7 @@ impl> KMeansParams { init: KMeansInit::KMeansPlusPlus, rng, dist_fn, + algorithm: KMeansAlgorithm::Lloyd, }) } @@ -101,6 +104,17 @@ impl> KMeansParams { self.0.init = init; self } + + /// Select the variant used for the assignment step. + /// + /// See [`KMeansAlgorithm`] for the available variants and when to prefer each. + /// Defaults to [`KMeansAlgorithm::Lloyd`]. This setting only affects batch `fit`; + /// `fit_with` (Mini-Batch K-means) always uses Lloyd's update and will reject + /// `Hamerly` with [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly). + pub fn algorithm(mut self, algorithm: KMeansAlgorithm) -> Self { + self.0.algorithm = algorithm; + self + } } impl> ParamGuard for KMeansParams { @@ -166,6 +180,11 @@ impl> KMeansValidParams { pub fn dist_fn(&self) -> &D { &self.dist_fn } + + /// The [`KMeansAlgorithm`] variant used by batch `fit` for the assignment step. + pub fn algorithm(&self) -> &KMeansAlgorithm { + &self.algorithm + } } #[cfg(test)] diff --git a/algorithms/linfa-clustering/src/k_means/init.rs b/algorithms/linfa-clustering/src/k_means/init.rs index 723bd1d8f..714496756 100644 --- a/algorithms/linfa-clustering/src/k_means/init.rs +++ b/algorithms/linfa-clustering/src/k_means/init.rs @@ -34,6 +34,50 @@ pub enum KMeansInit { KMeansPara, } +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Clone, Copy, Debug, PartialEq, Eq)] +#[non_exhaustive] +/// Specifies the algorithm used for the KMeans assignment step. +/// +/// Both variants minimise the same objective and, given identical initial centroids, +/// converge to the same result. They only differ in how the assignment step is computed. +/// Select a variant via [`KMeansParams::algorithm`](crate::KMeansParams::algorithm). +/// +/// This setting only applies to batch `fit`. The incremental Mini-Batch K-means path +/// (`fit_with`) always uses Lloyd's update, and configuring `Hamerly` alongside +/// `fit_with` is rejected with +/// [`KMeansParamsError::IncrementalHamerly`](crate::KMeansParamsError::IncrementalHamerly). +pub enum KMeansAlgorithm { + /// Standard Lloyd's algorithm (also known as the "naive" algorithm). + /// + /// On every iteration, computes the distance from each observation to every centroid + /// to determine the closest one. Simple and predictable; work per iteration is + /// `O(n_observations * n_clusters * n_features)`. + /// + /// Default variant. Works with any [`Distance`](linfa_nn::distance::Distance). + Lloyd, + /// Hamerly's accelerated algorithm. + /// + /// Uses the triangle inequality together with per-observation upper/lower distance + /// bounds to skip most distance computations once the algorithm has stabilised. + /// Produces the same result as Lloyd's algorithm given the same initial centroids, + /// and is typically substantially faster for well-separated clusters with a moderate + /// number of centroids. For heavily overlapping clusters or very large `n_clusters` + /// the bookkeeping overhead can make Lloyd a better choice. + /// + /// Because the bounds rely on the triangle inequality, the supplied distance + /// function must be a true metric. `L2Dist`, `L1Dist` and `LInfDist` satisfy this. + /// + /// Only supported in batch `fit`; not available for Mini-Batch `fit_with`. + /// + /// Reference: + Hamerly, +} + impl KMeansInit { /// Runs the chosen initialization routine pub(crate) fn run>(