diff --git a/algorithms/linfa-nn/src/distance.rs b/algorithms/linfa-nn/src/distance.rs index de4af9531..0272730c0 100644 --- a/algorithms/linfa-nn/src/distance.rs +++ b/algorithms/linfa-nn/src/distance.rs @@ -118,6 +118,27 @@ impl Distance for LpDist { } } +/// Wasserstein or [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance +#[cfg_attr( + feature = "serde", + derive(Serialize, Deserialize), + serde(crate = "serde_crate") +)] +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct EarthMoverDist; +impl Distance for EarthMoverDist { + #[inline] + fn distance(&self, a: ArrayView, b: ArrayView) -> F { + let mut cumulative_diff = F::zero(); + let mut emd = F::zero(); + Zip::from(&a).and(&b).for_each(|&a, &b| { + cumulative_diff += a - b; + emd += cumulative_diff.abs() + }); + emd + } +} + /// Computes a similarity matrix with gaussian kernel and scaling parameter `eps` /// /// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations @@ -146,7 +167,7 @@ pub fn to_gaussian_similarity( #[cfg(test)] mod test { use approx::assert_abs_diff_eq; - use ndarray::arr1; + use ndarray::{arr1, arr2}; use super::*; @@ -157,6 +178,7 @@ mod test { has_autotraits::(); has_autotraits::(); has_autotraits::>(); + has_autotraits::(); } fn dist_test>(dist: D, result: f64) { @@ -204,4 +226,45 @@ mod test { fn lp_dist() { dist_test(LpDist(3.3), 4.635); } + + #[test] + fn emd_dist() { + dist_test(EarthMoverDist, 4.2); + + let dist = EarthMoverDist; + let a = arr1(&[0.5, 0.5]); + let b = arr1(&[0.3, 0.7]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.2, epsilon = 1e-5); + assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab); + + let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let b = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.0, epsilon = 1e-5); + + let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]); + let b = arr1(&[0.1, 0.2, 0.1, 0.15, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.8, epsilon = 1e-5); + + let a = arr1(&[0.3, 0.2, 0.15, 0.10, 0.25]); + let b = arr1(&[0.1, 0.2, 0.05, 0.20, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.9, epsilon = 1e-5); + + let a = arr1(&[0.35, 0.15, 0.15, 0.10, 0.25]); + let b = arr1(&[0.1, 0.20, 0.05, 0.20, 0.45]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.95, epsilon = 1e-5); + + let a = arr2(&[[0.3, 0.2, 0.15, 0.10, 0.25], [0.35, 0.15, 0.15, 0.10, 0.25]]); + let b = arr2(&[[0.1, 0.2, 0.05, 0.20, 0.45], [0.1, 0.20, 0.05, 0.20, 0.45]]); + let ab = dist.distance(a.view(), b.view()); + assert_abs_diff_eq!(ab, 0.9 + 0.95, epsilon = 1e-5); + + let a = arr1(&[f64::INFINITY, 6.6]); + let b = arr1(&[4.4, f64::NEG_INFINITY]); + assert!(dist.distance(a.view(), b.view()).is_infinite()); + } }