# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Alexander Tong <alexander.tong@yale.edu>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
import warnings
import numpy as np
from ..utils import list_to_array
from ..backend import get_backend
[docs]
def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
numerical problems.
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling
:ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
the entropic contribution).
.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
**Choosing a Sinkhorn solver**
By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
numerical problems.
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
0.26894142136999516
.. _references-sinkhorn2:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
algorithms for optimal transport via Sinkhorn iteration,
Advances in Neural Information Processing Systems (NIPS) 31, 2017
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April).
Interpolating between optimal transport and MMD using Sinkhorn
divergences. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
if log:
return nx.sum(M * res[0]), res[1]
else:
return nx.sum(M * res)
else:
if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp
matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-knopp:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
K = nx.exp(M / (-reg))
Kp = (1 / a).reshape(-1, 1) * K
err = 1
for ii in range(numItermax):
uprev = u
vprev = v
KtransposeU = nx.dot(K.T, u)
v = b / KtransposeU
u = 1. / nx.dot(Kp, v)
if (nx.any(KtransposeU == 0)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.einsum('i,ij,j->j', u, K, v)
err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")
if log:
log['niter'] = ii
log['u'] = u
log['v'] = v
if n_hists: # return only loss
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
return res
else: # return OT matrix
if log:
return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
else:
return u.reshape((-1, 1)) * K * v.reshape((1, -1))
[docs]
def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm :ref:`[2] <references-sinkhorn-log>` with the
implementation from :ref:`[34] <references-sinkhorn-log>`
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-log:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April). Interpolating between
optimal transport and MMD using Sinkhorn divergences. In The
22nd International Conference on Artificial Intelligence and
Statistics (pp. 2681-2690). PMLR.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
# init data
dim_a = len(a)
dim_b = b.shape[0]
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
# in case of multiple historgrams
if n_hists > 1 and warmstart is None:
warmstart = [None] * n_hists
if n_hists: # we do not want to use tensors sor we do a loop
lst_loss = []
lst_u = []
lst_v = []
for k in range(n_hists):
res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)
if log:
lst_loss.append(nx.sum(M * res[0]))
lst_u.append(res[1]['log_u'])
lst_v.append(res[1]['log_v'])
else:
lst_loss.append(nx.sum(M * res))
res = nx.stack(lst_loss)
if log:
log = {'log_u': nx.stack(lst_u, 1),
'log_v': nx.stack(lst_v, 1), }
log['u'] = nx.exp(log['log_u'])
log['v'] = nx.exp(log['log_v'])
return res, log
else:
return res
else:
if log:
log = {'err': []}
Mr = - M / reg
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
u = nx.zeros(dim_a, type_as=M)
v = nx.zeros(dim_b, type_as=M)
else:
u, v = warmstart
def get_logT(u, v):
if n_hists:
return Mr[:, :, None] + u + v
else:
return Mr + u[:, None] + v[None, :]
loga = nx.log(a)
logb = nx.log(b)
err = 1
for ii in range(numItermax):
v = logb - nx.logsumexp(Mr + u[:, None], 0)
u = loga - nx.logsumexp(Mr + v[None, :], 1)
if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
err = nx.norm(tmp2 - b) # violation of marginal
if log:
log['err'].append(err)
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")
if log:
log['niter'] = ii
log['log_u'] = u
log['log_v'] = v
log['u'] = nx.exp(u)
log['v'] = nx.exp(v)
return nx.exp(get_logT(u, v)), log
else:
return nx.exp(get_logT(u, v))
[docs]
def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix
The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
which is a stochastic version of the Sinkhorn-Knopp
algorithm :ref:`[2] <references-greenkhorn>`
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.greenkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-greenkhorn:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013
.. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
approximation algorithms for optimal transport via Sinkhorn
iteration, Advances in Neural Information Processing
Systems (NIPS) 31, 2017
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if nx.__name__ in ("jax", "tf"):
raise TypeError("JAX or TF arrays have been received. Greenkhorn is not "
"compatible with neither JAX nor TF")
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
dim_a = a.shape[0]
dim_b = b.shape[0]
K = nx.exp(-M / reg)
if warmstart is None:
u = nx.full((dim_a,), 1. / dim_a, type_as=K)
v = nx.full((dim_b,), 1. / dim_b, type_as=K)
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]
viol = nx.sum(G, axis=1) - a
viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
if log:
log = dict()
log['u'] = u
log['v'] = v
for ii in range(numItermax):
i_1 = nx.argmax(nx.abs(viol))
i_2 = nx.argmax(nx.abs(viol_2))
m_viol_1 = nx.abs(viol[i_1])
m_viol_2 = nx.abs(viol_2[i_2])
stopThr_val = nx.maximum(m_viol_1, m_viol_2)
if m_viol_1 > m_viol_2:
old_u = u[i_1]
new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v
viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
u[i_1] = new_u
else:
old_v = v[i_2]
new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v
if stopThr_val <= stopThr:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")
if log:
log["n_iter"] = ii
log['u'] = u
log['v'] = v
if log:
return G, log
else:
return G
[docs]
def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
but with the log stabilization
proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
:ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
for log scaling
warmstart : table of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-stabilized:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
a = a[:, None]
else:
n_hists = 0
# init data
dim_a = len(a)
dim_b = len(b)
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
u /= dim_a
v /= dim_b
def get_K(alpha, beta):
"""log space computation"""
return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)
def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
/ reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))
K = get_K(alpha, beta)
transp = K
err = 1
for ii in range(numItermax):
uprev = u
vprev = v
# sinkhorn update
v = b / (nx.dot(K.T, u))
u = a / (nx.dot(K, v))
# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
alpha, beta = alpha + reg * \
nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)
if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
err_u = nx.max(nx.abs(u - uprev))
err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
err_v = nx.max(nx.abs(v - vprev))
err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr:
break
if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + nx.log(u)
logv = beta / reg + nx.log(v)
log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
log['alpha'] = alpha + reg * nx.log(u)
log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
res = nx.stack([
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
])
return res, log
else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
res = nx.stack([
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
])
return res
else:
return get_Gamma(alpha, beta, u, v)
[docs]
def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term
:math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
but with the log stabilization
proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
for log scaling
warmstart : tuple of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])
.. _references-sinkhorn-epsilon-scaling:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
# init data
dim_a = len(a)
dim_b = len(b)
# nrelative umerical precision with 64 bits
numItermin = 35
numItermax = max(numItermin, numItermax) # ensure that last velue is exact
ii = 0
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart
# print(np.min(K))
def get_reg(n): # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg
err = 1
for ii in range(numItermax):
regi = get_reg(ii)
G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)
alpha = logi['alpha']
beta = logi['beta']
if ii % (print_period) == 0: # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)
if verbose:
if ii % (print_period * 10) == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err <= stopThr and ii > numItermin:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
log['niter'] = ii
return G, log
else:
return G