# Source code for ot.bregman

# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized OT
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#         Kilian Fatras <kilian.fatras@irisa.fr>
#         Titouan Vayer <titouan.vayer@irisa.fr>
#         Hicham Janati <hicham.janati100@gmail.com>
#         Mokhtar Z. Alaya <mokhtarzahdi.alaya@gmail.com>
#         Alexander Tong <alexander.tong@yale.edu>
#         Ievgen Redko <ievgen.redko@univ-st-etienne.fr>
#         Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#

import warnings

import numpy as np
from scipy.optimize import fmin_l_bfgs_b

from ot.utils import unif, dist, list_to_array
from .backend import get_backend

[docs]def sinkhorn(a, b, M, reg, method='sinkhorn', numItermax=1000, stopThr=1e-9,
verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target
weights (histograms, both sum to 1)

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:[2] <references-sinkhorn>

**Choosing a Sinkhorn solver**

By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:ot.bregman.sinkhorn_stabilized solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:ot.bregman.sinkhorn_epsilon_scaling that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:ot.bregman.greenkhorn can also lead to a speedup and the screening
version of the sinkhorn :py:func:ot.bregman.screenkhorn aim at providing  a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:ot.bregman.sinkhorn_log solver that will no need to check for
numerical problems.

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:\mathbf{M} if :math:\mathbf{b} is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-sinkhorn:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.

.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:[2] <references-sinkhorn>
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:[9] <references-sinkhorn> :ref:[10] <references-sinkhorn>
ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling
:ref:[9] <references-sinkhorn> :ref:[10] <references-sinkhorn>

"""

if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'greenkhorn':
return greenkhorn(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_epsilon_scaling':
return sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

[docs]def sinkhorn2(a, b, M, reg, method='sinkhorn', numItermax=1000,
stopThr=1e-9, verbose=False, log=False, warn=False, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the loss

The function solves the following optimization problem:

.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target
weights (histograms, both sum to 1)

and returns :math:\langle \gamma^*, \mathbf{M} \rangle_F (without
the entropic contribution).

.. note:: This function is backend-compatible and will work on arrays
from all compatible backends.

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:[2] <references-sinkhorn2>

**Choosing a Sinkhorn solver**

By default and when using a regularization parameter that is not too small
the default sinkhorn solver should be enough. If you need to use a small
regularization to get sharper OT matrices, you should use the
:py:func:ot.bregman.sinkhorn_log solver that will avoid numerical
errors. This last solver can be very slow in practice and might not even
converge to a reasonable OT matrix in a finite time. This is why
:py:func:ot.bregman.sinkhorn_epsilon_scaling that relies on iterating the value
of the regularization (and using warm start) sometimes leads to better
solutions. Note that the greedy version of the sinkhorn
:py:func:ot.bregman.greenkhorn can also lead to a speedup and the screening
version of the sinkhorn :py:func:ot.bregman.screenkhorn aim a providing  a
fast approximation of the Sinkhorn problem. For use of GPU and gradient
computation with small number of iterations we strongly recommend the
:py:func:ot.bregman.sinkhorn_log solver that will no need to check for
numerical problems.

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:\mathbf{M} if :math:\mathbf{b} is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
method : str
method used for the solver either 'sinkhorn','sinkhorn_log',
'sinkhorn_stabilized', see those function for specific parameters
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
W : (n_hists) float/array-like
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn2(a, b, M, 1)
0.26894142136999516

.. _references-sinkhorn2:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.

.. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
algorithms for optimal transport via Sinkhorn iteration,
Advances in Neural Information Processing Systems (NIPS) 31, 2017

.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April).
Interpolating between optimal transport and MMD using Sinkhorn
divergences. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2681-2690). PMLR.

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:[2] <references-sinkhorn2>
ot.bregman.greenkhorn : Greenkhorn :ref:[21] <references-sinkhorn2>
ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
:ref:[9] <references-sinkhorn2> :ref:[10] <references-sinkhorn2>
"""

M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)

if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
res = sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
if log:
return nx.sum(M * res[0]), res[1]
else:
return nx.sum(M * res)

else:

if method.lower() == 'sinkhorn':
return sinkhorn_knopp(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return sinkhorn_log(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, warmstart=warmstart,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized(a, b, M, reg, numItermax=numItermax,
stopThr=stopThr, warmstart=warmstart,
verbose=verbose, log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

[docs]def sinkhorn_knopp(a, b, M, reg, numItermax=1000, stopThr=1e-9,
verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target
weights (histograms, both sum to 1)

The algorithm used for solving the problem is the Sinkhorn-Knopp
matrix scaling algorithm as proposed in :ref:[2] <references-sinkhorn-knopp>

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:\mathbf{M} if :math:\mathbf{b} is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-sinkhorn-knopp:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)

if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)

# init data
dim_a = len(a)
dim_b = b.shape[0]

if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0

if log:
log = {'err': []}

# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])

K = nx.exp(M / (-reg))

Kp = (1 / a).reshape(-1, 1) * K

err = 1
for ii in range(numItermax):
uprev = u
vprev = v
KtransposeU = nx.dot(K.T, u)
v = b / KtransposeU
u = 1. / nx.dot(Kp, v)

if (nx.any(KtransposeU == 0)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Warning: numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
tmp2 = nx.einsum('ik,ij,jk->jk', u, K, v)
else:
# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.einsum('i,ij,j->j', u, K, v)
err = nx.norm(tmp2 - b)  # violation of marginal
if log:
log['err'].append(err)

if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
log['u'] = u
log['v'] = v

if n_hists:  # return only loss
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
return res

else:  # return OT matrix

if log:
return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
else:
return u.reshape((-1, 1)) * K * v.reshape((1, -1))

[docs]def sinkhorn_log(a, b, M, reg, numItermax=1000, stopThr=1e-9, verbose=False,
log=False, warn=True, warmstart=None, **kwargs):
r"""
Solve the entropic regularization optimal transport problem in log space
and return the OT matrix

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target weights (histograms, both sum to 1)

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm  :ref:[2] <references-sinkhorn-log> with the
implementation from :ref:[34] <references-sinkhorn-log>

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:\mathbf{M} if :math:\mathbf{b} is a matrix (return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-sinkhorn-log:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013

.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
Trouvé, A., & Peyré, G. (2019, April). Interpolating between
optimal transport and MMD using Sinkhorn divergences. In The
22nd International Conference on Artificial Intelligence and
Statistics (pp. 2681-2690). PMLR.

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)

if len(a) == 0:
a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
if len(b) == 0:
b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)

# init data
dim_a = len(a)
dim_b = b.shape[0]

if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0

# in case of multiple historgrams
if n_hists > 1 and warmstart is None:
warmstart = [None] * n_hists

if n_hists:  # we do not want to use tensors sor we do a loop

lst_loss = []
lst_u = []
lst_v = []

for k in range(n_hists):
res = sinkhorn_log(a, b[:, k], M, reg, numItermax=numItermax, stopThr=stopThr,
verbose=verbose, log=log, warmstart=warmstart[k], **kwargs)

if log:
lst_loss.append(nx.sum(M * res[0]))
lst_u.append(res[1]['log_u'])
lst_v.append(res[1]['log_v'])
else:
lst_loss.append(nx.sum(M * res))
res = nx.stack(lst_loss)
if log:
log = {'log_u': nx.stack(lst_u, 1),
'log_v': nx.stack(lst_v, 1), }
log['u'] = nx.exp(log['log_u'])
log['v'] = nx.exp(log['log_v'])
return res, log
else:
return res

else:

if log:
log = {'err': []}

Mr = - M / reg

# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
u = nx.zeros(dim_a, type_as=M)
v = nx.zeros(dim_b, type_as=M)
else:
u, v = warmstart

def get_logT(u, v):
if n_hists:
return Mr[:, :, None] + u + v
else:
return Mr + u[:, None] + v[None, :]

loga = nx.log(a)
logb = nx.log(b)

err = 1
for ii in range(numItermax):

v = logb - nx.logsumexp(Mr + u[:, None], 0)
u = loga - nx.logsumexp(Mr + v[None, :], 1)

if ii % 10 == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations

# compute right marginal tmp2= (diag(u)Kdiag(v))^T1
tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
err = nx.norm(tmp2 - b)  # violation of marginal
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")

if log:
log['niter'] = ii
log['log_u'] = u
log['log_v'] = v
log['u'] = nx.exp(u)
log['v'] = nx.exp(v)

return nx.exp(get_logT(u, v)), log

else:
return nx.exp(get_logT(u, v))

[docs]def greenkhorn(a, b, M, reg, numItermax=10000, stopThr=1e-9, verbose=False,
log=False, warn=True, warmstart=None):
r"""
Solve the entropic regularization optimal transport problem and return the OT matrix

The algorithm used is based on the paper :ref:[22] <references-greenkhorn>
which is a stochastic version of the Sinkhorn-Knopp
algorithm :ref:[2] <references-greenkhorn>

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target
weights (histograms, both sum to 1)

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
samples in the target domain, compute sinkhorn with multiple targets
and fixed :math:\mathbf{M} if :math:\mathbf{b} is a matrix
(return OT loss + dual variables in log)
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.greenkhorn(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-greenkhorn:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013

.. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
approximation algorithms for optimal transport via Sinkhorn
iteration, Advances in Neural Information Processing
Systems (NIPS) 31, 2017

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)
if nx.__name__ in ("jax", "tf"):
raise TypeError("JAX or TF arrays have been received. Greenkhorn is not "
"compatible with  neither JAX nor TF")

if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]

dim_a = a.shape[0]
dim_b = b.shape[0]

K = nx.exp(-M / reg)

if warmstart is None:
u = nx.full((dim_a,), 1. / dim_a, type_as=K)
v = nx.full((dim_b,), 1. / dim_b, type_as=K)
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
G = u[:, None] * K * v[None, :]

viol = nx.sum(G, axis=1) - a
viol_2 = nx.sum(G, axis=0) - b
stopThr_val = 1
if log:
log = dict()
log['u'] = u
log['v'] = v

for ii in range(numItermax):
i_1 = nx.argmax(nx.abs(viol))
i_2 = nx.argmax(nx.abs(viol_2))
m_viol_1 = nx.abs(viol[i_1])
m_viol_2 = nx.abs(viol_2[i_2])
stopThr_val = nx.maximum(m_viol_1, m_viol_2)

if m_viol_1 > m_viol_2:
old_u = u[i_1]
new_u = a[i_1] / nx.dot(K[i_1, :], v)
G[i_1, :] = new_u * K[i_1, :] * v

viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
viol_2 += (K[i_1, :].T * (new_u - old_u) * v)
u[i_1] = new_u
else:
old_v = v[i_2]
new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
G[:, i_2] = u * K[:, i_2] * new_v
# aviol = (G@one_m - a)
# aviol_2 = (G.T@one_n - b)
viol += (-old_v + new_v) * K[:, i_2] * u
viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
v[i_2] = new_v

if stopThr_val <= stopThr:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")

if log:
log["n_iter"] = ii
log['u'] = u
log['v'] = v

if log:
return G, log
else:
return G

[docs]def sinkhorn_stabilized(a, b, M, reg, numItermax=1000, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=20,
log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization OT problem with log stabilization

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target
weights (histograms, both sum to 1)

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:[2] <references-sinkhorn-stabilized>
but with the log stabilization
proposed in :ref:[10] <references-sinkhorn-stabilized> an defined in
:ref:[9] <references-sinkhorn-stabilized> (Algo 3.1) .

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:\mathbf{u} or :math:\mathbf{v}
for log scaling
warmstart : table of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> import ot
>>> a=[.5,.5]
>>> b=[.5,.5]
>>> M=[[0.,1.],[1.,0.]]
>>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-sinkhorn-stabilized:
References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT

"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)

if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]

# test if multiple target
if len(b.shape) > 1:
n_hists = b.shape[1]
a = a[:, None]
else:
n_hists = 0

# init data
dim_a = len(a)
dim_b = len(b)

if log:
log = {'err': []}

# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart

if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
u /= dim_a
v /= dim_b

def get_K(alpha, beta):
"""log space computation"""
return nx.exp(-(M - alpha.reshape((dim_a, 1))
- beta.reshape((1, dim_b))) / reg)

def get_Gamma(alpha, beta, u, v):
"""log space gamma computation"""
return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b)))
/ reg + nx.log(u.reshape((dim_a, 1))) + nx.log(v.reshape((1, dim_b))))

K = get_K(alpha, beta)
transp = K
err = 1
for ii in range(numItermax):

uprev = u
vprev = v

# sinkhorn update
v = b / (nx.dot(K.T, u))
u = a / (nx.dot(K, v))

# remove numerical problems and store them in K
if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
if n_hists:
alpha, beta = alpha + reg * \
nx.max(nx.log(u), 1), beta + reg * nx.max(nx.log(v))
else:
alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
else:
u = nx.ones(dim_a, type_as=M) / dim_a
v = nx.ones(dim_b, type_as=M) / dim_b
K = get_K(alpha, beta)

if ii % print_period == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
if n_hists:
err_u = nx.max(nx.abs(u - uprev))
err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
err_v = nx.max(nx.abs(v - vprev))
err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
err = 0.5 * (err_u + err_v)
else:
transp = get_Gamma(alpha, beta, u, v)
err = nx.norm(nx.sum(transp, axis=0) - b)
if log:
log['err'].append(err)

if verbose:
if ii % (print_period * 20) == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

if err <= stopThr:
break

if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %d' % ii)
u = uprev
v = vprev
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
if n_hists:
alpha = alpha[:, None]
beta = beta[:, None]
logu = alpha / reg + nx.log(u)
logv = beta / reg + nx.log(v)
log["n_iter"] = ii
log['logu'] = logu
log['logv'] = logv
log['alpha'] = alpha + reg * nx.log(u)
log['beta'] = beta + reg * nx.log(v)
log['warmstart'] = (log['alpha'], log['beta'])
if n_hists:
res = nx.stack([
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
])
return res, log

else:
return get_Gamma(alpha, beta, u, v), log
else:
if n_hists:
res = nx.stack([
nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
for i in range(n_hists)
])
return res
else:
return get_Gamma(alpha, beta, u, v)

[docs]def sinkhorn_epsilon_scaling(a, b, M, reg, numItermax=100, epsilon0=1e4,
numInnerItermax=100, tau=1e3, stopThr=1e-9,
warmstart=None, verbose=False, print_period=10,
log=False, warn=True, **kwargs):
r"""
Solve the entropic regularization optimal transport problem with log
stabilization and epsilon scaling.
The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg}\cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (dim_a, dim_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target weights (histograms, both sum to 1)

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:[2] <references-sinkhorn-epsilon-scaling>
but with the log stabilization
proposed in :ref:[10] <references-sinkhorn-epsilon-scaling> and the log scaling
proposed in :ref:[9] <references-sinkhorn-epsilon-scaling> algorithm 3.2

Parameters
----------
a : array-like, shape (dim_a,)
samples weights in the source domain
b : array-like, shape (dim_b,)
samples in the target domain
M : array-like, shape (dim_a, dim_b)
loss matrix
reg : float
Regularization term >0
tau : float
threshold for max value in :math:\mathbf{u} or :math:\mathbf{b}
for log scaling
warmstart : tuple of vectors
if given then starting values for alpha and beta log scalings
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations in the inner slog stabilized sinkhorn
epsilon0 : int, optional
first epsilon regularization value (then exponential decrease to reg)
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
gamma : array-like, shape (dim_a, dim_b)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
array([[0.36552929, 0.13447071],
[0.13447071, 0.36552929]])

.. _references-sinkhorn-epsilon-scaling:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.

--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)

if len(a) == 0:
a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
if len(b) == 0:
b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]

# init data
dim_a = len(a)
dim_b = len(b)

# nrelative umerical precision with 64 bits
numItermin = 35
numItermax = max(numItermin, numItermax)  # ensure that last velue is exact

ii = 0
if log:
log = {'err': []}

# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
else:
alpha, beta = warmstart

# print(np.min(K))
def get_reg(n):  # exponential decreasing
return (epsilon0 - reg) * np.exp(-n) + reg

err = 1
for ii in range(numItermax):

regi = get_reg(ii)

G, logi = sinkhorn_stabilized(a, b, M, regi,
numItermax=numInnerItermax, stopThr=stopThr,
warmstart=(alpha, beta), verbose=False,
print_period=20, tau=tau, log=True)

alpha = logi['alpha']
beta = logi['beta']

if ii % (print_period) == 0:  # spsion nearly converged
# we can speed up the process by checking for the error only all
# the 10th iterations
transp = G
err = nx.norm(nx.sum(transp, axis=0) - b) ** 2 + \
nx.norm(nx.sum(transp, axis=1) - a) ** 2
if log:
log['err'].append(err)

if verbose:
if ii % (print_period * 10) == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

if err <= stopThr and ii > numItermin:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['alpha'] = alpha
log['beta'] = beta
log['warmstart'] = (log['alpha'], log['beta'])
log['niter'] = ii
return G, log
else:
return G

[docs]def geometricBar(weights, alldistribT):
"""return the weighted geometric mean of distributions"""
weights, alldistribT = list_to_array(weights, alldistribT)
nx = get_backend(weights, alldistribT)
assert (len(weights) == alldistribT.shape[1])
return nx.exp(nx.dot(nx.log(alldistribT), weights.T))

[docs]def geometricMean(alldistribT):
"""return the  geometric mean of distributions"""
alldistribT = list_to_array(alldistribT)
nx = get_backend(alldistribT)
return nx.exp(nx.mean(nx.log(alldistribT), axis=1))

[docs]def projR(gamma, p):
"""return the KL projection on the row constrints """
gamma, p = list_to_array(gamma, p)
nx = get_backend(gamma, p)
return (gamma.T * p / nx.maximum(nx.sum(gamma, axis=1), 1e-10)).T

[docs]def projC(gamma, q):
"""return the KL projection on the column constrints """
gamma, q = list_to_array(gamma, q)
nx = get_backend(gamma, q)
return gamma * q / nx.maximum(nx.sum(gamma, axis=0), 1e-10)

[docs]def barycenter(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:\mathbf{A}

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein
distance (see :py:func:ot.bregman.sinkhorn)
if method is sinkhorn or sinkhorn_stabilized or sinkhorn_log.
- :math:\mathbf{a}_i are training distributions in the columns of matrix
:math:\mathbf{A}
- reg and :math:\mathbf{M} are respectively the regularization term and
the cost matrix for OT

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
algorithm as proposed in :ref:[3] <references-barycenter>

Parameters
----------
A : array-like, shape (dim, n_hists)
n_hists training distributions :math:\mathbf{a}_i of size dim
M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
method used for the solver either 'sinkhorn' or 'sinkhorn_stabilized' or 'sinkhorn_log'
weights : array-like, shape (n_hists,)
Weights of each histogram :math:\mathbf{a}_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-barycenter:
References
----------

.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
Iterative Bregman projections for regularized transportation problems.
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.

"""

if method.lower() == 'sinkhorn':
return barycenter_sinkhorn(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_stabilized(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
elif method.lower() == 'sinkhorn_log':
return _barycenter_sinkhorn_log(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

[docs]def barycenter_sinkhorn(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:\mathbf{A}

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein distance
(see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{a}_i are training distributions in the columns of matrix
:math:\mathbf{A}
- reg and :math:\mathbf{M} are respectively the regularization term and
the cost matrix for OT

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
scaling algorithm as proposed in :ref:[3]<references-barycenter-sinkhorn>.

Parameters
----------
A : array-like, shape (dim, n_hists)
n_hists training distributions :math:\mathbf{a}_i of size dim
M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
weights : array-like, shape (n_hists,)
Weights of each histogram :math:\mathbf{a}_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-barycenter-sinkhorn:
References
----------

.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
Iterative Bregman projections for regularized transportation problems.
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.

"""

A, M = list_to_array(A, M)

nx = get_backend(A, M)

if weights is None:
weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])

if log:
log = {'err': []}

K = nx.exp(-M / reg)

err = 1

UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)

u = (geometricMean(UKv) / UKv.T).T

for ii in range(numItermax):

UKv = u * nx.dot(K.T, A / nx.dot(K, u))
u = (u.T * geometricBar(weights, UKv)).T / UKv

if ii % 10 == 1:
err = nx.sum(nx.std(UKv, axis=1))

# log and verbose print
if log:
log['err'].append(err)

if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
return geometricBar(weights, UKv), log
else:
return geometricBar(weights, UKv)

[docs]def free_support_sinkhorn_barycenter(measures_locations, measures_weights, X_init, reg, b=None, weights=None,
numItermax=100, numInnerItermax=1000, stopThr=1e-7, verbose=False, log=None,
**kwargs):
r"""
Solves the free support (locations of the barycenters are optimized, not the weights) regularized Wasserstein barycenter problem (i.e. the weighted Frechet mean for the 2-Sinkhorn divergence), formally:

.. math::
\min_\mathbf{X} \quad \sum_{i=1}^N w_i W_{reg}^2(\mathbf{b}, \mathbf{X}, \mathbf{a}_i, \mathbf{X}_i)

where :

- :math:w \in \mathbb{(0, 1)}^{N}'s are the barycenter weights and sum to one
- measure_weights denotes the :math:\mathbf{a}_i \in \mathbb{R}^{k_i}: empirical measures weights (on simplex)
- measures_locations denotes the :math:\mathbf{X}_i \in \mathbb{R}^{k_i, d}: empirical measures atoms locations
- :math:\mathbf{b} \in \mathbb{R}^{k} is the desired weights vector of the barycenter

This problem is considered in :ref:[20] <references-free-support-barycenter> (Algorithm 2).
There are two differences with the following codes:

- we do not optimize over the weights
- we do not do line search for the locations updates, we use i.e. :math:\theta = 1 in
:ref:[20] <references-free-support-barycenter> (Algorithm 2). This can be seen as a discrete
implementation of the fixed-point algorithm of
:ref:[43] <references-free-support-barycenter> proposed in the continuous setting.
- at each iteration, instead of solving an exact OT problem, we use the Sinkhorn algorithm for calculating the
transport plan in :ref:[20] <references-free-support-barycenter> (Algorithm 2).

Parameters
----------
measures_locations : list of N (k_i,d) array-like
The discrete support of a measure supported on :math:k_i locations of a d-dimensional space
(:math:k_i can be different for each element of the list)
measures_weights : list of N (k_i,) array-like
Numpy arrays where each numpy array has :math:k_i non-negatives values summing to one
representing the weights of each discrete input measure

X_init : (k,d) array-like
Initialization of the support locations (on k atoms) of the barycenter
reg : float
Regularization term >0
b : (k,) array-like
Initialization of the weights of the barycenter (non-negatives, sum to 1)
weights : (N,) array-like
Initialization of the coefficients of the barycenter (non-negatives, sum to 1)

numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations when calculating the transport plans with Sinkhorn
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True

Returns
-------
X : (k,d) array-like
Support locations (on k atoms) of the barycenter

--------
ot.bregman.sinkhorn : Entropic regularized OT solver
ot.lp.free_support_barycenter : Barycenter solver based on Linear Programming

.. _references-free-support-barycenter:
References
----------
.. [20] Cuturi, Marco, and Arnaud Doucet. "Fast computation of Wasserstein barycenters." International Conference on Machine Learning. 2014.

.. [43] Álvarez-Esteban, Pedro C., et al. "A fixed-point approach to barycenters in Wasserstein space." Journal of Mathematical Analysis and Applications 441.2 (2016): 744-762.

"""
nx = get_backend(*measures_locations, *measures_weights, X_init)

iter_count = 0

N = len(measures_locations)
k = X_init.shape[0]
d = X_init.shape[1]
if b is None:
b = nx.ones((k,), type_as=X_init) / k
if weights is None:
weights = nx.ones((N,), type_as=X_init) / N

X = X_init

log_dict = {}
displacement_square_norms = []

displacement_square_norm = stopThr + 1.

while (displacement_square_norm > stopThr and iter_count < numItermax):

T_sum = nx.zeros((k, d), type_as=X_init)

for (measure_locations_i, measure_weights_i, weight_i) in zip(measures_locations, measures_weights, weights):
M_i = dist(X, measure_locations_i)
T_i = sinkhorn(b, measure_weights_i, M_i, reg=reg,
numItermax=numInnerItermax, **kwargs)
T_sum = T_sum + weight_i * 1. / \
b[:, None] * nx.dot(T_i, measure_locations_i)

displacement_square_norm = nx.sum((T_sum - X) ** 2)
if log:
displacement_square_norms.append(displacement_square_norm)

X = T_sum

if verbose:
print('iteration %d, displacement_square_norm=%f\n',
iter_count, displacement_square_norm)

iter_count += 1

if log:
log_dict['displacement_square_norms'] = displacement_square_norms
return X, log_dict
else:
return X

def _barycenter_sinkhorn_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic wasserstein barycenter in log-domain
"""

A, M = list_to_array(A, M)
dim, n_hists = A.shape

nx = get_backend(A, M)

if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and tf. Use numpy or torch arrays instead."
)

if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert (len(weights) == A.shape[1])

if log:
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros(dim, type_as=A)
for k in range(n_hists):
f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
log_bar = log_bar + weights[k] * log_KU[:, k]

if ii % 10 == 1:
err = nx.exp(G + log_KU).std(axis=1).sum()

# log and verbose print
if log:
log['err'].append(err)

if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

G = log_bar[:, None] - log_KU

else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)

[docs]def barycenter_stabilized(A, M, reg, tau=1e10, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:\mathbf{A} with stabilization.

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein
distance (see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{a}_i are training distributions in the columns of matrix
:math:\mathbf{A}
- reg and :math:\mathbf{M} are respectively the regularization term and
the cost matrix for OT

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling
algorithm as proposed in :ref:[3] <references-barycenter-stabilized>

Parameters
----------
A : array-like, shape (dim, n_hists)
n_hists training distributions :math:\mathbf{a}_i of size dim
M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
tau : float
threshold for max value in :math:\mathbf{u} or :math:\mathbf{v}
for log scaling
weights : array-like, shape (n_hists,)
Weights of each histogram :math:\mathbf{a}_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-barycenter-stabilized:
References
----------

.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015).
Iterative Bregman projections for regularized transportation problems.
SIAM Journal on Scientific Computing, 37(2), A1111-A1138.

"""

A, M = list_to_array(A, M)

nx = get_backend(A, M)

dim, n_hists = A.shape
if weights is None:
weights = nx.ones((n_hists,), type_as=M) / n_hists
else:
assert (len(weights) == A.shape[1])

if log:
log = {'err': []}

u = nx.ones((dim, n_hists), type_as=M) / dim
v = nx.ones((dim, n_hists), type_as=M) / dim

K = nx.exp(-M / reg)

err = 1.
alpha = nx.zeros((dim,), type_as=M)
beta = nx.zeros((dim,), type_as=M)
q = nx.ones((dim,), type_as=M) / dim
for ii in range(numItermax):
qprev = q
Kv = nx.dot(K, v)
u = A / Kv
Ktu = nx.dot(K.T, u)
q = geometricBar(weights, Ktu)
Q = q[:, None]
v = Q / Ktu
absorbing = False
if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
alpha += reg * nx.log(nx.max(u, 1))
beta += reg * nx.log(nx.max(v, 1))
K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
v = nx.ones(tuple(v.shape), type_as=v)
Kv = nx.dot(K, v)
if (nx.any(Ktu == 0.)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % ii)
q = qprev
break
if (ii % 10 == 0 and not absorbing) or ii == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = nx.max(nx.abs(u * Kv - A))
if log:
log['err'].append(err)
if err < stopThr:
break
if verbose:
if ii % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

else:
if warn:
warnings.warn("Stabilized Sinkhorn did not converge." +
"Try a larger entropy reg" +
"Or a larger absorption threshold tau.")
if log:
log['niter'] = ii
log['logu'] = nx.log(u + 1e-16)
log['logv'] = nx.log(v + 1e-16)
return q, log
else:
return q

[docs]def barycenter_debiased(A, M, reg, weights=None, method="sinkhorn", numItermax=10000,
stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs):
r"""Compute the debiased Sinkhorn barycenter of distributions A

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:S_{reg}(\cdot,\cdot) is the debiased Sinkhorn divergence
(see :py:func:ot.bregman.empirical_sinkhorn_divergence)
- :math:\mathbf{a}_i are training distributions in the columns of matrix
:math:\mathbf{A}
- reg and :math:\mathbf{M} are respectively the regularization term and
the cost matrix for OT

The algorithm used for solving the problem is the debiased Sinkhorn
algorithm as proposed in :ref:[37] <references-barycenter-debiased>

Parameters
----------
A : array-like, shape (dim, n_hists)
n_hists training distributions :math:\mathbf{a}_i of size dim
M : array-like, shape (dim, dim)
loss matrix for OT
reg : float
Regularization term > 0
method : str (optional)
method used for the solver either 'sinkhorn' or 'sinkhorn_log'
weights : array-like, shape (n_hists,)
Weights of each histogram :math:\mathbf{a}_i on the simplex (barycentric coodinates)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : (dim,) array-like
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-barycenter-debiased:
References
----------
.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

if method.lower() == 'sinkhorn':
return _barycenter_debiased(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, **kwargs)
elif method.lower() == 'sinkhorn_log':
return _barycenter_debiased_log(A, M, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

def _barycenter_debiased(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False, warn=True):
r"""Compute the debiased sinkhorn barycenter of distributions A.
"""

A, M = list_to_array(A, M)

nx = get_backend(A, M)

if weights is None:
weights = nx.ones((A.shape[1],), type_as=A) / A.shape[1]
else:
assert (len(weights) == A.shape[1])

if log:
log = {'err': []}

K = nx.exp(-M / reg)

err = 1

UKv = nx.dot(K, (A.T / nx.sum(K, axis=0)).T)

u = (geometricMean(UKv) / UKv.T).T
c = nx.ones(A.shape[0], type_as=A)
bar = nx.ones(A.shape[0], type_as=A)

for ii in range(numItermax):
bold = bar
UKv = nx.dot(K, A / nx.dot(K, u))
bar = c * geometricBar(weights, UKv)
u = bar[:, None] / UKv
c = (c * bar / nx.dot(K, c)) ** 0.5

if ii % 10 == 9:
err = abs(bar - bold).max() / max(bar.max(), 1.)

# log and verbose print
if log:
log['err'].append(err)

# debiased Sinkhorn does not converge monotonically
# guarantee a few iterations are done before stopping
if err < stopThr and ii > 20:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
return bar, log
else:
return bar

def _barycenter_debiased_log(A, M, reg, weights=None, numItermax=1000,
stopThr=1e-4, verbose=False, log=False,
warn=True):
r"""Compute the debiased sinkhorn barycenter in log domain.
"""

A, M = list_to_array(A, M)
dim, n_hists = A.shape

nx = get_backend(A, M)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)

if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert (len(weights) == A.shape[1])

if log:
log = {'err': []}

M = - M / reg
logA = nx.log(A + 1e-15)
log_KU, G = nx.zeros((2, *logA.shape), type_as=A)
c = nx.zeros(dim, type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros(dim, type_as=A)
for k in range(n_hists):
f = logA[:, k] - nx.logsumexp(M + G[None, :, k], axis=1)
log_KU[:, k] = nx.logsumexp(M + f[:, None], axis=0)
log_bar += weights[k] * log_KU[:, k]
log_bar += c
if ii % 10 == 1:
err = nx.exp(G + log_KU).std(axis=1).sum()

# log and verbose print
if log:
log['err'].append(err)

if err < stopThr and ii > 20:
break
if verbose:
if ii % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

G = log_bar[:, None] - log_KU
for _ in range(10):
c = 0.5 * (c + log_bar - nx.logsumexp(M + c[:, None], axis=0))

else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)

[docs]def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
stopThr=1e-4, verbose=False, log=False,
warn=True, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:\mathbf{A}
where :math:\mathbf{A} is a collection of 2D images.

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein
distance (see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{a}_i are training distributions (2D images) in the mast two dimensions
of matrix :math:\mathbf{A}
- reg is the regularization strength scalar value

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
as proposed in :ref:[21] <references-convolutional-barycenter-2d>

Parameters
----------
A : array-like, shape (n_hists, width, height)
n distributions (2D images) of size width x height
reg : float
Regularization term >0
weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
method : string, optional
method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-convolutional-barycenter-2d:
References
----------

.. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher,
A., Nguyen, A. & Guibas, L. (2015).     Convolutional wasserstein distances:
Efficient optimal transportation on geometric domains. ACM Transactions
on Graphics (TOG), 34(4), 66

.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
International Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

if method.lower() == 'sinkhorn':
return _convolutional_barycenter2d(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return _convolutional_barycenter2d_log(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
stopThr=1e-9, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
"""

A = list_to_array(A)

nx = get_backend(A)

if weights is None:
weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])

if log:
log = {'err': []}

bar = nx.ones(A.shape[1:], type_as=A)
bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1])
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, A.shape[2])
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

KU = convol_imgs(U)
for ii in range(numItermax):
V = bar[None] / KU
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
bar = nx.exp(
nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
)
if ii % 10 == 9:
err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
log['U'] = U
return bar, log
else:
return bar

def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
stopThr=1e-4, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images in log-domain.
"""

A = list_to_array(A)

nx = get_backend(A)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)

n_hists, width, height = A.shape

if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == n_hists)

if log:
log = {'err': []}

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img

logA = nx.log(A + stabThr)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros((width, height), type_as=A)
for k in range(n_hists):
f = logA[k] - convol_img(G[k])
log_KU[k] = convol_img(f)
log_bar = log_bar + weights[k] * log_KU[k]

if ii % 10 == 9:
err = nx.exp(G + log_KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)

[docs]def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
numItermax=10000, stopThr=1e-3,
verbose=False, log=False, warn=True,
**kwargs):
r"""Compute the debiased sinkhorn barycenter of distributions :math:\mathbf{A}
where :math:\mathbf{A} is a collection of 2D images.

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:S_{reg}(\cdot,\cdot) is the debiased entropic regularized Wasserstein
distance (see :py:func:ot.bregman.barycenter_debiased)
- :math:\mathbf{a}_i are training distributions (2D images) in the mast two
dimensions of matrix :math:\mathbf{A}
- reg is the regularization strength scalar value

The algorithm used for solving the problem is the debiased Sinkhorn scaling
algorithm as proposed in :ref:[37] <references-convolutional-barycenter2d-debiased>

Parameters
----------
A : array-like, shape (n_hists, width, height)
n distributions (2D images) of size width x height
reg : float
Regularization term >0
weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
method : string, optional
method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-convolutional-barycenter2d-debiased:
References
----------

.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

if method.lower() == 'sinkhorn':
return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
stopThr=1e-3, stabThr=1e-15, verbose=False,
log=False, warn=True):
r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
"""

A = list_to_array(A)
n_hists, width, height = A.shape

nx = get_backend(A)

if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == n_hists)

if log:
log = {'err': []}

bar = nx.ones((width, height), type_as=A)
bar /= width * height
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
c = nx.ones(A.shape[1:], type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, height)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

KU = convol_imgs(U)
for ii in range(numItermax):
V = bar[None] / KU
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
bar = c * nx.exp(
nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
)

for _ in range(10):
c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5

if ii % 10 == 9:
err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

# debiased Sinkhorn does not converge monotonically
# guarantee a few iterations are done before stopping
if err < stopThr and ii > 20:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
log['U'] = U
return bar, log
else:
return bar

def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
stopThr=1e-3, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the debiased barycenter of 2D images in log-domain.
"""

A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == A.shape[0])

if log:
log = {'err': []}

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img

logA = nx.log(A + stabThr)
log_bar, c = nx.zeros((2, width, height), type_as=A)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros((width, height), type_as=A)
for k in range(n_hists):
f = logA[k] - convol_img(G[k])
log_KU[k] = convol_img(f)
log_bar = log_bar + weights[k] * log_KU[k]
log_bar += c
for _ in range(10):
c = 0.5 * (c + log_bar - convol_img(c))

if ii % 10 == 9:
err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)

[docs]def unmix(a, D, M, M0, h0, reg, reg0, alpha, numItermax=1000,
stopThr=1e-3, verbose=False, log=False, warn=True):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance

The function solve the following optimization problem:

.. math::

(1 - \alpha)  W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
\alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})

where :

- :math:W_{M,reg}(\cdot,\cdot) is the entropic regularized Wasserstein distance
with :math:\mathbf{M} loss matrix (see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{D} is a dictionary of n_atoms atoms of dimension dim_a,
its expected shape is (dim_a, n_atoms)
- :math:\mathbf{h} is the estimated unmixing of dimension n_atoms
- :math:\mathbf{a} is an observed distribution of dimension dim_a
- :math:\mathbf{h}_0 is a prior on :math:\mathbf{h} of dimension dim_prior
- reg and :math:\mathbf{M} are respectively the regularization term and the
cost matrix (dim_a, dim_a) for OT data fitting
- reg:math:_0 and :math:\mathbf{M_0} are respectively the regularization
term and the cost matrix (dim_prior, n_atoms) regularization
- :math:\alpha weight data fitting and regularization

The optimization problem is solved following the algorithm described
in :ref:[4] <references-unmix>

Parameters
----------
a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
D : array-like, shape (dim_a, n_atoms)
dictionary matrix
M : array-like, shape (dim_a, dim_a)
loss matrix
M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
Regularization term >0 (Wasserstein reg with h0)
alpha : float
How much should we trust the prior ([0,1])
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-unmix:
References
----------

.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
Supervised planetary unmixing with optimal transport, Workshop
on Hyperspectral Image and Signal Processing :
Evolution in Remote Sensing (WHISPERS), 2016.
"""

a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)

nx = get_backend(a, D, M, M0, h0)

# M = M/np.median(M)
K = nx.exp(-M / reg)

# M0 = M0/np.median(M0)
K0 = nx.exp(-M0 / reg0)
old = h0

err = 1
# log = {'niter':0, 'all_err':[]}
if log:
log = {'err': []}

for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionnary
inv_new = nx.dot(D, new)
other = nx.sum(K, axis=1)
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn("Unmixing algorithm did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
return nx.sum(K0, axis=1), log
else:
return nx.sum(K0, axis=1)

[docs]def jcpot_barycenter(Xs, Ys, Xt, reg, metric='sqeuclidean', numItermax=100,
stopThr=1e-6, verbose=False, log=False, warn=True, **kwargs):
r'''Joint OT and proportion estimation for multi-source target shift as
proposed in :ref:[27] <references-jcpot-barycenter>

The function solves the following optimization problem:

.. math::

\mathbf{h} = \mathop{\arg \min}_{\mathbf{h}} \quad \sum_{k=1}^{K} \lambda_k
W_{reg}((\mathbf{D}_2^{(k)} \mathbf{h})^T, \mathbf{a})

s.t. \ \forall k, \mathbf{D}_1^{(k)} \gamma_k \mathbf{1}_n= \mathbf{h}

where :

- :math:\lambda_k is the weight of k-th source domain
- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein distance
(see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{D}_2^{(k)} is a matrix of weights related to k-th source domain
defined as in [p. 5, :ref:27 <references-jcpot-barycenter>], its expected shape
is :math:(n_k, C) where :math:n_k is the number of elements in the k-th source
domain and C is the number of classes
- :math:\mathbf{h} is a vector of estimated proportions in the target domain of size C
- :math:\mathbf{a} is a uniform vector of weights in the target domain of size n
- :math:\mathbf{D}_1^{(k)} is a matrix of class assignments defined as in
[p. 5, :ref:27 <references-jcpot-barycenter>], its expected shape is :math:(n_k, C)

The problem consist in solving a Wasserstein barycenter problem to estimate
the proportions :math:\mathbf{h} in the target domain.

The algorithm used for solving the problem is the Iterative Bregman projections algorithm
with two sets of marginal constraints related to the unknown vector
:math:\mathbf{h} and uniform target distribution.

Parameters
----------
Xs : list of K array-like(nsk,d)
features of all source domains' samples
Ys : list of K array-like(nsk,)
labels of all source domains' samples
Xt : array-like (nt,d)
samples in the target domain
reg : float
Regularization term > 0
metric : string, optional (default="sqeuclidean")
The ground metric for the Wasserstein problem
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on relative change in the barycenter (>0)
verbose : bool, optional (default=False)
Controls the verbosity of the optimization algorithm
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
h : (C,) array-like
proportion estimation in the target domain
log : dict
log dictionary return only if log==True in parameters

.. _references-jcpot-barycenter:
References
----------

.. [27] Ievgen Redko, Nicolas Courty, Rémi Flamary, Devis Tuia
"Optimal transport for multi-source domain adaptation under target shift",
International Conference on Artificial Intelligence and Statistics (AISTATS), 2019.
'''

Xs = list_to_array(*Xs)
Ys = list_to_array(*Ys)
Xt = list_to_array(Xt)

nx = get_backend(*Xs, *Ys, Xt)

nbclasses = len(nx.unique(Ys[0]))
nbdomains = len(Xs)

# log dictionary
if log:
log = {'niter': 0, 'err': [], 'M': [], 'D1': [], 'D2': [], 'gamma': []}

K = []
M = []
D1 = []
D2 = []

# For each source domain, build cost matrices M, Gibbs kernels K and corresponding matrices D_1 and D_2
for d in range(nbdomains):
dom = {}
nsk = Xs[d].shape[0]  # get number of elements for this domain
dom['nbelem'] = nsk
classes = nx.unique(Ys[d])  # get number of classes for this domain

# format classes to start from 0 for convenience
if nx.min(classes) != 0:
Ys[d] -= nx.min(classes)
classes = nx.unique(Ys[d])

# build the corresponding D_1 and D_2 matrices
Dtmp1 = np.zeros((nbclasses, nsk))
Dtmp2 = np.zeros((nbclasses, nsk))

for c in classes:
nbelemperclass = float(nx.sum(Ys[d] == c))
if nbelemperclass != 0:
Dtmp1[int(c), nx.to_numpy(Ys[d] == c)] = 1.
Dtmp2[int(c), nx.to_numpy(Ys[d] == c)] = 1. / (nbelemperclass)
D1.append(nx.from_numpy(Dtmp1, type_as=Xs[0]))
D2.append(nx.from_numpy(Dtmp2, type_as=Xs[0]))

# build the cost matrix and the Gibbs kernel
Mtmp = dist(Xs[d], Xt, metric=metric)
M.append(Mtmp)

Ktmp = nx.exp(-Mtmp / reg)
K.append(Ktmp)

# uniform target distribution
a = nx.from_numpy(unif(Xt.shape[0]), type_as=Xs[0])

err = 1
old_bary = nx.ones((nbclasses,), type_as=Xs[0])

for ii in range(numItermax):

bary = nx.zeros((nbclasses,), type_as=Xs[0])

# update coupling matrices for marginal constraints w.r.t. uniform target distribution
for d in range(nbdomains):
K[d] = projC(K[d], a)
other = nx.sum(K[d], axis=1)
bary += nx.log(nx.dot(D1[d], other)) / nbdomains

bary = nx.exp(bary)

# update coupling matrices for marginal constraints w.r.t. unknown proportions based on [Prop 4., 27]
for d in range(nbdomains):
new = nx.dot(D2[d].T, bary)
K[d] = projR(K[d], new)

err = nx.norm(bary - old_bary)

old_bary = bary

if log:
log['err'].append(err)

if err < stopThr:
break
if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
else:
if warn:
warnings.warn("Algorithm did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
bary = bary / nx.sum(bary)

if log:
log['niter'] = ii
log['M'] = M
log['D1'] = D1
log['D2'] = D2
log['gamma'] = K
return bary, log
else:
return bary

[docs]def empirical_sinkhorn(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100, verbose=False,
log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem and return the
OT matrix from empirical data

The function solves the following optimization problem:

.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (n_samples_a, n_samples_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target weights (sum to 1)

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
isLazy: boolean, optional
If True, then only calculate the cost matrix by block and return
the dual potentials only (to save memory). If False, calculate full
cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
gamma : array-like, shape (n_samples_a, n_samples_b)
Regularized optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn(X_s, X_t, reg=reg, verbose=False)  # doctest: +NORMALIZE_WHITESPACE
array([[4.99977301e-01,  2.26989344e-05],
[2.26989344e-05,  4.99977301e-01]])

References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
'''

X_s, X_t = list_to_array(X_s, X_t)

nx = get_backend(X_s, X_t)

ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if isLazy:
if log:
dict_log = {"err": []}

log_a, log_b = nx.log(a), nx.log(b)
if warmstart is None:
f, g = nx.zeros((ns,), type_as=a), nx.zeros((nt,), type_as=a)
else:
f, g = warmstart

if isinstance(batchSize, int):
bs, bt = batchSize, batchSize
elif isinstance(batchSize, tuple) and len(batchSize) == 2:
bs, bt = batchSize[0], batchSize[1]
else:
raise ValueError(
"Batch size must be in integer or a tuple of two integers")

range_s, range_t = range(0, ns, bs), range(0, nt, bt)

lse_f = nx.zeros((ns,), type_as=a)
lse_g = nx.zeros((nt,), type_as=a)

X_s_np = nx.to_numpy(X_s)
X_t_np = nx.to_numpy(X_t)

for i_ot in range(numIterMax):

lse_f_cols = []
for i in range_s:
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
lse_f_cols.append(
nx.logsumexp(g[None, :] - M / reg, axis=1)
)
lse_f = nx.concatenate(lse_f_cols, axis=0)
f = log_a - lse_f

lse_g_cols = []
for j in range_t:
M = dist(X_s_np, X_t_np[j:j + bt, :], metric=metric)
M = nx.from_numpy(M, type_as=a)
lse_g_cols.append(
nx.logsumexp(f[:, None] - M / reg, axis=0)
)
lse_g = nx.concatenate(lse_g_cols, axis=0)
g = log_b - lse_g

if (i_ot + 1) % 10 == 0:
m1_cols = []
for i in range_s:
M = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M = nx.from_numpy(M, type_as=a)
m1_cols.append(
nx.sum(nx.exp(f[i:i + bs, None] +
g[None, :] - M / reg), axis=1)
)
m1 = nx.concatenate(m1_cols, axis=0)
err = nx.sum(nx.abs(m1 - a))
if log:
dict_log["err"].append(err)

if verbose and (i_ot + 1) % 100 == 0:
print("Error in marginal at iteration {} = {}".format(
i_ot + 1, err))

if err <= stopThr:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
dict_log["u"] = f
dict_log["v"] = g
return (f, g, dict_log)
else:
return (f, g)

else:
M = dist(X_s, X_t, metric=metric)
if log:
pi, log = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=True, warmstart=warmstart, **kwargs)
return pi, log
else:
pi = sinkhorn(a, b, M, reg, numItermax=numIterMax, stopThr=stopThr,
verbose=verbose, log=False, warmstart=warmstart, **kwargs)
return pi

[docs]def empirical_sinkhorn2(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, isLazy=False, batchSize=100,
verbose=False, log=False, warn=True, warmstart=None, **kwargs):
r'''
Solve the entropic regularization optimal transport problem from empirical
data and return the OT loss

The function solves the following optimization problem:

.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot\Omega(\gamma)

s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

where :

- :math:\mathbf{M} is the (n_samples_a, n_samples_b) metric cost matrix
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target weights (sum to 1)

and returns :math:\langle \gamma^*, \mathbf{M} \rangle_F (without
the entropic contribution).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
isLazy: boolean, optional
If True, then only calculate the cost matrix by block and return
the dual potentials only (to save memory). If False, calculate
full cost matrix and return outputs of sinkhorn function.
batchSize: int or tuple of 2 int, optional
Size of the batches used to compute the sinkhorn update without memory overhead.
When a tuple is provided it sets the size of the left/right batches.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
W : (n_hists) array-like or float
Optimal transportation loss for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------

>>> n_samples_a = 2
>>> n_samples_b = 2
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> b = np.full((n_samples_b, 3), 1/n_samples_b)
>>> empirical_sinkhorn2(X_s, X_t, b=b, reg=reg, verbose=False)
array([4.53978687e-05, 4.53978687e-05, 4.53978687e-05])

References
----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information
Processing Systems (NIPS) 26, 2013

.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling
Algorithms for Entropy Regularized Transport Problems.
arXiv preprint arXiv:1610.06519.

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
'''

X_s, X_t = list_to_array(X_s, X_t)

nx = get_backend(X_s, X_t)

ns, nt = X_s.shape[0], X_t.shape[0]
if a is None:
a = nx.from_numpy(unif(ns), type_as=X_s)
if b is None:
b = nx.from_numpy(unif(nt), type_as=X_s)

if isLazy:
if log:
f, g, dict_log = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
numIterMax=numIterMax,
stopThr=stopThr,
isLazy=isLazy,
batchSize=batchSize,
verbose=verbose, log=log,
warn=warn,
warmstart=warmstart)
else:
f, g = empirical_sinkhorn(X_s, X_t, reg, a, b, metric,
numIterMax=numIterMax,
stopThr=stopThr,
isLazy=isLazy, batchSize=batchSize,
verbose=verbose, log=log,
warn=warn,
warmstart=warmstart)

bs = batchSize if isinstance(batchSize, int) else batchSize[0]
range_s = range(0, ns, bs)

loss = 0

X_s_np = nx.to_numpy(X_s)
X_t_np = nx.to_numpy(X_t)

for i in range_s:
M_block = dist(X_s_np[i:i + bs, :], X_t_np, metric=metric)
M_block = nx.from_numpy(M_block, type_as=a)
pi_block = nx.exp(f[i:i + bs, None] + g[None, :] - M_block / reg)
loss += nx.sum(M_block * pi_block)

if log:
return loss, dict_log
else:
return loss

else:
M = dist(X_s, X_t, metric=metric)

if log:
sinkhorn_loss, log = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss, log
else:
sinkhorn_loss = sinkhorn2(a, b, M, reg, numItermax=numIterMax,
stopThr=stopThr, verbose=verbose, log=log,
warn=warn, warmstart=warmstart, **kwargs)
return sinkhorn_loss

[docs]def empirical_sinkhorn_divergence(X_s, X_t, reg, a=None, b=None, metric='sqeuclidean',
numIterMax=10000, stopThr=1e-9, verbose=False,
log=False, warn=True, warmstart=None, **kwargs):
r'''
Compute the sinkhorn divergence loss from empirical data

The function solves the following optimization problems and return the
sinkhorn divergence :math:S:

.. math::
W &= \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot\Omega(\gamma)

W_a &= \min_{\gamma_a} \quad \langle \gamma_a, \mathbf{M_a} \rangle_F +
\mathrm{reg} \cdot\Omega(\gamma_a)

W_b &= \min_{\gamma_b} \quad \langle \gamma_b, \mathbf{M_b} \rangle_F +
\mathrm{reg} \cdot\Omega(\gamma_b)

S &= W - \frac{W_a + W_b}{2}

.. math::
s.t. \ \gamma \mathbf{1} &= \mathbf{a}

\gamma^T \mathbf{1} &= \mathbf{b}

\gamma &\geq 0

\gamma_a \mathbf{1} &= \mathbf{a}

\gamma_a^T \mathbf{1} &= \mathbf{a}

\gamma_a &\geq 0

\gamma_b \mathbf{1} &= \mathbf{b}

\gamma_b^T \mathbf{1} &= \mathbf{b}

\gamma_b &\geq 0

where :

- :math:\mathbf{M} (resp. :math:\mathbf{M_a}, :math:\mathbf{M_b})
is the (n_samples_a, n_samples_b) metric cost matrix
(resp (n_samples_a, n_samples_a) and (n_samples_b, n_samples_b))
- :math:\Omega is the entropic regularization term
:math:\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})
- :math:\mathbf{a} and :math:\mathbf{b} are source and target weights (sum to 1)

and returns :math:\langle \gamma^*, \mathbf{M} \rangle_F -(\langle \gamma^*_a, \mathbf{M_a} \rangle_F + \langle
\gamma^*_b , \mathbf{M_b} \rangle_F)/2.

.. note: The current implementation does not account for the entropic contributions and thus differs from the
Sinkhorn divergence as introduced in the literature. The possibility to account for the entropic contributions
will be provided in a future release.

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
reg : float
Regularization term >0
a : array-like, shape (n_samples_a,)
samples weights in the source domain
b : array-like, shape (n_samples_b,)
samples weights in the target domain
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors)

Returns
-------
W : (1,) array-like
Optimal transportation symmetrized loss for the given parameters
log : dict
log dictionary return only if log==True in parameters

Examples
--------
>>> n_samples_a = 2
>>> n_samples_b = 4
>>> reg = 0.1
>>> X_s = np.reshape(np.arange(n_samples_a, dtype=np.float64), (n_samples_a, 1))
>>> X_t = np.reshape(np.arange(0, n_samples_b, dtype=np.float64), (n_samples_b, 1))
>>> empirical_sinkhorn_divergence(X_s, X_t, reg)  # doctest: +ELLIPSIS
1.499887176049052

References
----------
.. [23] Aude Genevay, Gabriel Peyré, Marco Cuturi, Learning Generative
Models with Sinkhorn Divergences,  Proceedings of the Twenty-First
International Conference on Artificial Intelligence and Statistics,
(AISTATS) 21, 2018
'''
X_s, X_t = list_to_array(X_s, X_t)

nx = get_backend(X_s, X_t)
if warmstart is None:
warmstart_a, warmstart_b = None, None
else:
u, v = warmstart
warmstart_a = (u, u)
warmstart_b = (v, v)

if log:
sinkhorn_loss_ab, log_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart, **kwargs)

sinkhorn_loss_a, log_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart_a, **kwargs)

sinkhorn_loss_b, log_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart_b, **kwargs)

sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
(sinkhorn_loss_a + sinkhorn_loss_b)

log = {}
log['sinkhorn_loss_ab'] = sinkhorn_loss_ab
log['sinkhorn_loss_a'] = sinkhorn_loss_a
log['sinkhorn_loss_b'] = sinkhorn_loss_b
log['log_sinkhorn_ab'] = log_ab
log['log_sinkhorn_a'] = log_a
log['log_sinkhorn_b'] = log_b

return nx.maximum(0, sinkhorn_div), log

else:
sinkhorn_loss_ab = empirical_sinkhorn2(X_s, X_t, reg, a, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart, **kwargs)

sinkhorn_loss_a = empirical_sinkhorn2(X_s, X_s, reg, a, a, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart_a, **kwargs)

sinkhorn_loss_b = empirical_sinkhorn2(X_t, X_t, reg, b, b, metric=metric,
numIterMax=numIterMax, stopThr=stopThr,
verbose=verbose, log=log, warn=warn,
warmstart=warmstart_b, **kwargs)

sinkhorn_div = sinkhorn_loss_ab - 0.5 * \
(sinkhorn_loss_a + sinkhorn_loss_b)
return nx.maximum(0, sinkhorn_div)

[docs]def screenkhorn(a, b, M, reg, ns_budget=None, nt_budget=None, uniform=False,
restricted=True, maxiter=10000, maxfun=10000, pgtol=1e-09,
verbose=False, log=False):
r"""
Screening Sinkhorn Algorithm for Regularized Optimal Transport

The function solves an approximate dual of Sinkhorn divergence :ref:[2]
<references-screenkhorn> which is written as the following optimization problem:

.. math::

(\mathbf{u}, \mathbf{v}) = \mathop{\arg \min}_{\mathbf{u}, \mathbf{v}} \quad
\mathbf{1}_{ns}^T \mathbf{B}(\mathbf{u}, \mathbf{v}) \mathbf{1}_{nt} -
\langle \kappa \mathbf{u}, \mathbf{a} \rangle -
\langle \frac{1}{\kappa} \mathbf{v}, \mathbf{b} \rangle

where:

.. math::

\mathbf{B}(\mathbf{u}, \mathbf{v}) = \mathrm{diag}(e^\mathbf{u}) \mathbf{K} \mathrm{diag}(e^\mathbf{v}) \text{, with } \mathbf{K} = e^{-\mathbf{M} / \mathrm{reg}} \text{ and}

.. math::

s.t. \ e^{u_i} &\geq \epsilon / \kappa, \forall i \in \{1, \ldots, ns\}

e^{v_j} &\geq \epsilon \kappa, \forall j \in \{1, \ldots, nt\}

The parameters kappa and epsilon are determined w.r.t the couple number
budget of points (ns_budget, nt_budget), see Equation (5)
in :ref:[26] <references-screenkhorn>

Parameters
----------
a: array-like, shape=(ns,)
samples weights in the source domain
b: array-like, shape=(nt,)
samples weights in the target domain
M: array-like, shape=(ns, nt)
Cost matrix
reg: float
Level of the entropy regularisation
ns_budget: int, default=None
Number budget of points to be kept in the source domain.
If it is None then 50% of the source sample points will be kept
nt_budget: int, default=None
Number budget of points to be kept in the target domain.
If it is None then 50% of the target sample points will be kept
uniform: bool, default=False
If True, the source and target distribution are supposed to be uniform,
i.e., :math:a_i = 1 / ns and :math:b_j = 1 / nt
restricted : bool, default=True
If True, a warm-start initialization for the  L-BFGS-B solver
using a restricted Sinkhorn algorithm with at most 5 iterations
maxiter: int, default=10000
Maximum number of iterations in LBFGS solver
maxfun: int, default=10000
Maximum number of function evaluations in LBFGS solver
pgtol: float, default=1e-09
Final objective function accuracy in LBFGS solver
verbose: bool, default=False
If True, display informations about the cardinals of the active sets
and the parameters kappa and epsilon

To gain more efficiency, :py:func:ot.bregman.screenkhorn needs to call the "Bottleneck"
package (https://pypi.org/project/Bottleneck/) in the screening pre-processing step.

If Bottleneck isn't installed, the following error message appears:

"Bottleneck module doesn't exist. Install it from https://pypi.org/project/Bottleneck/"

Returns
-------
gamma : array-like, shape=(ns, nt)
Screened optimal transportation matrix for the given parameters

log : dict, default=False
Log dictionary return only if log==True in parameters

.. _references-screenkhorn:
References
-----------

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport,
Advances in Neural Information Processing Systems (NIPS) 26, 2013

.. [26] Alaya M. Z., Bérar M., Gasso G., Rakotomamonjy A. (2019).
Screening Sinkhorn Algorithm for Regularized Optimal Transport (NIPS) 33, 2019

"""
# check if bottleneck module exists
try:
import bottleneck
except ImportError:
warnings.warn(
"Bottleneck module is not installed. Install it from"
" https://pypi.org/project/Bottleneck/ for better performance.")
bottleneck = np

a, b, M = list_to_array(a, b, M)

nx = get_backend(M, a, b)
if nx.__name__ in ("jax", "tf"):
raise TypeError("JAX or TF arrays have been received but screenkhorn is not "
"compatible with neither JAX nor TF.")

ns, nt = M.shape

# by default, we keep only 50% of the sample data points
if ns_budget is None:
ns_budget = int(np.floor(0.5 * ns))
if nt_budget is None:
nt_budget = int(np.floor(0.5 * nt))

# calculate the Gibbs kernel
K = nx.exp(-M / reg)

def projection(u, epsilon):
u = nx.maximum(u, epsilon)
return u

# ----------------------------------------------------------------------------------------------------------------#
#                                          Step 1: Screening pre-processing                                       #
# ----------------------------------------------------------------------------------------------------------------#

if ns_budget == ns and nt_budget == nt:
# full number of budget points (ns, nt) = (ns_budget, nt_budget)
Isel = nx.from_numpy(np.ones(ns, dtype=bool))
Jsel = nx.from_numpy(np.ones(nt, dtype=bool))
epsilon = 0.0
kappa = 1.0

cst_u = 0.
cst_v = 0.

bounds_u = [(0.0, np.inf)] * ns
bounds_v = [(0.0, np.inf)] * nt

a_I = a
b_J = b
K_IJ = K
K_IJc = []
K_IcJ = []

vec_eps_IJc = nx.zeros((nt,), type_as=M)
vec_eps_IcJ = nx.zeros((ns,), type_as=M)

else:
# sum of rows and columns of K
K_sum_cols = nx.sum(K, axis=1)
K_sum_rows = nx.sum(K, axis=0)

if uniform:
if ns / ns_budget < 4:
aK_sort = nx.sort(K_sum_cols)
epsilon_u_square = a[0] / aK_sort[ns_budget - 1]
else:
aK_sort = nx.from_numpy(
bottleneck.partition(nx.to_numpy(
K_sum_cols), ns_budget - 1)[ns_budget - 1],
type_as=M
)
epsilon_u_square = a[0] / aK_sort

if nt / nt_budget < 4:
bK_sort = nx.sort(K_sum_rows)
epsilon_v_square = b[0] / bK_sort[nt_budget - 1]
else:
bK_sort = nx.from_numpy(
bottleneck.partition(nx.to_numpy(
K_sum_rows), nt_budget - 1)[nt_budget - 1],
type_as=M
)
epsilon_v_square = b[0] / bK_sort
else:
aK = a / K_sum_cols
bK = b / K_sum_rows

aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = aK_sort[ns_budget - 1]

bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = bK_sort[nt_budget - 1]

# active sets I and J (see Lemma 1 in [26])
Isel = a >= epsilon_u_square * K_sum_cols
Jsel = b >= epsilon_v_square * K_sum_rows

if nx.sum(Isel) != ns_budget:
if uniform:
aK = a / K_sum_cols
aK_sort = nx.flip(nx.sort(aK), axis=0)
epsilon_u_square = nx.mean(aK_sort[ns_budget - 1:ns_budget + 1])
Isel = a >= epsilon_u_square * K_sum_cols
ns_budget = nx.sum(Isel)

if nx.sum(Jsel) != nt_budget:
if uniform:
bK = b / K_sum_rows
bK_sort = nx.flip(nx.sort(bK), axis=0)
epsilon_v_square = nx.mean(bK_sort[nt_budget - 1:nt_budget + 1])
Jsel = b >= epsilon_v_square * K_sum_rows
nt_budget = nx.sum(Jsel)

epsilon = (epsilon_u_square * epsilon_v_square) ** (1 / 4)
kappa = (epsilon_v_square / epsilon_u_square) ** (1 / 2)

if verbose:
print("epsilon = %s\n" % epsilon)
print("kappa = %s\n" % kappa)
print('Cardinality of selected points: |Isel| = %s \t |Jsel| = %s \n'
% (sum(Isel), sum(Jsel)))

# Ic, Jc: complementary of the active sets I and J
Ic = ~Isel
Jc = ~Jsel

K_IJ = K[np.ix_(Isel, Jsel)]
K_IcJ = K[np.ix_(Ic, Jsel)]
K_IJc = K[np.ix_(Isel, Jc)]

K_min = nx.min(K_IJ)
if K_min == 0:
K_min = float(np.finfo(float).tiny)

# a_I, b_J, a_Ic, b_Jc
a_I = a[Isel]
b_J = b[Jsel]
if not uniform:
a_I_min = nx.min(a_I)
a_I_max = nx.max(a_I)
b_J_max = nx.max(b_J)
b_J_min = nx.min(b_J)
else:
a_I_min = a_I[0]
a_I_max = a_I[0]
b_J_max = b_J[0]
b_J_min = b_J[0]

# box constraints in L-BFGS-B (see Proposition 1 in [26])
bounds_u = [(max(a_I_min / ((nt - nt_budget) * epsilon + nt_budget * (b_J_max / (
ns * epsilon * kappa * K_min))), epsilon / kappa), a_I_max / (nt * epsilon * K_min))] * ns_budget

bounds_v = [(
max(b_J_min / ((ns - ns_budget) * epsilon + ns_budget * (kappa * a_I_max / (nt * epsilon * K_min))),
epsilon * kappa), b_J_max / (ns * epsilon * K_min))] * nt_budget

# pre-calculated constants for the objective
vec_eps_IJc = epsilon * kappa * nx.sum(
K_IJc * nx.ones((nt - nt_budget,), type_as=M)[None, :],
axis=1
)
vec_eps_IcJ = (epsilon / kappa) * nx.sum(
nx.ones((ns - ns_budget,), type_as=M)[:, None] * K_IcJ,
axis=0
)

# initialisation
u0 = nx.full((ns_budget,), 1. / ns_budget + epsilon / kappa, type_as=M)
v0 = nx.full((nt_budget,), 1. / nt_budget + epsilon * kappa, type_as=M)

# pre-calculed constants for Restricted Sinkhorn (see Algorithm 1 in supplementary of [26])
if restricted:
if ns_budget != ns or nt_budget != nt:
cst_u = kappa * epsilon * nx.sum(K_IJc, axis=1)
cst_v = epsilon * nx.sum(K_IcJ, axis=0) / kappa

for _ in range(5):  # 5 iterations
K_IJ_v = nx.dot(K_IJ.T, u0) + cst_v
v0 = b_J / (kappa * K_IJ_v)
KIJ_u = nx.dot(K_IJ, v0) + cst_u
u0 = (kappa * a_I) / KIJ_u

u0 = projection(u0, epsilon / kappa)
v0 = projection(v0, epsilon * kappa)

else:
u0 = u0
v0 = v0

def restricted_sinkhorn(usc, vsc, max_iter=5):
"""
Restricted Sinkhorn Algorithm as a warm-start initialized pointfor L-BFGS-B)
"""
for _ in range(max_iter):
K_IJ_v = nx.dot(K_IJ.T, usc) + cst_v
vsc = b_J / (kappa * K_IJ_v)
KIJ_u = nx.dot(K_IJ, vsc) + cst_u
usc = (kappa * a_I) / KIJ_u

usc = projection(usc, epsilon / kappa)
vsc = projection(vsc, epsilon * kappa)

return usc, vsc

def screened_obj(usc, vsc):
part_IJ = (
nx.dot(nx.dot(usc, K_IJ), vsc)
- kappa * nx.dot(a_I, nx.log(usc))
- (1. / kappa) * nx.dot(b_J, nx.log(vsc))
)
part_IJc = nx.dot(usc, vec_eps_IJc)
part_IcJ = nx.dot(vec_eps_IcJ, vsc)
psi_epsilon = part_IJ + part_IJc + part_IcJ
return psi_epsilon

# gradients of Psi_(kappa,epsilon) w.r.t u and v
grad_u = nx.dot(K_IJ, vsc) + vec_eps_IJc - kappa * a_I / usc
grad_v = nx.dot(K_IJ.T, usc) + vec_eps_IcJ - (1. / kappa) * b_J / vsc

def bfgspost(theta):
u = theta[:ns_budget]
v = theta[ns_budget:]
# objective
f = screened_obj(u, v)
g = nx.concatenate([g_u, g_v], axis=0)
return nx.to_numpy(f), nx.to_numpy(g)

# ----------------------------------------------------------------------------------------------------------------#
#                                           Step 2: L-BFGS-B solver                                              #
# ----------------------------------------------------------------------------------------------------------------#

u0, v0 = restricted_sinkhorn(u0, v0)
theta0 = nx.concatenate([u0, v0], axis=0)

bounds = bounds_u + bounds_v  # constraint bounds

def obj(theta):
return bfgspost(nx.from_numpy(theta, type_as=M))

theta, _, _ = fmin_l_bfgs_b(func=obj,
x0=theta0,
bounds=bounds,
maxfun=maxfun,
pgtol=pgtol,
maxiter=maxiter)
theta = nx.from_numpy(theta, type_as=M)

usc = theta[:ns_budget]
vsc = theta[ns_budget:]

usc_full = nx.full((ns,), epsilon / kappa, type_as=M)
vsc_full = nx.full((nt,), epsilon * kappa, type_as=M)
usc_full[Isel] = usc
vsc_full[Jsel] = vsc

if log:
log = {}
log['u'] = usc_full
log['v'] = vsc_full
log['Isel'] = Isel
log['Jsel'] = Jsel

gamma = usc_full[:, None] * K * vsc_full[None, :]
gamma = gamma / nx.sum(gamma)

if log:
return gamma, log
else:
return gamma