User guide
==========
In the following we provide some pointers about which functions and classes
to use for different problems related to optimal transport (OT) and machine
learning. We refer when we can to concrete examples in the documentation that
are also available as notebooks on the POT Github.
.. note::
For a good introduction to numerical optimal transport we refer the reader
to `the book `_ by Peyré and Cuturi
[15]_. For more detailed introduction to OT and how it can be used
in ML applications we refer the reader to the following `OTML tutorial
`_.
.. note::
Since version 0.8, POT provides a backend to automatically solve some OT
problems independently from the toolbox used by the user (numpy/torch/jax).
We provide a discussion about which functions are compatible in section
`Backend section `_ .
Why Optimal Transport ?
-----------------------
When to use OT
^^^^^^^^^^^^^^
Optimal Transport (OT) is a mathematical problem introduced by Gaspard Monge in
1781 that aim at finding the most efficient way to move mass between
distributions. The cost of moving a unit of mass between two positions is called
the ground cost and the objective is to minimize the overall cost of moving one
mass distribution onto another one. The optimization problem can be expressed
for two distributions :math:`\mu_s` and :math:`\mu_t` as
.. math::
\min_{m, m \# \mu_s = \mu_t} \int c(x,m(x))d\mu_s(x) ,
where :math:`c(\cdot,\cdot)` is the ground cost and the constraint
:math:`m \# \mu_s = \mu_t` ensures that :math:`\mu_s` is completely transported to :math:`\mu_t`.
This problem is particularly difficult to solve because of this constraint and
has been replaced in practice (on discrete distributions) by a
linear program easier to solve. It corresponds to the Kantorovitch formulation
where the Monge mapping :math:`m` is replaced by a joint distribution
(OT matrix expressed in the next section) (see :ref:`kantorovitch_solve`).
From the optimization problem above we can see that there are two main aspects
to the OT solution that can be used in practical applications:
- The optimal value (Wasserstein distance): Measures similarity between distributions.
- The optimal mapping (Monge mapping, OT matrix): Finds correspondences between distributions.
In the first case, OT can be used to measure similarity between distributions
(or datasets), in this case the Wasserstein distance (the optimal value of the
problem) is used. In the second case one can be interested in the way the mass
is moved between the distributions (the mapping). This mapping can then be used
to transfer knowledge between distributions.
Wasserstein distance between distributions
""""""""""""""""""""""""""""""""""""""""""
OT is often used to measure similarity between distributions, especially
when they do not share the same support. When the support between the
distributions is disjoint OT-based Wasserstein distances compare favorably to
popular f-divergences including the popular Kullback-Leibler, Jensen-Shannon
divergences, and the Total Variation distance. What is particularly interesting
for data science applications is that one can compute meaningful sub-gradients
of the Wasserstein distance. For these reasons it became a very efficient tool
for machine learning applications that need to measure and optimize similarity
between empirical distributions.
Numerous contributions make use of this an approach is the machine learning (ML)
literature. For example OT was used for training `Generative
Adversarial Networks (GANs) `_
in order to overcome the vanishing gradient problem. It has also
been used to find `discriminant `_ or
`robust `_ subspaces for a dataset. The
Wasserstein distance has also been used to measure `similarity between word
embeddings of documents `_ or
between `signals
`_
or `spectra `_.
OT for mapping estimation
"""""""""""""""""""""""""
A very interesting aspect of OT problem is the OT mapping in itself. When
computing optimal transport between discrete distributions one output is the OT
matrix that will provide you with correspondences between the samples in each
distributions.
This correspondence is estimated with respect to the OT criterion and is found
in a non-supervised way, which makes it very interesting on problems of transfer
between datasets. It has been used to perform
`color transfer between images `_ or in
the context of `domain adaptation `_.
More recent applications include the use of extension of OT (Gromov-Wasserstein)
to find correspondences between languages in `word embeddings
`_.
When to use POT
^^^^^^^^^^^^^^^
The main objective of POT is to provide OT solvers for the rapidly growing area
of OT in the context of machine learning. To this end we implement a number of
solvers that have been proposed in research papers. Doing so we aim to promote
reproducible research and foster novel developments.
One very important aspect of POT is its ability to be easily extended. For
instance we provide a very generic OT solver :any:`ot.optim.cg` that can solve
OT problems with any smooth/continuous regularization term making it
particularly practical for research purpose. Note that this generic solver has
been used to solve both graph Laplacian regularization OT and Gromov
Wasserstein [30]_.
When not to use POT
"""""""""""""""""""
While POT has to the best of our knowledge one of the most efficient exact OT
solvers, it has not been designed to handle large scale OT problems. For
instance the memory cost for an OT problem is always :math:`\mathcal{O}(n^2)` in
memory because the cost matrix has to be computed. The exact solver in of time
complexity :math:`\mathcal{O}(n^3\log(n))` and the Sinkhorn solver has been
proven to be nearly :math:`\mathcal{O}(n^2)` which is still too complex for very
large scale solvers. For all the generic solvers we need to compute the cost
matrix and the OT matrix of memory size :math:`\mathcal{O}(n^2)` which can be
prohibitive for very large scale problems.
If you need to solve OT with large number of samples, we provide "lazy" memory efficient implementation of Sinkhorn in pure
python and using `GeomLoss `_. This
implementation is compatible with Pytorch and can handle large number of
samples. Another approach to estimate the Wasserstein distance for very large
number of sample is to use the trick from `Wasserstein GAN
`_ that solves the problem
in the dual with a neural network estimating the dual variable. Note that in this
case you are only solving an approximation of the Wasserstein distance because
the 1-Lipschitz constraint on the dual cannot be enforced exactly (approximated
through filter thresholding or regularization). Finally note that in order to
avoid solving large scale OT problems, a number of recent approached minimized
the expected Wasserstein distance on minibatches that is different from the
Wasserstein but has better computational and
`statistical properties `_.
Optimal transport and Wasserstein distance
------------------------------------------
.. note::
In POT, most functions that solve OT or regularized OT problems have two
versions that return the OT matrix or the value of the optimal solution. For
instance :any:`ot.emd` returns the OT matrix and :any:`ot.emd2` returns the
Wasserstein distance. This approach has been implemented in practice for all
solvers that return an OT matrix (even Gromov-Wasserstein).
.. _kantorovitch_solve:
Solving optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^
The optimal transport problem between discrete distributions is often expressed
as
.. math::
\gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
where:
- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
- :math:`a` and :math:`b` are histograms on the simplex (positive, sum to 1) that represent the weights of each samples in the source an target distributions.
Solving the linear program above can be done using the function :any:`ot.emd`
that will return the optimal transport matrix :math:`\gamma^*`:
.. code:: python
# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
# unified API
T = ot.solve(M, a, b).plan # exact linear program
# classical API
T = ot.emd(a, b, M) # exact linear program
The method implemented for solving the OT problem is the network simplex. It is
implemented in C from [1]_. It has a complexity of :math:`O(n^3)` but the
solver is quite efficient and uses sparsity of the solution.
.. minigallery:: ot.emd, ot.solve
:add-heading: Examples of use for :any:`ot.emd`
:heading-level: "
Computing Wasserstein distance
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
The value of the OT solution is often more interesting than the OT matrix:
.. math::
OT(a,b) = \min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
It can computed from an already estimated OT matrix with
:code:`np.sum(T*M)` or directly with the function :any:`ot.emd2`.
.. code:: python
# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
# Wasserstein distance / EMD value with unified API
W = ot.solve(M, a, b, return_matrix=False).value
# with classical API
W = ot.emd2(a, b, M)
Note that the well known `Wasserstein distance
`_ between distributions a and
b is defined as
.. math::
W_p(a,b)=(\min_{\gamma \in \mathbb{R}_+^{m\times n}} \sum_{i,j}\gamma_{i,j}\|x_i-y_j\|_p)^\frac{1}{p}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
This means that if you want to compute the :math:`W_2` you need to compute the
square root of :any:`ot.emd2` when providing
:code:`M = ot.dist(xs, xt)`, that uses the squared euclidean distance by default. Computing
the :math:`W_1` Wasserstein distance can be done directly with :any:`ot.emd2`
when providing :code:`M = ot.dist(xs, xt, metric='euclidean')` to use the Euclidean
distance.
.. minigallery:: ot.emd2, ot.solve
:add-heading: Examples of use for :any:`ot.emd2`
:heading-level: "
Special cases
^^^^^^^^^^^^^
Note that the OT problem and the corresponding Wasserstein distance can in some
special cases be computed very efficiently.
For instance when the samples are in 1D, then the OT problem can be solved in
:math:`O(n\log(n))` by using a simple sorting. In this case we provide the
function :any:`ot.emd_1d` and :any:`ot.emd2_1d` to return respectively the OT
matrix and value. Note that since the solution is very sparse the :code:`sparse`
parameter of :any:`ot.emd_1d` allows for solving and returning the solution for
very large problems. Note that in order to compute directly the :math:`W_p`
Wasserstein distance in 1D we provide the function :any:`ot.wasserstein_1d` that
takes :code:`p` as a parameter.
Another special case for estimating OT and Monge mapping is between Gaussian
distributions. In this case there exists a close form solution given in Remark
2.29 in [15]_ and the Monge mapping is an affine function and can be
also computed from the covariances and means of the source and target
distributions. In the case when the finite sample dataset is supposed Gaussian,
we provide :any:`ot.gaussian.bures_wasserstein_mapping` that returns the parameters for the
Monge mapping.
All those special cases are accessible with the unified API of POT through the
function :any:`ot.solve_sample` with the parameter :code:`method` that allows to
choose the method used to solve the problem (with :code:`method='1D'` or :code:`method='gaussian'`).
Regularized Optimal Transport
-----------------------------
Recent developments have shown the interest of regularized OT both in terms of
computational and statistical properties.
We address in this section the regularized OT problems that can be expressed as
.. math::
\gamma^* = arg\min_{\gamma \in \mathbb{R}_+^{m\times n}} \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + \lambda\Omega(\gamma)
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
where :
- :math:`M\in\mathbb{R}_+^{m\times n}` is the metric cost matrix defining the cost to move mass from bin :math:`a_i` to bin :math:`b_j`.
- :math:`a` and :math:`b` are histograms (positive, sum to 1) that represent the weights of each samples in the source an target distributions.
- :math:`\Omega` is the regularization term.
We discuss in the following specific algorithms that can be used depending on
the regularization term.
Entropic regularized OT
^^^^^^^^^^^^^^^^^^^^^^^
This is the most common regularization used for optimal transport. It has been
proposed in the ML community by Marco Cuturi in his seminal paper [2]_. This
regularization has the following expression
.. math::
\Omega(\gamma)=\sum_{i,j}\gamma_{i,j}\log(\gamma_{i,j})
The use of the regularization term above in the optimization problem has a very
strong impact. First it makes the problem smooth which leads to new optimization
procedures such as the well known Sinkhorn algorithm [2]_ or L-BFGS (see
:any:`ot.smooth` ). Next it makes the problem
strictly convex meaning that there will be a unique solution. Finally the
solution of the resulting optimization problem can be expressed as:
.. math::
\gamma_\lambda^*=\text{diag}(u)K\text{diag}(v)
where :math:`u` and :math:`v` are vectors and :math:`K=\exp(-M/\lambda)` where
the :math:`\exp` is taken component-wise. In order to solve the optimization
problem, one can use an alternative projection algorithm called Sinkhorn-Knopp
that can be very efficient for large values of regularization.
The Sinkhorn-Knopp algorithm is implemented in :any:`ot.sinkhorn` and
:any:`ot.sinkhorn2` that return respectively the OT matrix and the value of the
linear term. Note that the regularization parameter :math:`\lambda` in the
equation above is given to those functions with the parameter :code:`reg`.
.. code:: python
# unified API
P = ot.solve(M, a, b, reg=1).plan # OT Sinkhorn matrix
loss = ot.solve(M, a, b, reg=1).value # OT Sinkhorn value
# classical API
P = ot.sinkhorn(a, b, M, reg=1) # OT Sinkhorn matrix
loss = ot.sinkhorn2(a, b, M, reg=1) # OT Sinkhorn value
More details about the algorithms used are given in the following note.
.. note::
The main function to solve entropic regularized OT is :any:`ot.sinkhorn`.
This function is a wrapper and the parameter :code:`method` allows you to select
the actual algorithm used to solve the problem:
+ :code:`method='sinkhorn'` calls :any:`ot.bregman.sinkhorn_knopp` the
classic algorithm [2]_.
+ :code:`method='sinkhorn_log'` calls :any:`ot.bregman.sinkhorn_log` the
sinkhorn algorithm in log space [2]_ that is more stable but can be
slower in numpy since `logsumexp` is not implemented in parallel.
It is the recommended solver for applications that requires
differentiability with a small number of iterations.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.bregman.sinkhorn_stabilized` the
log stabilized version of the algorithm [9]_.
+ :code:`method='sinkhorn_epsilon_scaling'` calls
:any:`ot.bregman.sinkhorn_epsilon_scaling` the epsilon scaling version
of the algorithm [9]_.
+ :code:`method='greenkhorn'` calls :any:`ot.bregman.greenkhorn` the
greedy Sinkhorn version of the algorithm [22]_.
+ :code:`method='screenkhorn'` calls :any:`ot.bregman.screenkhorn` the
screening sinkhorn version of the algorithm [26]_.
In addition to all those variants of Sinkhorn, we have another
implementation solving the problem in the smooth dual or semi-dual in
:any:`ot.smooth`. This solver uses the :any:`scipy.optimize.minimize`
function to solve the smooth problem with :code:`L-BFGS-B` algorithm. To use
this solver, use functions :any:`ot.smooth.smooth_ot_dual` or
:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='kl'` to
choose entropic/Kullback-Leibler regularization.
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default Sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:any:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:any:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the Sinkhorn
:any:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the Sinkhorn :any:`ot.bregman.screenkhorn` aim a providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:any:`ot.bregman.sinkhorn_log` solver that will no need to check for
numerical problems.
Recently Genevay et al. [23]_ introduced the Sinkhorn divergence that build from entropic
regularization to compute fast and differentiable geometric divergence between
empirical distributions. Note that we provide a function that computes directly
(with no need to precompute the :code:`M` matrix)
the Sinkhorn divergence for empirical distributions in
:any:`ot.bregman.empirical_sinkhorn_divergence`. Similarly one can compute the
OT matrix and loss for empirical distributions with respectively
:any:`ot.bregman.empirical_sinkhorn` and :any:`ot.bregman.empirical_sinkhorn2`.
Finally note that we also provide in :any:`ot.stochastic` several implementation
of stochastic solvers for entropic regularized OT [18]_ [19]_. Those pure Python
implementations are not optimized for speed but provide a robust implementation
of algorithms in [18]_ [19]_.
.. minigallery:: ot.sinkhorn ot.sinkhorn2
:add-heading: Examples of use for Sinkhorn algorithm
:heading-level: "
Other regularizations
^^^^^^^^^^^^^^^^^^^^^
While entropic OT is the most common and favored in practice, there exists other
kinds of regularizations. We provide in POT two specific solvers for other
regularization terms, namely quadratic regularization and group Lasso
regularization. But we also provide in :any:`ot.optim` two generic solvers
that allows solving any smooth regularization in practice.
Quadratic regularization
""""""""""""""""""""""""
The first general regularization term we can solve is the quadratic
regularization of the form
.. math::
\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}^2
This regularization term has an effect similar to entropic regularization by
densifying the OT matrix, yet it keeps some sort of sparsity that is lost with
entropic regularization as soon as :math:`\lambda>0` [17]_. This problem can be
solved with POT using solvers from :any:`ot.smooth`, more specifically
functions :any:`ot.smooth.smooth_ot_dual` or
:any:`ot.smooth.smooth_ot_semi_dual` with parameter :code:`reg_type='l2'` to
choose the quadratic regularization.
.. minigallery:: ot.smooth.smooth_ot_dual ot.smooth.smooth_ot_semi_dual ot.optim.cg
:add-heading: Examples of use of quadratic regularization
:heading-level: "
Group Lasso regularization
""""""""""""""""""""""""""
Another regularization that has been used in recent years [5]_ is the group Lasso
regularization
.. math::
\Omega(\gamma)=\sum_{j,G\in\mathcal{G}} \|\gamma_{G,j}\|_q^p
where :math:`\mathcal{G}` contains non-overlapping groups of lines in the OT
matrix. This regularization proposed in [5]_ promotes sparsity at the group level and for
instance will force target samples to get mass from a small number of groups.
Note that the exact OT solution is already sparse so this regularization does
not make sense if it is not combined with entropic regularization. Depending on
the choice of :code:`p` and :code:`q`, the problem can be solved with different
approaches. When :code:`q=1` and :code:`p`__. We provide in
:any:`ot.da` several solvers for smooth Monge mapping estimation and domain
adaptation from discrete distributions.
Monge Mapping estimation
^^^^^^^^^^^^^^^^^^^^^^^^
We now discuss several approaches that are implemented in POT to estimate or
approximate a Monge mapping from finite distributions.
First note that when the source and target distributions are supposed to be Gaussian
distributions, there exists a close form solution for the mapping and its an
affine function [14]_ of the form :math:`T(x)=Ax+b` . In this case we provide the function
:any:`ot.gaussian.bures_wasserstein_mapping` that returns the operator :math:`A` and vector
:math:`b`. Note that if the number of samples is too small there is a parameter
:code:`reg` that provides a regularization for the covariance matrix estimation.
For a more general mapping estimation we also provide the barycentric mapping
proposed in [6]_. It is implemented in the class :any:`ot.da.EMDTransport` and
other transport-based classes in :any:`ot.da` . Those classes are discussed more
in the following but follow an interface similar to scikit-learn classes. Finally a
method proposed in [8]_ that estimates a continuous mapping approximating the
barycentric mapping is provided in :any:`ot.da.joint_OT_mapping_linear` for
linear mapping and :any:`ot.da.joint_OT_mapping_kernel` for non-linear mapping.
.. minigallery:: ot.da.joint_OT_mapping_linear ot.da.joint_OT_mapping_linear ot.gaussian.bures_wasserstein_mapping
:add-heading: Examples of Monge mapping estimation
:heading-level: "
Domain adaptation classes
^^^^^^^^^^^^^^^^^^^^^^^^^
The use of OT for domain adaptation (OTDA) has been first proposed in [5]_ that also
introduced the group Lasso regularization. The main idea of OTDA is to estimate
a mapping of the samples between source and target distributions which allows to
transport labeled source samples onto the target distribution with no labels.
We provide several classes based on :any:`ot.da.BaseTransport` that provide
several OT and mapping estimations. The interface of those classes is similar to
classifiers in scikit-learn. At initialization, several parameters such as
regularization parameter value can be set. Then one needs to estimate the
mapping with function :any:`ot.da.BaseTransport.fit`. Finally one can map the
samples from source to target with :any:`ot.da.BaseTransport.transform` and
from target to source with :any:`ot.da.BaseTransport.inverse_transform`.
Here is an example for class :any:`ot.da.EMDTransport`:
.. code::
ot_emd = ot.da.EMDTransport()
ot_emd.fit(Xs=Xs, Xt=Xt)
Xs_mapped = ot_emd.transform(Xs=Xs)
A list of the provided implementation is given in the following note.
.. note::
Here is a list of the OT mapping classes inheriting from
:any:`ot.da.BaseTransport`
* :any:`ot.da.EMDTransport`: Barycentric mapping with EMD transport
* :any:`ot.da.SinkhornTransport`: Barycentric mapping with Sinkhorn transport
* :any:`ot.da.SinkhornL1l2Transport`: Barycentric mapping with Sinkhorn +
group Lasso regularization [5]_
* :any:`ot.da.SinkhornLpl1Transport`: Barycentric mapping with Sinkhorn +
non convex group Lasso regularization [5]_
* :any:`ot.da.LinearTransport`: Linear mapping estimation between Gaussians
[14]_
* :any:`ot.da.MappingTransport`: Nonlinear mapping estimation [8]_
.. minigallery:: ot.da.SinkhornTransport ot.da.LinearTransport
:add-heading: Examples of the use of OTDA classes
:heading-level: "
Unbalanced and partial OT
-------------------------
Unbalanced optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Unbalanced OT is a relaxation of the entropy regularized OT problem where the violation of
the constraint on the marginals is added to the objective of the optimization
problem. The unbalanced OT metric between two unbalanced histograms a and b is defined as [25]_ [10]_:
.. math::
W_u(a, b) = \min_\gamma \quad \sum_{i,j}\gamma_{i,j}M_{i,j} + reg\cdot\Omega(\gamma) + reg_m KL(\gamma 1, a) + reg_m KL(\gamma^T 1, b)
s.t. \quad \gamma\geq 0
where KL is the Kullback-Leibler divergence. This formulation allows for
computing approximate mapping between distributions that do not have the same
amount of mass. Interestingly the problem can be solved with a generalization of
the Bregman projections algorithm [10]_. We provide a solver for unbalanced OT
in :any:`ot.unbalanced`. Computing the optimal transport
plan or the transport cost is similar to the balanced case. The Sinkhorn-Knopp
algorithm is implemented in :any:`ot.sinkhorn_unbalanced` and :any:`ot.sinkhorn_unbalanced2`
that return respectively the OT matrix and the value of the
linear term.
.. note::
The main function to solve entropic regularized UOT is :any:`ot.sinkhorn_unbalanced`.
This function is a wrapper and the parameter :code:`method` helps you select
the actual algorithm used to solve the problem:
+ :code:`method='sinkhorn'` calls :any:`ot.unbalanced.sinkhorn_knopp_unbalanced`
the generalized Sinkhorn algorithm [25]_ [10]_.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.sinkhorn_stabilized_unbalanced`
the log stabilized version of the algorithm [10]_.
.. minigallery:: ot.sinkhorn_unbalanced ot.sinkhorn_unbalanced2 ot.unbalanced.sinkhorn_unbalanced
:add-heading: Examples of Unbalanced OT
:heading-level: "
Unbalanced Barycenters
^^^^^^^^^^^^^^^^^^^^^^
As with balanced distributions, we can define a barycenter of a set of
histograms with different masses as a Fréchet Mean:
.. math::
\min_{\mu} \quad \sum_{k} w_kW_u(\mu,\mu_k)
where :math:`W_u` is the unbalanced Wasserstein metric defined above. This problem
can also be solved using generalized version of Sinkhorn's algorithm and it is
implemented the main function :any:`ot.barycenter_unbalanced`.
.. note::
The main function to compute UOT barycenters is :any:`ot.barycenter_unbalanced`.
This function is a wrapper and the parameter :code:`method` helps you select
the actual algorithm used to solve the problem:
+ :code:`method='sinkhorn'` calls :meth:`ot.unbalanced.barycenter_unbalanced_sinkhorn_unbalanced`
the generalized Sinkhorn algorithm [10]_.
+ :code:`method='sinkhorn_stabilized'` calls :any:`ot.unbalanced.barycenter_unbalanced_stabilized`
the log stabilized version of the algorithm [10]_.
.. minigallery:: ot.barycenter_unbalanced ot.unbalanced.barycenter_unbalanced
:add-heading: Examples of Unbalanced OT barycenters
:heading-level: "
Partial optimal transport
^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Partial OT is a variant of the optimal transport problem when only a fixed amount of mass m
is to be transported. The partial OT metric between two histograms a and b is defined as [28]_:
.. math::
\gamma = \arg\min_\gamma _F
s.t.
\gamma\geq 0 \\
\gamma 1 \leq a\\
\gamma^T 1 \leq b\\
1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
Interestingly the problem can be casted into a regular OT problem by adding reservoir points
in which the surplus mass is sent [29]_. We provide a solver for partial OT
in :any:`ot.partial`. The exact resolution of the problem is computed in :any:`ot.partial.partial_wasserstein`
and :any:`ot.partial.partial_wasserstein2` that return respectively the OT matrix and the value of the
linear term. The entropic solution of the problem is computed in :any:`ot.partial.entropic_partial_wasserstein`
(see [3]_).
The partial Gromov-Wasserstein formulation of the problem
.. math::
GW = \min_\gamma \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*\gamma_{i,j}*\gamma_{k,l}
s.t.
\gamma\geq 0 \\
\gamma 1 \leq a\\
\gamma^T 1 \leq b\\
1^T \gamma^T 1 = m \leq \min\{\|a\|_1, \|b\|_1\}
is computed in :any:`ot.partial.partial_gromov_wasserstein` and in
:any:`ot.partial.entropic_partial_gromov_wasserstein` when considering the entropic
regularization of the problem.
.. minigallery:: ot.partial.partial_wasserstein ot.partial.partial_gromov_wasserstein
:add-heading: Examples of Partial OT
:heading-level: "
Gromov Wasserstein and extensions
---------------------------------
Gromov Wasserstein(GW)
^^^^^^^^^^^^^^^^^^^^^^
Gromov Wasserstein (GW) is a generalization of OT to distributions that do not lie in
the same space [13]_. In this case one cannot compute distance between samples
from the two distributions. [13]_ proposed instead to realign the metric spaces
by computing a transport between distance matrices. The Gromov Wasserstein
alignment between two distributions can be expressed as the one minimizing:
.. math::
GW = \min_\gamma \sum_{i,j,k,l} L(C1_{i,k},C2_{j,l})*\gamma_{i,j}*\gamma_{k,l}
s.t. \gamma 1 = a; \gamma^T 1= b; \gamma\geq 0
where ::math:`C1` is the distance matrix between samples in the source
distribution and :math:`C2` the one between samples in the target,
:math:`L(C1_{i,k},C2_{j,l})` is a measure of similarity between
:math:`C1_{i,k}` and :math:`C2_{j,l}` often chosen as
:math:`L(C1_{i,k},C2_{j,l})=\|C1_{i,k}-C2_{j,l}\|^2`. The optimization problem
above is a non-convex quadratic program but we provide a solver that finds a
local minimum using conditional gradient in :any:`ot.gromov.gromov_wasserstein`.
There also exists an entropic regularized variant of GW that has been proposed in
[12]_ and we provide an implementation of their algorithm in
:any:`ot.gromov.entropic_gromov_wasserstein`.
.. minigallery:: ot.gromov.gromov_wasserstein ot.gromov.entropic_gromov_wasserstein ot.gromov.fused_gromov_wasserstein ot.gromov.gromov_wasserstein2
:add-heading: Examples of computation of GW, regularized G and FGW
:heading-level: "
Gromov Wasserstein barycenters
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Note that similarly to Wasserstein distance GW allows for the definition of GW
barycenters that can be expressed as
.. math::
\min_{C\geq 0} \quad \sum_{k} w_k GW(C,Ck)
where :math:`Ck` is the distance matrix between samples in distribution
:math:`k`. Note that interestingly the barycenter is defined as a symmetric
positive matrix. We provide a block coordinate optimization procedure in
:any:`ot.gromov.gromov_barycenters` and
:any:`ot.gromov.entropic_gromov_barycenters` for non-regularized and regularized
barycenters respectively.
Finally note that recently a fusion between Wasserstein and GW, coined Fused
Gromov-Wasserstein (FGW) has been proposed [24]_.
It allows to compute a similarity between objects that are only partly in
the same space. As such it can be used to measure similarity between labeled
graphs for instance and also provide computable barycenters.
The implementations of FGW and FGW barycenter is provided in functions
:any:`ot.gromov.fused_gromov_wasserstein` and :any:`ot.gromov.fgw_barycenters`.
.. minigallery:: ot.gromov.gromov_barycenters ot.gromov.fgw_barycenters
:add-heading: Examples of GW, regularized G and FGW barycenters
:heading-level: "
Other applications
------------------
We discuss in the following several OT related problems and tools that has been
proposed in the OT and machine learning community.
Wasserstein Discriminant Analysis
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
Wasserstein Discriminant Analysis [11]_ is a generalization of `Fisher Linear Discriminant
Analysis `__ that
allows discrimination between classes that are not linearly separable. It
consists in finding a linear projector optimizing the following criterion
.. math::
P = \text{arg}\min_P \frac{\sum_i OT_e(\mu_i\#P,\mu_i\#P)}{\sum_{i,j\neq i}
OT_e(\mu_i\#P,\mu_j\#P)}
where :math:`\#` is the push-forward operator, :math:`OT_e` is the entropic OT
loss and :math:`\mu_i` is the
distribution of samples from class :math:`i`. :math:`P` is also constrained to
be in the Stiefel manifold. WDA can be solved in POT using function
:any:`ot.dr.wda`. It requires to have installed :code:`pymanopt` and
:code:`autograd` for manifold optimization and automatic differentiation
respectively. Note that we also provide the Fisher discriminant estimator in
:any:`ot.dr.fda` for easy comparison.
.. warning::
Note that due to the hard dependency on :code:`pymanopt` and
:code:`autograd`, :any:`ot.dr` is not imported by default. If you want to
use it you have to specifically import it with :code:`import ot.dr` .
.. minigallery:: ot.dr.wda
:add-heading: Examples of the use of WDA
:heading-level: "
Solving OT with Multiple backends on CPU/GPU
--------------------------------------------
.. _backends_section:
Since version 0.8, POT provides a backend that allows to code solvers
independently from the type of the input arrays. The idea is to provide the user
with a package that works seamlessly and returns a solution for instance as a
Pytorch tensors when the function has Pytorch tensors as input.
How it works
^^^^^^^^^^^^
The aim of the backend is to use the same function independently of the type of
the input arrays.
For instance when executing the following code
.. code:: python
# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
T = ot.emd(a, b, M) # exact linear program
w = ot.emd2(a, b, M) # Wasserstein computation
the functions :any:`ot.emd` and :any:`ot.emd2` can take inputs of the type
:any:`numpy.array`, :any:`torch.tensor` or :any:`jax.numpy.array`. The output of
the function will be the same type as the inputs and on the same device. When
possible all computations are done on the same device and also when possible the
output will be differentiable with respect to the input of the function.
GPU acceleration
^^^^^^^^^^^^^^^^
The backends provide automatic computations/compatibility on GPU for most of the
POT functions.
Note that all solvers relying on the exact OT solver en C++ will need to solve the
problem on CPU which can incur some memory copy overhead and be far from optimal
when all other computations are done on GPU. They will still work on array on
GPU since the copy is done automatically.
Some of the functions that rely on the exact C++ solver are:
- :any:`ot.emd`, :any:`ot.emd2`
- :any:`ot.gromov_wasserstein`, :any:`ot.gromov_wasserstein2`
- :any:`ot.optim.cg`
List of compatible Backends
^^^^^^^^^^^^^^^^^^^^^^^^^^^
- `Numpy `_ (all functions and solvers)
- `Pytorch `_ (all outputs differentiable w.r.t. inputs)
- `Jax `_ (Some functions are differentiable some require a wrapper)
- `Tensorflow `_ (all outputs differentiable w.r.t. inputs)
- `Cupy `_ (no differentiation, GPU only)
The library automatically detects which backends are available for use. A backend
is instantiated lazily only when necessary to prevent unwarranted GPU memory allocations.
You can also disable the import of a specific backend library (e.g., to accelerate
loading of `ot` library) using the environment variable `POT_BACKEND_DISABLE_` with in (TORCH,TENSORFLOW,CUPY,JAX).
For instance, to disable TensorFlow, set `export POT_BACKEND_DISABLE_TENSORFLOW=1`.
It's important to note that the `numpy` backend cannot be disabled.
FAQ
---
1. **How to solve a discrete optimal transport problem ?**
The solver for discrete OT is the function :py:mod:`ot.emd` that returns
the OT transport matrix. If you want to solve a regularized OT you can
use :py:mod:`ot.sinkhorn`.
Here is a simple use case:
.. code:: python
# a and b are 1D histograms (sum to 1 and positive)
# M is the ground cost matrix
T = ot.emd(a, b, M) # exact linear program
T_reg = ot.sinkhorn(a, b, M, reg) # entropic regularized OT
More detailed examples can be seen on this example:
:doc:`auto_examples/plot_OT_2D_samples`
2. **pip install POT fails with error : ImportError: No module named Cython.Build**
As discussed shortly in the README file. POT`__ for more
details.
3. **Why is Sinkhorn slower than EMD ?**
This might come from the choice of the regularization term. The speed of
convergence of Sinkhorn depends directly on this term [22]_. When the
regularization gets very small the problem tries to approximate the exact OT
which leads to slow convergence in addition to numerical problems. In other
words, for large regularization Sinkhorn will be very fast to converge, for
small regularization (when you need an OT matrix close to the true OT), it
might be quicker to use the EMD solver.
Also note that the numpy implementation of Sinkhorn can use parallel
computation depending on the configuration of your system, yet very important
speedup can be obtained by using a GPU implementation since all operations
are matrix/vector products.
References
----------
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011,
December). `Displacement nterpolation using Lagrangian mass transport
`__.
In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM.
.. [2] Cuturi, M. (2013). `Sinkhorn distances: Lightspeed computation of
optimal transport `__. In Advances
in Neural Information Processing Systems (pp. 2292-2300).
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). `Iterative Bregman projections for regularized transportation
problems `__. SIAM Journal on
Scientific Computing, 37(2), A1111-A1138.
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, `Optimal Transport
for Domain Adaptation `__, in IEEE
Transactions on Pattern Analysis and Machine Intelligence , vol.PP,
no.99, pp.1-1
.. [6] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014).
`Regularized discrete optimal
transport `__. SIAM Journal on
Imaging Sciences, 7(3), 1853-1882.
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). `Generalized
conditional gradient: analysis of convergence and
applications `__. arXiv preprint
arXiv:1510.06567.
.. [8] M. Perrot, N. Courty, R. Flamary, A. Habrard (2016), `Mapping
estimation for discrete optimal
transport `__,
Neural Information Processing Systems (NIPS).
.. [9] Schmitzer, B. (2016). `Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport
Problems `__. arXiv preprint
arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
`Scaling algorithms for unbalanced transport
problems `__. arXiv preprint
arXiv:1607.05816.
.. [11] Flamary, R., Cuturi, M., Courty, N., & Rakotomamonjy, A. (2016).
`Wasserstein Discriminant
Analysis `__. arXiv preprint
arXiv:1608.08063.
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
`Gromov-Wasserstein averaging of kernel and distance
matrices `__
International Conference on Machine Learning (ICML).
.. [13] Mémoli, Facundo (2011). `GromovâWasserstein distances and the
metric approach to object
matching `__.
Foundations of computational mathematics 11.4 : 417-487.
.. [14] Knott, M. and Smith, C. S. (1984). `On the optimal mapping of
distributions `__,
Journal of Optimization Theory and Applications Vol 43.
.. [15] Peyré, G., & Cuturi, M. (2018). `Computational Optimal
Transport `__ .
.. [16] Agueh, M., & Carlier, G. (2011). `Barycenters in the Wasserstein
space `__. SIAM
Journal on Mathematical Analysis, 43(2), 904-924.
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). `Smooth and Sparse
Optimal Transport `__. Proceedings of
the Twenty-First International Conference on Artificial Intelligence and
Statistics (AISTATS).
.. [18] Genevay, A., Cuturi, M., Peyré, G. & Bach, F. (2016) `Stochastic
Optimization for Large-scale Optimal
Transport `__. Advances in Neural
Information Processing Systems (2016).
.. [19] Seguy, V., Bhushan Damodaran, B., Flamary, R., Courty, N., Rolet,
A.& Blondel, M. `Large-scale Optimal Transport and Mapping
Estimation `__. International
Conference on Learning Representation (2018)
.. [20] Cuturi, M. and Doucet, A. (2014) `Fast Computation of Wasserstein
Barycenters `__.
International Conference in Machine Learning
.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A.,
Nguyen, A. & Guibas, L. (2015). `Convolutional wasserstein distances:
Efficient optimal transportation on geometric
domains `__. ACM
Transactions on Graphics (TOG), 34(4), 66.
.. [22] J. Altschuler, J.Weed, P. Rigollet, (2017) `Near-linear time
approximation algorithms for optimal transport via Sinkhorn
iteration `__,
Advances in Neural Information Processing Systems (NIPS) 31
.. [23] Genevay, A., Peyré, G., Cuturi, M., `Learning Generative Models with
Sinkhorn Divergences `__, Proceedings
of the Twenty-First International Conference on Artificial Intelligence
and Statistics, (AISTATS) 21, 2018
.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
(2019). `Optimal Transport for structured data with application on
graphs `__ Proceedings
of the 36th International Conference on Machine Learning (ICML).
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
.. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019). Screening Sinkhorn
Algorithm for Regularized Optimal Transport ,
Advances in Neural Information Processing Systems 33 (NeurIPS).
.. [28] Caffarelli, L. A., McCann, R. J. (2020). Free boundaries in optimal transport and
Monge-Ampere obstacle problems ,
Annals of mathematics, 673-730.
.. [29] Chapel, L., Alaya, M., Gasso, G. (2019). Partial Gromov-Wasserstein with
Applications on Positive-Unlabeled Learning ,
arXiv preprint arXiv:2002.08276.
.. [30] Flamary, Rémi, et al. "Optimal transport with Laplacian regularization:
Applications to domain adaptation and shape matching." NIPS Workshop on Optimal
Transport and Machine Learning OTML. 2014.
.. [31] Bonneel, Nicolas, et al. `Sliced and radon wasserstein barycenters of
measures
`_\
, Journal of Mathematical Imaging and Vision 51.1 (2015): 22-45
.. [32] Huang, M., Ma S., Lai, L. (2021). `A Riemannian Block Coordinate Descent Method for Computing the Projection Robust Wasserstein Distance `_\ , Proceedings of the 38th International Conference on Machine Learning (ICML).
.. [33] Kerdoncuff T., Emonet R., Marc S. `Sampled Gromov Wasserstein
`_\ , Machine
Learning Journal (MJL), 2021
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., &
Peyré, G. (2019, April). `Interpolating between optimal transport and MMD
using Sinkhorn divergences
`_. In The 22nd
International Conference on Artificial Intelligence and Statistics (pp.
2681-2690). PMLR.
.. [35] Deshpande, I., Hu, Y. T., Sun, R., Pyrros, A., Siddiqui, N., Koyejo, S.,
& Schwing, A. G. (2019). `Max-sliced wasserstein distance and its use
for gans
`_.
In Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (pp. 10648-10656).
.. [36] Liutkus, A., Simsekli, U., Majewski, S., Durmus, A., & Stöter, F. R.
(2019, May). `Sliced-Wasserstein flows: Nonparametric generative modeling via
optimal transport and diffusions
`_. In International
Conference on Machine Learning (pp. 4104-4113). PMLR.
.. [37] Janati, H., Cuturi, M., Gramfort, A. `Debiased sinkhorn barycenters
`_ Proceedings of
the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020
.. [38] C. Vincent-Cuaz, T. Vayer, R. Flamary, M. Corneli, N. Courty, `Online
Graph Dictionary Learning `_\ ,
International Conference on Machine Learning (ICML), 2021.
.. [39] Gozlan, N., Roberto, C., Samson, P. M., & Tetali, P. (2017).
`Kantorovich duality for general transport costs and applications
`_.
Journal of Functional Analysis, 273(11), 3327-3405.
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., &
Weed, J. (2019, April). `Statistical optimal transport via factored
couplings `_. In
The 22nd International Conference on Artificial Intelligence and Statistics
(pp. 2454-2465). PMLR.
.. [41] Xu, H., Luo, D., & Carin, L. (2019). `Scalable Gromov-Wasserstein learning for graph partitioning and matching
`_\ , Advances in neural information processing systems, 32.