# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Alexander Tong <alexander.tong@yale.edu>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
import warnings
import numpy as np
from ..utils import list_to_array
from ..backend import get_backend
[docs]
def sinkhorn(
a,
b,
M,
reg,
method="sinkhorn",
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
**kwargs,
):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
numerical problems.
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling
:ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == "sinkhorn":
return sinkhorn_knopp(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "sinkhorn_log":
return sinkhorn_log(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "greenkhorn":
return greenkhorn(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
)
elif method.lower() == "sinkhorn_stabilized":
return sinkhorn_stabilized(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
warmstart=warmstart,
verbose=verbose,
log=log,
warn=warn,
**kwargs,
)
elif method.lower() == "sinkhorn_epsilon_scaling":
return sinkhorn_epsilon_scaling(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
warmstart=warmstart,
verbose=verbose,
log=log,
warn=warn,
**kwargs,
)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def sinkhorn2(
a,
b,
M,
reg,
method="sinkhorn",
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=False,
warmstart=None,
**kwargs,
):
r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
the entropic contribution).
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
numerical problems.
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
0.26894142136999516
.. _references-sinkhorn2:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
algorithms for optimal transport via Sinkhorn iteration,
Advances in Neural Information Processing Systems (NIPS) 31, 2017
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April).
Interpolating between optimal transport and MMD using Sinkhorn
divergences. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
if len(b.shape) < 2:
if method.lower() == "sinkhorn":
res = sinkhorn_knopp(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "sinkhorn_log":
res = sinkhorn_log(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "sinkhorn_stabilized":
res = sinkhorn_stabilized(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
warmstart=warmstart,
verbose=verbose,
log=log,
warn=warn,
**kwargs,
)
else:
raise ValueError("Unknown method '%s'." % method)
if log:
return nx.sum(M * res[0]), res[1]
else:
return nx.sum(M * res)
else:
if method.lower() == "sinkhorn":
return sinkhorn_knopp(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "sinkhorn_log":
return sinkhorn_log(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warn=warn,
warmstart=warmstart,
**kwargs,
)
elif method.lower() == "sinkhorn_stabilized":
return sinkhorn_stabilized(
a,
b,
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
warmstart=warmstart,
verbose=verbose,
log=log,
warn=warn,
**kwargs,
)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def sinkhorn_knopp(
a,
b,
M,
reg,
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
**kwargs,
):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp
matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-knopp:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
if log:
log = {"err": []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
err = 1
for ii in range(numItermax):
uprev = u
vprev = v
KtransposeU = nx.dot(K.T, u)
v = b / KtransposeU
u = 1.0 / nx.dot(Kp, v)
if (
nx.any(KtransposeU == 0)
or nx.any(nx.isnan(u))
or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u))
or nx.any(nx.isinf(v))
):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn("Warning: numerical errors at iteration %d" % ii)
u = uprev
v = vprev
break
if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
tmp2 = nx.einsum("ik,ij,jk->jk", u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.einsum("i,ij,j->j", u, K, v)
err = nx.norm(tmp2 - b) # violation of marginal
if log:
log["err"].append(err)
if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["niter"] = ii
log["u"] = u
log["v"] = v
if n_hists: # return only loss
res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M)
if log:
return res, log
else:
return res
else: # return OT matrix
if log:
return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
else:
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
[docs]
def sinkhorn_log(
a,
b,
M,
reg,
numItermax=1000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
**kwargs,
):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the
implementation from :ref:`[34] <references-sinkhorn-log>`
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-log:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April). Interpolating between
optimal transport and MMD using Sinkhorn divergences. In The
22nd International Conference on Artificial Intelligence and
Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
# in case of multiple histograms
if n_hists > 1 and warmstart is None:
warmstart = [None] * n_hists
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
lst_u = []
lst_v = []
for k in range(n_hists):
res = sinkhorn_log(
a,
b[:, k],
M,
reg,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log,
warmstart=warmstart[k],
**kwargs,
)
if log:
lst_loss.append(nx.sum(M * res[0]))
lst_u.append(res[1]["log_u"])
lst_v.append(res[1]["log_v"])
else:
lst_loss.append(nx.sum(M * res))
res = nx.stack(lst_loss)
if log:
log = {
"log_u": nx.stack(lst_u, 1),
"log_v": nx.stack(lst_v, 1),
}
log["u"] = nx.exp(log["log_u"])
log["v"] = nx.exp(log["log_v"])
return res, log
else:
return res
else:
if log:
log = {"err": []}
Mr = -M / reg
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
u = nx.zeros(dim_a, type_as=M)
v = nx.zeros(dim_b, type_as=M)
else:
u, v = warmstart
def get_logT(u, v):
if n_hists:
return Mr[:, :, None] + u + v
else:
return Mr + u[:, None] + v[None, :]
loga = nx.log(a)
logb = nx.log(b)
err = 1
for ii in range(numItermax):
v = logb - nx.logsumexp(Mr + u[:, None], 0)
u = loga - nx.logsumexp(Mr + v[None, :], 1)
if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
err = nx.norm(tmp2 - b) # violation of marginal
if log:
log["err"].append(err)
if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["niter"] = ii
log["log_u"] = u
log["log_v"] = v
log["u"] = nx.exp(u)
log["v"] = nx.exp(v)
return nx.exp(get_logT(u, v)), log
else:
return nx.exp(get_logT(u, v))
[docs]
def greenkhorn(
a,
b,
M,
reg,
numItermax=10000,
stopThr=1e-9,
verbose=False,
log=False,
warn=True,
warmstart=None,
):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
which is a stochastic version of the Sinkhorn-Knopp
algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.greenkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-greenkhorn:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
.. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
approximation algorithms for optimal transport via Sinkhorn
iteration, Advances in Neural Information Processing
Systems (NIPS) 31, 2017
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if nx.__name__ in ("jax", "tf"):
raise TypeError(
"JAX or TF arrays have been received. Greenkhorn is not "
"compatible with neither JAX nor TF"
)
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
K = nx.exp(-M / reg)
if warmstart is None:
u = nx.full((dim_a,), 1.0 / dim_a, type_as=K)
v = nx.full((dim_b,), 1.0 / dim_b, type_as=K)
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
if log:
log = dict()
log["u"] = u
log["v"] = v
for ii in range(numItermax):
i_1 = nx.argmax(nx.abs(viol))
i_2 = nx.argmax(nx.abs(viol_2))
m_viol_1 = nx.abs(viol[i_1])
m_viol_2 = nx.abs(viol_2[i_2])
stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v
viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += K[i_1, :].T * (new_u - old_u) * v
u[i_1] = new_u
else:
old_v = v[i_2]
new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v
if stopThr_val <= stopThr:
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["n_iter"] = ii
log["u"] = u
log["v"] = v
if log:
return G, log
else:
return G
[docs]
def sinkhorn_stabilized(
a,
b,
M,
reg,
numItermax=1000,
tau=1e3,
stopThr=1e-9,
warmstart=None,
verbose=False,
print_period=20,
log=False,
warn=True,
**kwargs,
):
r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
but with the log stabilization
proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
:ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
for log scaling
warmstart : table of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-stabilized:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
a = a[:, None]
else:
n_hists = 0
# init data
dim_a = len(a)
dim_b = len(b)
if log:
log = {"err": []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
u /= dim_a
v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
return nx.exp(
-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg
+ nx.log(u.reshape((dim_a, 1)))
+ nx.log(v.reshape((1, dim_b)))
)
K = get_K(alpha, beta)
transp = K
err = 1
for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
v = b / (nx.dot(K.T, u))
u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
alpha, beta = (
alpha + reg * nx.max(nx.log(u), 1),
beta + reg * nx.max(nx.log(v)),
)
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
err_u = nx.max(nx.abs(u - uprev))
err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
err_v = nx.max(nx.abs(v - vprev))
err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log["err"].append(err)
if verbose:
if ii % (print_period * 20) == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
if err <= stopThr:
break
if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn("Numerical errors at iteration %d" % ii)
u = uprev
v = vprev
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + nx.log(u)
logv = beta / reg + nx.log(v)
log["n_iter"] = ii
log["logu"] = logu
log["logv"] = logv
log["alpha"] = alpha + reg * nx.log(u)
log["beta"] = beta + reg * nx.log(v)
log["warmstart"] = (log["alpha"], log["beta"])
if n_hists:
res = nx.stack(
[
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
]
)
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
res = nx.stack(
[
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
]
)
return res
else:
return get_Gamma(alpha, beta, u, v)
[docs]
def sinkhorn_epsilon_scaling(
a,
b,
M,
reg,
numItermax=100,
epsilon0=1e4,
numInnerItermax=100,
tau=1e3,
stopThr=1e-9,
warmstart=None,
verbose=False,
print_period=10,
log=False,
warn=True,
**kwargs,
):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
but with the log stabilization
proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
for log scaling
warmstart : tuple of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-epsilon-scaling:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
ii = 0
if log:
log = {"err": []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
# print(np.min(K))
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
err = 1
for ii in range(numItermax):
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(
a,
b,
M,
regi,
numItermax=numInnerItermax,
stopThr=stopThr,
warmstart=(alpha, beta),
verbose=False,
print_period=20,
tau=tau,
log=True,
)
alpha = logi["alpha"]
beta = logi["beta"]
if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
err = (
nx.norm(nx.sum(transp, axis=0) - b) ** 2
+ nx.norm(nx.sum(transp, axis=1) - a) ** 2
)
if log:
log["err"].append(err)
if verbose:
if ii % (print_period * 10) == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
if err <= stopThr and ii > numItermin:
break
else:
if warn:
warnings.warn(
"Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["alpha"] = alpha
log["beta"] = beta
log["warmstart"] = (log["alpha"], log["beta"])
log["niter"] = ii
return G, log
else:
return G