# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#         Titouan Vayer <titouan.vayer@irisa.fr>
#         Alexander Tong <alexander.tong@yale.edu>
#         Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
import warnings
import numpy as np
from ..utils import list_to_array
from ..backend import get_backend
[docs]
def sinkhorn(
    a,
    b,
    M,
    reg,
    method="sinkhorn",
    numItermax=1000,
    stopThr=1e-9,
    verbose=False,
    log=False,
    warn=True,
    warmstart=None,
    **kwargs,
):
    r"""
    Solve the entropic regularization optimal transport problem and return the OT matrix
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
      weights (histograms, both sum to 1)
    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends.
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
    scaling algorithm as proposed in :ref:`[2] <references-sinkhorn>`
    **Choosing a Sinkhorn solver**
    By default and when using a regularization parameter that is not too small
    the default sinkhorn solver should be enough. If you need to use a small
    regularization to get sharper OT matrices, you should use the
    :py:func:`ot.bregman.sinkhorn_stabilized` solver that will avoid numerical
    errors. This last solver can be very slow in practice and might not even
    converge to a reasonable OT matrix in a finite time. This is why
    :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
    of the regularization (and using warm start) sometimes leads to better
    solutions. Note that the greedy version of the sinkhorn
    :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
    version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim at providing  a
    fast approximation of the Sinkhorn problem. For use of GPU and gradient
    computation with small number of iterations we strongly recommend the
    :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
    numerical problems.
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
        samples in the target domain, compute sinkhorn with multiple targets
        and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
        (return OT loss + dual variables in log)
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    method : str
        method used for the solver either 'sinkhorn','sinkhorn_log',
        'greenkhorn', 'sinkhorn_stabilized' or 'sinkhorn_epsilon_scaling', see
        those function for specific parameters
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors)
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.sinkhorn(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-sinkhorn:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
        of Optimal Transport, Advances in Neural Information Processing
        Systems (NIPS) 26, 2013
    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
        for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems.
        arXiv preprint arXiv:1607.05816.
    .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
        A., & Peyré, G. (2019, April). Interpolating between optimal transport
        and MMD using Sinkhorn divergences. In The 22nd International Conference
        on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn>`
    ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
        :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
    ot.bregman.sinkhorn_epsilon_scaling: Sinkhorn with epsilon scaling
        :ref:`[9] <references-sinkhorn>` :ref:`[10] <references-sinkhorn>`
    """
    if method.lower() == "sinkhorn":
        return sinkhorn_knopp(
            a,
            b,
            M,
            reg,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            warn=warn,
            warmstart=warmstart,
            **kwargs,
        )
    elif method.lower() == "sinkhorn_log":
        return sinkhorn_log(
            a,
            b,
            M,
            reg,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            warn=warn,
            warmstart=warmstart,
            **kwargs,
        )
    elif method.lower() == "greenkhorn":
        return greenkhorn(
            a,
            b,
            M,
            reg,
            numItermax=numItermax,
            stopThr=stopThr,
            verbose=verbose,
            log=log,
            warn=warn,
            warmstart=warmstart,
        )
    elif method.lower() == "sinkhorn_stabilized":
        return sinkhorn_stabilized(
            a,
            b,
            M,
            reg,
            numItermax=numItermax,
            stopThr=stopThr,
            warmstart=warmstart,
            verbose=verbose,
            log=log,
            warn=warn,
            **kwargs,
        )
    elif method.lower() == "sinkhorn_epsilon_scaling":
        return sinkhorn_epsilon_scaling(
            a,
            b,
            M,
            reg,
            numItermax=numItermax,
            stopThr=stopThr,
            warmstart=warmstart,
            verbose=verbose,
            log=log,
            warn=warn,
            **kwargs,
        )
    else:
        raise ValueError("Unknown method '%s'." % method) 
[docs]
def sinkhorn2(
    a,
    b,
    M,
    reg,
    method="sinkhorn",
    numItermax=1000,
    stopThr=1e-9,
    verbose=False,
    log=False,
    warn=False,
    warmstart=None,
    **kwargs,
):
    r"""
    Solve the entropic regularization optimal transport problem and return the loss
    The function solves the following optimization problem:
    .. math::
        W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
      weights (histograms, both sum to 1)
    and returns :math:`\langle \gamma^*, \mathbf{M} \rangle_F` (without
    the entropic contribution).
    .. note:: This function is backend-compatible and will work on arrays
        from all compatible backends.
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
    scaling algorithm as proposed in :ref:`[2] <references-sinkhorn2>`
    **Choosing a Sinkhorn solver**
    By default and when using a regularization parameter that is not too small
    the default sinkhorn solver should be enough. If you need to use a small
    regularization to get sharper OT matrices, you should use the
    :py:func:`ot.bregman.sinkhorn_log` solver that will avoid numerical
    errors. This last solver can be very slow in practice and might not even
    converge to a reasonable OT matrix in a finite time. This is why
    :py:func:`ot.bregman.sinkhorn_epsilon_scaling` that relies on iterating the value
    of the regularization (and using warm start) sometimes leads to better
    solutions. Note that the greedy version of the sinkhorn
    :py:func:`ot.bregman.greenkhorn` can also lead to a speedup and the screening
    version of the sinkhorn :py:func:`ot.bregman.screenkhorn` aim a providing  a
    fast approximation of the Sinkhorn problem. For use of GPU and gradient
    computation with small number of iterations we strongly recommend the
    :py:func:`ot.bregman.sinkhorn_log` solver that will no need to check for
    numerical problems.
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,) or ndarray, shape (dim_b, n_hists)
        samples in the target domain, compute sinkhorn with multiple targets
        and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
        (return OT loss + dual variables in log)
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    method : str
        method used for the solver either 'sinkhorn','sinkhorn_log',
        'sinkhorn_stabilized', see those function for specific parameters
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors)
    Returns
    -------
    W : (n_hists) float/array-like
        Optimal transportation loss for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.sinkhorn2(a, b, M, 1)
    0.26894142136999516
    .. _references-sinkhorn2:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
        Optimal Transport, Advances in Neural Information
        Processing Systems (NIPS) 26, 2013
    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
        for Entropy Regularized Transport Problems.
        arXiv preprint arXiv:1610.06519.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems.
        arXiv preprint arXiv:1607.05816.
    .. [21] Altschuler J., Weed J., Rigollet P. : Near-linear time approximation
        algorithms for optimal transport via Sinkhorn iteration,
        Advances in Neural Information Processing Systems (NIPS) 31, 2017
    .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
        Trouvé, A., & Peyré, G. (2019, April).
        Interpolating between optimal transport and MMD using Sinkhorn
        divergences. In The 22nd International Conference on Artificial
        Intelligence and Statistics (pp. 2681-2690). PMLR.
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    ot.bregman.sinkhorn_knopp : Classic Sinkhorn :ref:`[2] <references-sinkhorn2>`
    ot.bregman.greenkhorn : Greenkhorn :ref:`[21] <references-sinkhorn2>`
    ot.bregman.sinkhorn_stabilized: Stabilized sinkhorn
        :ref:`[9] <references-sinkhorn2>` :ref:`[10] <references-sinkhorn2>`
    """
    M, a, b = list_to_array(M, a, b)
    nx = get_backend(M, a, b)
    if len(b.shape) < 2:
        if method.lower() == "sinkhorn":
            res = sinkhorn_knopp(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                warn=warn,
                warmstart=warmstart,
                **kwargs,
            )
        elif method.lower() == "sinkhorn_log":
            res = sinkhorn_log(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                warn=warn,
                warmstart=warmstart,
                **kwargs,
            )
        elif method.lower() == "sinkhorn_stabilized":
            res = sinkhorn_stabilized(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                warmstart=warmstart,
                verbose=verbose,
                log=log,
                warn=warn,
                **kwargs,
            )
        else:
            raise ValueError("Unknown method '%s'." % method)
        if log:
            return nx.sum(M * res[0]), res[1]
        else:
            return nx.sum(M * res)
    else:
        if method.lower() == "sinkhorn":
            return sinkhorn_knopp(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                warn=warn,
                warmstart=warmstart,
                **kwargs,
            )
        elif method.lower() == "sinkhorn_log":
            return sinkhorn_log(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                warn=warn,
                warmstart=warmstart,
                **kwargs,
            )
        elif method.lower() == "sinkhorn_stabilized":
            return sinkhorn_stabilized(
                a,
                b,
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                warmstart=warmstart,
                verbose=verbose,
                log=log,
                warn=warn,
                **kwargs,
            )
        else:
            raise ValueError("Unknown method '%s'." % method) 
[docs]
def sinkhorn_knopp(
    a,
    b,
    M,
    reg,
    numItermax=1000,
    stopThr=1e-9,
    verbose=False,
    log=False,
    warn=True,
    warmstart=None,
    **kwargs,
):
    r"""
    Solve the entropic regularization optimal transport problem and return the OT matrix
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
      weights (histograms, both sum to 1)
    The algorithm used for solving the problem is the Sinkhorn-Knopp
    matrix scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-knopp>`
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
        samples in the target domain, compute sinkhorn with multiple targets
        and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
        (return OT loss + dual variables in log)
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors)
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.sinkhorn(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-sinkhorn-knopp:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
        of Optimal Transport, Advances in Neural Information
        Processing Systems (NIPS) 26, 2013
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)
    if len(a) == 0:
        a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
    if len(b) == 0:
        b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
    # init data
    dim_a = len(a)
    dim_b = b.shape[0]
    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0
    if log:
        log = {"err": []}
    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        if n_hists:
            u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
            v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
        else:
            u = nx.ones(dim_a, type_as=M) / dim_a
            v = nx.ones(dim_b, type_as=M) / dim_b
    else:
        u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
    K = nx.exp(M / (-reg))
    Kp = (1 / a).reshape(-1, 1) * K
    err = 1
    for ii in range(numItermax):
        uprev = u
        vprev = v
        KtransposeU = nx.dot(K.T, u)
        v = b / KtransposeU
        u = 1.0 / nx.dot(Kp, v)
        if (
            nx.any(KtransposeU == 0)
            or nx.any(nx.isnan(u))
            or nx.any(nx.isnan(v))
            or nx.any(nx.isinf(u))
            or nx.any(nx.isinf(v))
        ):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Warning: numerical errors at iteration %d" % ii)
            u = uprev
            v = vprev
            break
        if ii % 10 == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            if n_hists:
                tmp2 = nx.einsum("ik,ij,jk->jk", u, K, v)
            else:
                # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
                tmp2 = nx.einsum("i,ij,j->j", u, K, v)
            err = nx.norm(tmp2 - b)  # violation of marginal
            if log:
                log["err"].append(err)
            if err < stopThr:
                break
            if verbose:
                if ii % 200 == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(ii, err))
    else:
        if warn:
            warnings.warn(
                "Sinkhorn did not converge. You might want to "
                "increase the number of iterations `numItermax` "
                "or the regularization parameter `reg`."
            )
    if log:
        log["niter"] = ii
        log["u"] = u
        log["v"] = v
    if n_hists:  # return only loss
        res = nx.einsum("ik,ij,jk,ij->k", u, K, v, M)
        if log:
            return res, log
        else:
            return res
    else:  # return OT matrix
        if log:
            return u.reshape((-1, 1)) * K * v.reshape((1, -1)), log
        else:
            return u.reshape((-1, 1)) * K * v.reshape((1, -1)) 
[docs]
def sinkhorn_log(
    a,
    b,
    M,
    reg,
    numItermax=1000,
    stopThr=1e-9,
    verbose=False,
    log=False,
    warn=True,
    warmstart=None,
    **kwargs,
):
    r"""
    Solve the entropic regularization optimal transport problem in log space
    and return the OT matrix
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
    scaling algorithm  :ref:`[2] <references-sinkhorn-log>` with the
    implementation from :ref:`[34] <references-sinkhorn-log>`
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
        samples in the target domain, compute sinkhorn with multiple targets
        and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix (return OT loss + dual variables in log)
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors)
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.sinkhorn(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-sinkhorn-log:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
        Optimal Transport, Advances in Neural Information Processing
        Systems (NIPS) 26, 2013
    .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I.,
        Trouvé, A., & Peyré, G. (2019, April). Interpolating between
        optimal transport and MMD using Sinkhorn divergences. In The
        22nd International Conference on Artificial Intelligence and
        Statistics (pp. 2681-2690). PMLR.
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)
    if len(a) == 0:
        a = nx.full((M.shape[0],), 1.0 / M.shape[0], type_as=M)
    if len(b) == 0:
        b = nx.full((M.shape[1],), 1.0 / M.shape[1], type_as=M)
    # init data
    dim_a = len(a)
    dim_b = b.shape[0]
    if len(b.shape) > 1:
        n_hists = b.shape[1]
    else:
        n_hists = 0
    # in case of multiple histograms
    if n_hists > 1 and warmstart is None:
        warmstart = [None] * n_hists
    if n_hists:  # we do not want to use tensors sor we do a loop
        lst_loss = []
        lst_u = []
        lst_v = []
        for k in range(n_hists):
            res = sinkhorn_log(
                a,
                b[:, k],
                M,
                reg,
                numItermax=numItermax,
                stopThr=stopThr,
                verbose=verbose,
                log=log,
                warmstart=warmstart[k],
                **kwargs,
            )
            if log:
                lst_loss.append(nx.sum(M * res[0]))
                lst_u.append(res[1]["log_u"])
                lst_v.append(res[1]["log_v"])
            else:
                lst_loss.append(nx.sum(M * res))
        res = nx.stack(lst_loss)
        if log:
            log = {
                "log_u": nx.stack(lst_u, 1),
                "log_v": nx.stack(lst_v, 1),
            }
            log["u"] = nx.exp(log["log_u"])
            log["v"] = nx.exp(log["log_v"])
            return res, log
        else:
            return res
    else:
        if log:
            log = {"err": []}
        Mr = -M / reg
        # we assume that no distances are null except those of the diagonal of
        # distances
        if warmstart is None:
            u = nx.zeros(dim_a, type_as=M)
            v = nx.zeros(dim_b, type_as=M)
        else:
            u, v = warmstart
        def get_logT(u, v):
            if n_hists:
                return Mr[:, :, None] + u + v
            else:
                return Mr + u[:, None] + v[None, :]
        loga = nx.log(a)
        logb = nx.log(b)
        err = 1
        for ii in range(numItermax):
            v = logb - nx.logsumexp(Mr + u[:, None], 0)
            u = loga - nx.logsumexp(Mr + v[None, :], 1)
            if ii % 10 == 0:
                # we can speed up the process by checking for the error only all
                # the 10th iterations
                # compute right marginal tmp2= (diag(u)Kdiag(v))^T1
                tmp2 = nx.sum(nx.exp(get_logT(u, v)), 0)
                err = nx.norm(tmp2 - b)  # violation of marginal
                if log:
                    log["err"].append(err)
                if verbose:
                    if ii % 200 == 0:
                        print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                    print("{:5d}|{:8e}|".format(ii, err))
                if err < stopThr:
                    break
        else:
            if warn:
                warnings.warn(
                    "Sinkhorn did not converge. You might want to "
                    "increase the number of iterations `numItermax` "
                    "or the regularization parameter `reg`."
                )
        if log:
            log["niter"] = ii
            log["log_u"] = u
            log["log_v"] = v
            log["u"] = nx.exp(u)
            log["v"] = nx.exp(v)
            return nx.exp(get_logT(u, v)), log
        else:
            return nx.exp(get_logT(u, v)) 
[docs]
def greenkhorn(
    a,
    b,
    M,
    reg,
    numItermax=10000,
    stopThr=1e-9,
    verbose=False,
    log=False,
    warn=True,
    warmstart=None,
):
    r"""
    Solve the entropic regularization optimal transport problem and return the OT matrix
    The algorithm used is based on the paper :ref:`[22] <references-greenkhorn>`
    which is a stochastic version of the Sinkhorn-Knopp
    algorithm :ref:`[2] <references-greenkhorn>`
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
      weights (histograms, both sum to 1)
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,) or array-like, shape (dim_b, n_hists)
        samples in the target domain, compute sinkhorn with multiple targets
        and fixed :math:`\mathbf{M}` if :math:`\mathbf{b}` is a matrix
        (return OT loss + dual variables in log)
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    warmstart: tuple of arrays, shape (dim_a, dim_b), optional
        Initialization of dual potentials. If provided, the dual potentials should be given
        (that is the logarithm of the u,v sinkhorn scaling vectors)
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.bregman.greenkhorn(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-greenkhorn:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
        of Optimal Transport, Advances in Neural Information
        Processing Systems (NIPS) 26, 2013
    .. [22] J. Altschuler, J.Weed, P. Rigollet : Near-linear time
        approximation algorithms for optimal transport via Sinkhorn
        iteration, Advances in Neural Information Processing
        Systems (NIPS) 31, 2017
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)
    if nx.__name__ in ("jax", "tf"):
        raise TypeError(
            "JAX or TF arrays have been received. Greenkhorn is not "
            "compatible with  neither JAX nor TF"
        )
    if len(a) == 0:
        a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
    if len(b) == 0:
        b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
    dim_a = a.shape[0]
    dim_b = b.shape[0]
    K = nx.exp(-M / reg)
    if warmstart is None:
        u = nx.full((dim_a,), 1.0 / dim_a, type_as=K)
        v = nx.full((dim_b,), 1.0 / dim_b, type_as=K)
    else:
        u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
    G = u[:, None] * K * v[None, :]
    viol = nx.sum(G, axis=1) - a
    viol_2 = nx.sum(G, axis=0) - b
    stopThr_val = 1
    if log:
        log = dict()
        log["u"] = u
        log["v"] = v
    for ii in range(numItermax):
        i_1 = nx.argmax(nx.abs(viol))
        i_2 = nx.argmax(nx.abs(viol_2))
        m_viol_1 = nx.abs(viol[i_1])
        m_viol_2 = nx.abs(viol_2[i_2])
        stopThr_val = nx.maximum(m_viol_1, m_viol_2)
        if m_viol_1 > m_viol_2:
            old_u = u[i_1]
            new_u = a[i_1] / nx.dot(K[i_1, :], v)
            G[i_1, :] = new_u * K[i_1, :] * v
            viol[i_1] = nx.dot(new_u * K[i_1, :], v) - a[i_1]
            viol_2 += K[i_1, :].T * (new_u - old_u) * v
            u[i_1] = new_u
        else:
            old_v = v[i_2]
            new_v = b[i_2] / nx.dot(K[:, i_2].T, u)
            G[:, i_2] = u * K[:, i_2] * new_v
            # aviol = (G@one_m - a)
            # aviol_2 = (G.T@one_n - b)
            viol += (-old_v + new_v) * K[:, i_2] * u
            viol_2[i_2] = new_v * nx.dot(K[:, i_2], u) - b[i_2]
            v[i_2] = new_v
        if stopThr_val <= stopThr:
            break
    else:
        if warn:
            warnings.warn(
                "Sinkhorn did not converge. You might want to "
                "increase the number of iterations `numItermax` "
                "or the regularization parameter `reg`."
            )
    if log:
        log["n_iter"] = ii
        log["u"] = u
        log["v"] = v
    if log:
        return G, log
    else:
        return G 
[docs]
def sinkhorn_stabilized(
    a,
    b,
    M,
    reg,
    numItermax=1000,
    tau=1e3,
    stopThr=1e-9,
    warmstart=None,
    verbose=False,
    print_period=20,
    log=False,
    warn=True,
    **kwargs,
):
    r"""
    Solve the entropic regularization OT problem with log stabilization
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
      weights (histograms, both sum to 1)
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
    scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-stabilized>`
    but with the log stabilization
    proposed in :ref:`[10] <references-sinkhorn-stabilized>` an defined in
    :ref:`[9] <references-sinkhorn-stabilized>` (Algo 3.1) .
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,)
        samples in the target domain
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    tau : float
        threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{v}`
        for log scaling
    warmstart : table of vectors
        if given then starting values for alpha and beta log scalings
    numItermax : int, optional
        Max number of iterations
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5,.5]
    >>> b=[.5,.5]
    >>> M=[[0.,1.],[1.,0.]]
    >>> ot.bregman.sinkhorn_stabilized(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-sinkhorn-stabilized:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of
        Optimal Transport, Advances in Neural Information Processing
        Systems (NIPS) 26, 2013
    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms
        for Entropy Regularized Transport Problems.
        arXiv preprint arXiv:1610.06519.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems.
        arXiv preprint arXiv:1607.05816.
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)
    if len(a) == 0:
        a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
    if len(b) == 0:
        b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
    # test if multiple target
    if len(b.shape) > 1:
        n_hists = b.shape[1]
        a = a[:, None]
    else:
        n_hists = 0
    # init data
    dim_a = len(a)
    dim_b = len(b)
    if log:
        log = {"err": []}
    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
    else:
        alpha, beta = warmstart
    if n_hists:
        u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
        v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
    else:
        u, v = nx.ones(dim_a, type_as=M), nx.ones(dim_b, type_as=M)
        u /= dim_a
        v /= dim_b
    def get_K(alpha, beta):
        """log space computation"""
        return nx.exp(-(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg)
    def get_Gamma(alpha, beta, u, v):
        """log space gamma computation"""
        return nx.exp(
            -(M - alpha.reshape((dim_a, 1)) - beta.reshape((1, dim_b))) / reg
            + nx.log(u.reshape((dim_a, 1)))
            + nx.log(v.reshape((1, dim_b)))
        )
    K = get_K(alpha, beta)
    transp = K
    err = 1
    for ii in range(numItermax):
        uprev = u
        vprev = v
        # sinkhorn update
        v = b / (nx.dot(K.T, u))
        u = a / (nx.dot(K, v))
        # remove numerical problems and store them in K
        if nx.max(nx.abs(u)) > tau or nx.max(nx.abs(v)) > tau:
            if n_hists:
                alpha, beta = (
                    alpha + reg * nx.max(nx.log(u), 1),
                    beta + reg * nx.max(nx.log(v)),
                )
            else:
                alpha, beta = alpha + reg * nx.log(u), beta + reg * nx.log(v)
                if n_hists:
                    u = nx.ones((dim_a, n_hists), type_as=M) / dim_a
                    v = nx.ones((dim_b, n_hists), type_as=M) / dim_b
                else:
                    u = nx.ones(dim_a, type_as=M) / dim_a
                    v = nx.ones(dim_b, type_as=M) / dim_b
            K = get_K(alpha, beta)
        if ii % print_period == 0:
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            if n_hists:
                err_u = nx.max(nx.abs(u - uprev))
                err_u /= max(nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.0)
                err_v = nx.max(nx.abs(v - vprev))
                err_v /= max(nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.0)
                err = 0.5 * (err_u + err_v)
            else:
                transp = get_Gamma(alpha, beta, u, v)
                err = nx.norm(nx.sum(transp, axis=0) - b)
            if log:
                log["err"].append(err)
            if verbose:
                if ii % (print_period * 20) == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(ii, err))
        if err <= stopThr:
            break
        if nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)):
            # we have reached the machine precision
            # come back to previous solution and quit loop
            warnings.warn("Numerical errors at iteration %d" % ii)
            u = uprev
            v = vprev
            break
    else:
        if warn:
            warnings.warn(
                "Sinkhorn did not converge. You might want to "
                "increase the number of iterations `numItermax` "
                "or the regularization parameter `reg`."
            )
    if log:
        if n_hists:
            alpha = alpha[:, None]
            beta = beta[:, None]
        logu = alpha / reg + nx.log(u)
        logv = beta / reg + nx.log(v)
        log["n_iter"] = ii
        log["logu"] = logu
        log["logv"] = logv
        log["alpha"] = alpha + reg * nx.log(u)
        log["beta"] = beta + reg * nx.log(v)
        log["warmstart"] = (log["alpha"], log["beta"])
        if n_hists:
            res = nx.stack(
                [
                    nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
                    for i in range(n_hists)
                ]
            )
            return res, log
        else:
            return get_Gamma(alpha, beta, u, v), log
    else:
        if n_hists:
            res = nx.stack(
                [
                    nx.sum(get_Gamma(alpha, beta, u[:, i], v[:, i]) * M)
                    for i in range(n_hists)
                ]
            )
            return res
        else:
            return get_Gamma(alpha, beta, u, v) 
[docs]
def sinkhorn_epsilon_scaling(
    a,
    b,
    M,
    reg,
    numItermax=100,
    epsilon0=1e4,
    numInnerItermax=100,
    tau=1e3,
    stopThr=1e-9,
    warmstart=None,
    verbose=False,
    print_period=10,
    log=False,
    warn=True,
    **kwargs,
):
    r"""
    Solve the entropic regularization optimal transport problem with log
    stabilization and epsilon scaling.
    The function solves the following optimization problem:
    .. math::
        \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
        \mathrm{reg}\cdot\Omega(\gamma)
        s.t. \ \gamma \mathbf{1} &= \mathbf{a}
             \gamma^T \mathbf{1} &= \mathbf{b}
             \gamma &\geq 0
    where :
    - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
    - :math:`\Omega` is the entropic regularization term
      :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
    - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (histograms, both sum to 1)
    The algorithm used for solving the problem is the Sinkhorn-Knopp matrix
    scaling algorithm as proposed in :ref:`[2] <references-sinkhorn-epsilon-scaling>`
    but with the log stabilization
    proposed in :ref:`[10] <references-sinkhorn-epsilon-scaling>` and the log scaling
    proposed in :ref:`[9] <references-sinkhorn-epsilon-scaling>` algorithm 3.2
    Parameters
    ----------
    a : array-like, shape (dim_a,)
        samples weights in the source domain
    b : array-like, shape (dim_b,)
        samples in the target domain
    M : array-like, shape (dim_a, dim_b)
        loss matrix
    reg : float
        Regularization term >0
    tau : float
        threshold for max value in :math:`\mathbf{u}` or :math:`\mathbf{b}`
        for log scaling
    warmstart : tuple of vectors
        if given then starting values for alpha and beta log scalings
    numItermax : int, optional
        Max number of iterations
    numInnerItermax : int, optional
        Max number of iterations in the inner slog stabilized sinkhorn
    epsilon0 : int, optional
        first epsilon regularization value (then exponential decrease to reg)
    stopThr : float, optional
        Stop threshold on error (>0)
    verbose : bool, optional
        Print information along iterations
    log : bool, optional
        record log if True
    warn : bool, optional
        if True, raises a warning if the algorithm doesn't convergence.
    Returns
    -------
    gamma : array-like, shape (dim_a, dim_b)
        Optimal transportation matrix for the given parameters
    log : dict
        log dictionary return only if log==True in parameters
    Examples
    --------
    >>> import ot
    >>> a=[.5, .5]
    >>> b=[.5, .5]
    >>> M=[[0., 1.], [1., 0.]]
    >>> ot.bregman.sinkhorn_epsilon_scaling(a, b, M, 1)
    array([[0.36552929, 0.13447071],
           [0.13447071, 0.36552929]])
    .. _references-sinkhorn-epsilon-scaling:
    References
    ----------
    .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
        Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013
    .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
        Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
    .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
        Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
    See Also
    --------
    ot.lp.emd : Unregularized OT
    ot.optim.cg : General regularized OT
    """
    a, b, M = list_to_array(a, b, M)
    nx = get_backend(M, a, b)
    if len(a) == 0:
        a = nx.ones((M.shape[0],), type_as=M) / M.shape[0]
    if len(b) == 0:
        b = nx.ones((M.shape[1],), type_as=M) / M.shape[1]
    # init data
    dim_a = len(a)
    dim_b = len(b)
    # nrelative umerical precision with 64 bits
    numItermin = 35
    numItermax = max(numItermin, numItermax)  # ensure that last velue is exact
    ii = 0
    if log:
        log = {"err": []}
    # we assume that no distances are null except those of the diagonal of
    # distances
    if warmstart is None:
        alpha, beta = nx.zeros(dim_a, type_as=M), nx.zeros(dim_b, type_as=M)
    else:
        alpha, beta = warmstart
    # print(np.min(K))
    def get_reg(n):  # exponential decreasing
        return (epsilon0 - reg) * np.exp(-n) + reg
    err = 1
    for ii in range(numItermax):
        regi = get_reg(ii)
        G, logi = sinkhorn_stabilized(
            a,
            b,
            M,
            regi,
            numItermax=numInnerItermax,
            stopThr=stopThr,
            warmstart=(alpha, beta),
            verbose=False,
            print_period=20,
            tau=tau,
            log=True,
        )
        alpha = logi["alpha"]
        beta = logi["beta"]
        if ii % (print_period) == 0:  # spsion nearly converged
            # we can speed up the process by checking for the error only all
            # the 10th iterations
            transp = G
            err = (
                nx.norm(nx.sum(transp, axis=0) - b) ** 2
                + nx.norm(nx.sum(transp, axis=1) - a) ** 2
            )
            if log:
                log["err"].append(err)
            if verbose:
                if ii % (print_period * 10) == 0:
                    print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
                print("{:5d}|{:8e}|".format(ii, err))
        if err <= stopThr and ii > numItermin:
            break
    else:
        if warn:
            warnings.warn(
                "Sinkhorn did not converge. You might want to "
                "increase the number of iterations `numItermax` "
                "or the regularization parameter `reg`."
            )
    if log:
        log["alpha"] = alpha
        log["beta"] = beta
        log["warmstart"] = (log["alpha"], log["beta"])
        log["niter"] = ii
        return G, log
    else:
        return G