Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 64 additions & 1 deletion algorithms/linfa-nn/src/distance.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,27 @@ impl<F: Float> Distance<F> for LpDist<F> {
}
}

/// Wasserstein or [Earth Mover's](https://en.wikipedia.org/wiki/Earth_mover%27s_distance) distance
#[cfg_attr(
feature = "serde",
derive(Serialize, Deserialize),
serde(crate = "serde_crate")
)]
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EarthMoverDist;
impl<F: Float> Distance<F> for EarthMoverDist {
#[inline]
fn distance<D: Dimension>(&self, a: ArrayView<F, D>, b: ArrayView<F, D>) -> F {
let mut cumulative_diff = F::zero();
let mut emd = F::zero();
Zip::from(&a).and(&b).for_each(|&a, &b| {
cumulative_diff += a - b;
emd += cumulative_diff.abs()
});
emd
}
}

/// Computes a similarity matrix with gaussian kernel and scaling parameter `eps`
///
/// The generated matrix is a upper triangular matrix with dimension NxN (number of observations) and contains the similarity between all permutations of observations
Expand Down Expand Up @@ -146,7 +167,7 @@ pub fn to_gaussian_similarity<F: Float>(
#[cfg(test)]
mod test {
use approx::assert_abs_diff_eq;
use ndarray::arr1;
use ndarray::{arr1, arr2};

use super::*;

Expand All @@ -157,6 +178,7 @@ mod test {
has_autotraits::<L2Dist>();
has_autotraits::<LInfDist>();
has_autotraits::<LpDist<f64>>();
has_autotraits::<EarthMoverDist>();
}

fn dist_test<D: Distance<f64>>(dist: D, result: f64) {
Expand Down Expand Up @@ -204,4 +226,45 @@ mod test {
fn lp_dist() {
dist_test(LpDist(3.3), 4.635);
}

#[test]
fn emd_dist() {
dist_test(EarthMoverDist, 4.2);

let dist = EarthMoverDist;
let a = arr1(&[0.5, 0.5]);
let b = arr1(&[0.3, 0.7]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.2, epsilon = 1e-5);
assert_abs_diff_eq!(dist.rdist_to_dist(dist.dist_to_rdist(ab)), ab);

let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]);
let b = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.0, epsilon = 1e-5);

let a = arr1(&[0.3, 0.2, 0.1, 0.15, 0.25]);
let b = arr1(&[0.1, 0.2, 0.1, 0.15, 0.45]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.8, epsilon = 1e-5);

let a = arr1(&[0.3, 0.2, 0.15, 0.10, 0.25]);
let b = arr1(&[0.1, 0.2, 0.05, 0.20, 0.45]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.9, epsilon = 1e-5);

let a = arr1(&[0.35, 0.15, 0.15, 0.10, 0.25]);
let b = arr1(&[0.1, 0.20, 0.05, 0.20, 0.45]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.95, epsilon = 1e-5);

let a = arr2(&[[0.3, 0.2, 0.15, 0.10, 0.25], [0.35, 0.15, 0.15, 0.10, 0.25]]);
let b = arr2(&[[0.1, 0.2, 0.05, 0.20, 0.45], [0.1, 0.20, 0.05, 0.20, 0.45]]);
let ab = dist.distance(a.view(), b.view());
assert_abs_diff_eq!(ab, 0.9 + 0.95, epsilon = 1e-5);

let a = arr1(&[f64::INFINITY, 6.6]);
let b = arr1(&[4.4, f64::NEG_INFINITY]);
assert!(dist.distance(a.view(), b.view()).is_infinite());
}
}
Loading