---
title: "Computing the Wasserstein distance"
bibliography: ../inst/REFERENCES.bib
output: 
  rmarkdown::html_vignette:
    toc: true
    toc_depth: 2
    number_sections: false
    fig_caption: true
    mathjax: default
vignette: >
  %\VignetteIndexEntry{Computing the Wasserstein distance}
  %\VignetteEngine{knitr::rmarkdown}
  %\VignetteEncoding{UTF-8}
---
```{r init, eval=TRUE, echo=FALSE}
if (!requireNamespace("mlbench", quietly = TRUE)) {
  knitr::opts_chunk$set(eval = FALSE)
  message("Install 'mlbench' to run the code in this vignette.")
}
```
```{r setup, include=FALSE, echo=FALSE}
knitr::opts_chunk$set(cache = TRUE)
library(mlbench)
library(T4transport)
set.seed(10)
```

# Introduction


The Wasserstein distance [@villani_2003_TopicsOptimalTransportation] provides a natural and geometrically meaningful way to compare probability distributions. Unlike traditional dissimilarities such as Kullback–Leibler divergence or total variation, it reflects the minimal "cost" of transporting one distribution to another, with respect to an underlying geometry of the sample space.

Formally, the \( p \)-Wasserstein distance between two probability measures \( \mu \) and \( \nu \) on a metric space \( (\mathcal{X}, d) \) is defined as:

\[
W_p(\mu, \nu) = \left( \inf_{\gamma \in \Gamma(\mu, \nu)} \int_{\mathcal{X} \times \mathcal{X}} d(x, y)^p \, d\gamma(x, y) \right)^{1/p},
\]

where \( \Gamma(\mu, \nu) \) denotes the set of all couplings (i.e., joint distributions) with marginals \( \mu \) and \( \nu \). This is known as the Kantrovich formulation [@kantorovitch_1958_TranslocationMasses], which differs from the original formulation by @monge_1781_MemoireTheorieDeblais and yet arises the equivalent measure of distance under mild conditions.


# Toy example

We use two artificial datasets - `cassini` and `smiley` - from the [mlbench](https://CRAN.R-project.org/package=mlbench) package to demonstrate 
the usage of our package. Let's generate small samples of 
cardinality $n=100$. For visual clarity, two datasets are first normalized to have 
mean zero and unit variance across both dimensions, and the second one is 
translated by +5 in the $x$-direction.

```{r toy_code, echo=TRUE, eval=TRUE, cache=FALSE, fig.alt="Two dataets: Cassini and Smiley"}
# load the library
library(mlbench)

# generate two datasets
data1 = mlbench::mlbench.cassini(n=100)$x
data2 = mlbench::mlbench.smiley(n=100)$x

# normalize the datasets
data1 = as.matrix(scale(data1))
data2 = as.matrix(scale(data2))

# translate the second dataset
data2[,1] = data2[,1] + 5

# plot the datasets
plot(data1, col="blue", pch=19, cex=0.5, main="Two datasets", 
     xlim=c(-2, 7), xlab="x", ylab="y")
points(data2[,1], data2[,2], col="red", cex=0.5, pch=19)
```

As shown in the figure, two datasets are quite different in terms of their shapes. Moreover, 
the horizontal translation applied to the second dataset (`smiley`) makes it more distinct.

# How to compute?

Per the formulation of Kantrovich, computing the Wasserstein distance requires solving a linear programming problem. In **T4transport**, the `wasserstein()` function achieves that. Let's see how 
we can use it. Assume that we consider the simplest case of order 2.

```{r compute1, echo=TRUE, eval=TRUE, cache=TRUE, fig.alt="Wasserstein distance computation"}
# call the function
output = wasserstein(data1, data2, p=2)

# print the output
print(paste0("2-wasserstein distance: ",round(output$distance, 4)))
```

The computed Wasserstein distance of order 2 between the two empirical measures is `r round(output$distance,4)`. Another benefit of the Kantrovich formulation is that it returns 
an optimal coupling matrix that matches elements from two sets. The attained coupling matrix $\hat{\Gamma}$ is given as follows.

```{r compute2a, echo=FALSE, eval=TRUE, fig.alt="Optimal coupling matrix", fig.align="center"}
par(pty="s")
## --- Plot 1: Optimal coupling matrix
P = output$plan
image(
  x = 1:nrow(P), 
  y = 1:ncol(P), 
  z = t(P)[, nrow(P):1],         # transpose + flip rows
  col = gray.colors(100, start = 1, end = 0), 
  xlab = "Source Index", 
  ylab = "Target Index",
  main = "Optimal Coupling",
  axes = FALSE
)
axis(1, at = 1:nrow(P), labels = 1:nrow(P), las = 2, cex.axis = 0.6, tick=FALSE)
axis(2, at = 1:ncol(P), labels = rev(1:ncol(P)), las = 2, cex.axis = 0.6, tick=FALSE)
```

Furthermore, this coupling can be shown within the original scatterplot of the two datasets 
by considering the bipartite graph representation. 

```{r compute2b, echo=FALSE, eval=TRUE, fig.alt="Optimal coupling scatterplot", fig.align="center"}
## --- Plot 2: Bipartite graph
plot(data1, col="blue", pch=19, cex=0.5, main="Bipartite Graph", 
     xlim=c(-2, 7), xlab="x", ylab="y")
points(data2[,1], data2[,2], col="red", cex=0.5, pch=19)
maxP = max(P)
multiplier = 0.5
for (i in 1:nrow(data1)){
  for (j in 1:nrow(data2)){
    if (P[i,j] > 0){
      lines(x = c(data1[i, 1], data2[j, 1]),
            y = c(data1[i, 2], data2[j, 2]),
            col = "gray40",
            lwd = multiplier * P[i, j] / maxP)
    }
  }
}
```

In the figure, edges connecting the dots represent the optimal coupling between the two empirical measures. The thickness of the edges is proportional to the amount of mass transported between the points, which can be interpreted as the "flow" in the transportation problem.

# Alternative input

It is often a case where we have cross distances between two empirical measures, rather than 
the measures themselves. For instance, the current `wasserstein()` function only assumes the 
Euclidean-valued atoms. In practical scenarios, it is plausible to have the distances or 
dissimilarities according to a user-defined metric. The `wassersteinD()` function is the choice in 
such scenarios.

```{r computeD, echo=TRUE, eval=TRUE, cache=TRUE, fig.alt="Wasserstein distance computation with distances"}
# compute the cross distance with a helper function
cross_dist <- function(X, Y) {
  X2 <- rowSums(X^2)
  Y2 <- rowSums(Y^2)
  sqrt(outer(X2, Y2, "+") - 2 * tcrossprod(X, Y))
}
cdist = cross_dist(data1, data2)

# call the function
crossed = wassersteinD(cdist, p=2)

# print the output
print(paste0("2-wasserstein distance: ",round(crossed$distance, 4)))
```

Supplying a cross distance matrix also returns the 2-Wasserstein distance value of 
`r round(crossed$distance,4)`, which is identical to the value we obtained before 
(`r round(output$distance,4)`). 

# Disclaimer

Computing the Wasserstein distance using a linear programming (LP) solver is the classical approach 
based on the Kantrovich formulation. It is typically included as a core method in any OT libraries. 
However, this approach does not scale well to large-scale data, since the LP problem involves $n\times m$ variables and requires $\mathcal{O}(n^2)$ memory and at least $\mathcal{O}(n^3)$ time in the worst case. 
For this reason, we bring alternative algorithms in the package and hope you enjoy exploring them.


# References

<div id="refs"></div>

<!-- pkgdown::build_articles(); pkgdown::preview_site() -->