# Source code for ot.solvers

# -*- coding: utf-8 -*-
"""
General OT solvers with unified API
"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#

from .utils import OTResult, dist
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2,
semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein2,
entropic_semirelaxed_gromov_wasserstein2)
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn
from .optim import cg

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']

[docs]
def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None,
r"""Solve the discrete optimal transport problem and return :any:OTResult object

The function solves the following general optimal transport problem

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

The regularization is selected with reg (:math:\lambda_r) and reg_type. By
default reg=None and there is no regularization. The unbalanced marginal
penalization can be selected with unbalanced (:math:\lambda_u) and
unbalanced_type. By default unbalanced=None and the function
solves the exact optimal transport problem (respecting the marginals).

Parameters
----------
M : array_like, shape (dim_a, dim_b)
Loss matrix
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:\lambda_r, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:R  either "KL", "L2", "entropy",
by default "KL". a tuple of functions can be provided for general
solver (see :any:cg). This is only used when reg!=None.
unbalanced : float, optional
Unbalanced penalization weight :math:\lambda_u, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:U  either "KL", "L2",
"TV", by default "KL".
method : str, optional
Method for solving the problem when multiple algorithms are available,
default None for automatic selection.
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iterations, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
Type of gradient computation, either or 'autodiff' or 'envelope'  used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (plan, value, value_linear) but with important memory cost.
'envelope' provides gradients only for value and and other outputs are
detached. This is useful for memory saving when only the value is needed.

Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:

- res.plan : OT plan :math:\mathbf{T}
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan

See :any:OTResult for more information.

Notes
-----

The following methods are available for solving the OT problems:

- **Classical exact OT problem [1]** (default parameters) :

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve(M, a, b)

- **Entropic regularized OT [2]** (when reg!=None):

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

# default is "KL" regularization (reg_type="KL")
res = ot.solve(M, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')

# Use envelope theorem differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors
res.value.backward() # only the value is differentiable

Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
grad parameter. The envelope mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.

- **Quadratic regularized OT [17]** (when reg!=None and reg_type="L2"):

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve(M,a,b,reg=1.0,reg_type='L2')

- **Unbalanced OT [41]** (when unbalanced!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

can be solved with the following code:

.. code-block:: python

# default is "KL"
res = ot.solve(M,a,b,unbalanced=1.0)
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV')

- **Regularized unbalanced regularized OT [34]** (when unbalanced!=None and reg!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

can be solved with the following code:

.. code-block:: python

# default is "KL" for both
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')

.. _references-solve:
References
----------

.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December).  Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
158). ACM.

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.

.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
Optimal Transport. Proceedings of the Twenty-First International
Conference on Artificial Intelligence and Statistics (AISTATS).

.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.

.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.

"""

# detect backend
arr = [M]
if a is not None:
arr.append(a)
if b is not None:
arr.append(b)
nx = get_backend(*arr)

# create uniform weights if not given
if a is None:
a = nx.ones(M.shape[0], type_as=M) / M.shape[0]
if b is None:
b = nx.ones(M.shape[1], type_as=M) / M.shape[1]

# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
status = None

if reg is None or reg == 0:  # exact OT

if unbalanced is None:  # Exact balanced OT

# default values for EMD solver
if max_iter is None:
max_iter = 1000000

value = value_linear
potentials = (log['u'], log['v'])
plan = log['G']
status = log["warning"] if log["warning"] is not None else 'Converged'

elif unbalanced_type.lower() in ['kl', 'l2']:  # unbalanced exact OT

# default values for exact unbalanced OT
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12

plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced,
div=unbalanced_type.lower(), numItermax=max_iter,
stopThr=tol, log=True,
verbose=verbose, G0=plan_init)

value_linear = log['cost']

if unbalanced_type.lower() == 'kl':
value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))
else:
err_a = nx.sum(plan, 1) - a
err_b = nx.sum(plan, 0) - b
value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2)

elif unbalanced_type.lower() == 'tv':

if max_iter is None:
max_iter = 1000000

plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter)

value_linear = nx.sum(M * plan)
err_a = nx.sum(plan, 1) - a
err_b = nx.sum(plan, 0) - b
value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) +
nx.sum(nx.abs(err_b))))

else:
raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))

else:  # regularized OT

if unbalanced is None:  # Balanced regularized OT

if isinstance(reg_type, tuple):  # general solver

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init)

value_linear = nx.sum(M * plan)
value = log['loss'][-1]
potentials = (log['u'], log['v'])

elif reg_type.lower() in ['entropy', 'kl']:

if grad == 'envelope':  # if envelope then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)

# default values for sinkhorn
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
stopThr=tol, log=True,
verbose=verbose)

value_linear = nx.sum(M * plan)

if reg_type.lower() == 'entropy':
value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
else:
value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :])

potentials = (log['log_u'], log['log_v'])

value = nx.set_gradients(value, (M0, a0, b0),
(plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean())))

elif reg_type.lower() == 'l2':

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose)

value_linear = nx.sum(M * plan)
value = value_linear + reg * nx.sum(plan**2)
potentials = (log['alpha'], log['beta'])

else:
raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))

else:  # unbalanced AND regularized OT

if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl':

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True)

value_linear = nx.sum(M * plan)

value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b))

potentials = (log['logu'], log['logv'])

elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']:

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12
if isinstance(reg_type, str):
reg_type = reg_type.lower()

plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init)

value_linear = nx.sum(M * plan)

value = log['loss']

else:
raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))

res = OTResult(potentials=potentials, value=value,
value_linear=value_linear, plan=plan, status=status, backend=nx)

return res

[docs]
def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None,
alpha=0.5, reg=None,
reg_type="entropy", unbalanced=None, unbalanced_type='KL',
verbose=False):
r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:OTResult object

The function solves the following optimization problem:

.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

The regularization is selected with reg (:math:\lambda_r) and
reg_type. By default reg=None and there is no regularization. The
unbalanced marginal penalization can be selected with unbalanced
(:math:\lambda_u) and unbalanced_type. By default unbalanced=None
and the function solves the exact optimal transport problem (respecting the
marginals).

Parameters
----------
Ca : array_like, shape (dim_a, dim_a)
Cost matrix in the source domain
Cb : array_like, shape (dim_b, dim_b)
Cost matrix in the target domain
M : array_like, shape (dim_a, dim_b), optional
Linear cost matrix for Fused Gromov-Wasserstein (default is None).
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
loss : str, optional
Type of loss function, either "L2" or "KL", by default "L2"
symmetric : bool, optional
Use symmetric version of the Gromov-Wasserstein problem, by default None
tests whether the matrices are symmetric or True/False to avoid the test.
reg : float, optional
Regularization weight :math:\lambda_r, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:R, by default "entropy" (only used when
reg!=None)
alpha : float, optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
Gromov problem (when M is not provided). By default alpha=None
corresponds to alpha=1 for Gromov problem (M==None) and
alpha=0.5 for Fused Gromov-Wasserstein problem (M!=None)
unbalanced : float, optional
Unbalanced penalization weight :math:\lambda_u, by default None
(balanced OT), Not implemented yet
unbalanced_type : str, optional
Type of unbalanced penalization function :math:U either "KL", "semirelaxed",
"partial", by default "KL" but note that it is not implemented yet.
Number of OMP threads for exact OT solver, by default 1
method : str, optional
Method for solving the problem when multiple algorithms are available,
default None for automatic selection.
max_iter : int, optional
Maximum number of iterations, by default None (default values in each
solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
tol : float, optional
Tolerance for solution precision, by default None (default values in
each solvers)
verbose : bool, optional
Print information in the solver, by default False

Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:

- res.plan : OT plan :math:\mathbf{T}
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan

See :any:OTResult for more information.

Notes
-----
The following methods are available for solving the Gromov-Wasserstein
problem:

- **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters):

.. math::
\min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve_gromov(Ca, Cb) # uniform weights
res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights
res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss

plan = res.plan # GW plan
value = res.value # GW value

- **Fused Gromov-Wasserstein (FGW) problem [24]** (when M!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha

plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss = res.value # FGW value

- **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when  reg!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T})

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy

plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss = res.value # FGW value (including regularization)

- **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when  unbalanced='semirelaxed'):

.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T} \geq 0

can be solved with the following code:

.. code-block:: python

res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW
res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW
res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW

plan = res.plan # FGW plan
right_marginal = res.marginal_b # right marginal of the plan

- **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when  unbalanced='partial'):

.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}

s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a}

\mathbf{T}^T \mathbf{1} \leq \mathbf{b}

\mathbf{T} \geq 0

\mathbf{1}^T\mathbf{T}\mathbf{1} = m

can be solved with the following code:

.. code-block:: python

res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8

.. _references-solve-gromov:
References
----------

.. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric
approach to object matching. Foundations of computational mathematics,
11(4), 417-487.

.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
Gromov-Wasserstein averaging of kernel and distance matrices
International Conference on Machine Learning (ICML).

.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
(2019). Optimal Transport for structured data with application on graphs
Proceedings of the 36th International Conference on Machine Learning
(ICML).

.. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer,
Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and
applications on graphs. International Conference on Learning
Representations (ICLR), 2022.

.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport
with Applications on Positive-Unlabeled Learning, Advances in Neural
Information Processing Systems (NeurIPS), 2020.

"""

# detect backend
nx = get_backend(Ca, Cb, M, a, b)

# create uniform weights if not given
if a is None:
a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0]
if b is None:
b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1]

# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
status = None
log = None

loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'}

if loss.lower() not in loss_dict.keys():
raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss)))
loss_fun = loss_dict[loss.lower()]

if reg is None or reg == 0:  # exact OT

if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']:  # Exact balanced OT

if M is None or alpha == 1:  # Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

if alpha == 1:  # set to 0 for FGW with alpha=1
value_linear = 0
plan = log['T']
potentials = (log['u'], log['v'])

elif alpha == 0:  # Wasserstein problem

# default values for EMD solver
if max_iter is None:
max_iter = 1000000

value = value_linear
potentials = (log['u'], log['v'])
plan = log['G']
status = log["warning"] if log["warning"] is not None else 'Converged'

else:  # Fused Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
plan = log['T']
potentials = (log['u'], log['v'])

elif unbalanced_type.lower() in ['semirelaxed']:  # Semi-relaxed  OT

if M is None or alpha == 1:  # Semi relaxed Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

if alpha == 1:  # set to 0 for FGW with alpha=1
value_linear = 0
plan = log['T']
# potentials = (log['u'], log['v']) TODO

else:  # Semi relaxed Fused Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9

value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
plan = log['T']
# potentials = (log['u'], log['v']) TODO

elif unbalanced_type.lower() in ['partial']:  # Partial OT

if M is None:  # Partial Gromov-Wasserstein problem

if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (ValueError('Partial GW mass given in reg is too large'))
if loss.lower() != 'l2':
raise (NotImplementedError('Partial GW only implemented with L2 loss'))
if symmetric is not None:
raise (NotImplementedError('Partial GW only implemented with symmetric=True'))

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-7

value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose)

plan = log['T']
# potentials = (log['u'], log['v']) TODO

else:  # partial FGW

raise (NotImplementedError('Partial FGW not implemented yet'))

elif unbalanced_type.lower() in ['kl', 'l2']:  # unbalanced exact OT

raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type)))

else:
raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type)))

else:  # regularized OT

if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']:  # Balanced regularized OT

if reg_type.lower() in ['entropy'] and (M is None or alpha == 1):  # Entropic Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = 'PGD'

value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

plan = log['T']
value_linear = 0
value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16))
# potentials = (log['log_u'], log['log_v'])  #TODO

elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0:  # Entropic Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter,
stopThr=tol, log=True,
verbose=verbose)

value_linear = nx.sum(M * plan)
value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
potentials = (log['log_u'], log['log_v'])

elif reg_type.lower() in ['entropy'] and M is not None:  # Entropic Fused Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = 'PGD'

value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
plan = log['T']
# potentials = (log['u'], log['v'])
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))

else:
raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type)))

elif unbalanced_type.lower() in ['semirelaxed']:  # Semi-relaxed  OT

if reg_type.lower() in ['entropy'] and (M is None or alpha == 1):  # Entropic Semi-relaxed Gromov-Wasserstein problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

plan = log['T']
value_linear = 0
value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16))

else:  # Entropic Semi-relaxed FGW problem

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9

value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose)

value_linear = log['lin_loss']
plan = log['T']
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))

elif unbalanced_type.lower() in ['partial']:  # Partial OT

if M is None:  # Partial Gromov-Wasserstein problem

if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (ValueError('Partial GW mass given in reg is too large'))
if loss.lower() != 'l2':
raise (NotImplementedError('Partial GW only implemented with L2 loss'))
if symmetric is not None:
raise (NotImplementedError('Partial GW only implemented with symmetric=True'))

# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-7

value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose)

plan = log['T']
# potentials = (log['u'], log['v']) TODO

else:  # partial FGW

raise (NotImplementedError('Partial entropic FGW not implemented yet'))

else:  # unbalanced AND regularized OT

raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type)))

res = OTResult(potentials=potentials, value=value,

return res

[docs]
def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL",
unbalanced=None,
unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95,
potentials_init=None, X_init=None, tol=None, verbose=False,
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.

The function solves the following general optimal transport problem

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

where the cost matrix :math:\mathbf{M} is computed from the samples in the
source and target domains such that :math:M_{i,j} = d(x_i,y_j) where
:math:d is a metric (by default the squared Euclidean distance).

The regularization is selected with reg (:math:\lambda_r) and reg_type. By
default reg=None and there is no regularization. The unbalanced marginal
penalization can be selected with unbalanced (:math:\lambda_u) and
unbalanced_type. By default unbalanced=None and the function
solves the exact optimal transport problem (respecting the marginals).

Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:\lambda_r, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:R  either "KL", "L2", "entropy", by default "KL"
unbalanced : float, optional
Unbalanced penalization weight :math:\lambda_u, by default None
(balanced OT)
unbalanced_type : str, optional
Type of unbalanced penalization function :math:U  either "KL", "L2", "TV", by default "KL"
lazy : bool, optional
Return :any:OTResultlazy object to reduce memory cost when True, by
default False
batch_size : int, optional
Batch size for lazy solver, by default None (default values in each
solvers)
method : str, optional
Method for solving the problem, this can be used to select the solver
for unbalanced problems (see :any:ot.solve), or to select a specific
large scale solver.
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
rank : int, optional
Rank of the OT matrix for lazy solers (method='factored'), by default 100
scaling : float, optional
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
Type of gradient computation, either or 'autodiff' or 'envelope'  used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (plan, value, value_linear) but with important memory cost.
'envelope' provides gradients only for value and and other outputs are
detached. This is useful for memory saving when only the value is needed.

Returns
-------

res : OTResult()
Result of the optimization problem. The information can be obtained as follows:

- res.plan : OT plan :math:\mathbf{T}
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.lazy_plan : Lazy OT plan (when lazy=True or lazy method)

See :any:OTResult for more information.

Notes
-----

The following methods are available for solving the OT problems:

- **Classical exact OT problem [1]** (default parameters) :

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0,  M_{i,j} = d(x_i,y_j)

can be solved with the following code:

.. code-block:: python

res = ot.solve_sample(xa, xb, a, b)

# for uniform weights
res = ot.solve_sample(xa, xb)

- **Entropic regularized OT [2]** (when reg!=None):

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0,  M_{i,j} = d(x_i,y_j)

can be solved with the following code:

.. code-block:: python

# default is "KL" regularization (reg_type="KL")
res = ot.solve_sample(xa, xb, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy')

# lazy solver of memory complexity O(n)
res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100)
# lazy OT plan
lazy_plan = res.lazy_plan

# Use envelope theorem differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope')
res.value.backward() # only the value is differentiable

Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
grad parameter. The envelope mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.

We also have a very efficient solver with compiled CPU/CUDA code using
geomloss/PyKeOps that can be used with the following code:

.. code-block:: python

# automatic solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss')

# force O(n) memory efficient solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online')

# force pre-computed cost matrix
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized')

# use multiscale solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale')

# One can play with speed (small scaling factor) and precision (scaling close to 1)
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5)

- **Quadratic regularized OT [17]** (when reg!=None and reg_type="L2"):

.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})

s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}

\mathbf{T}^T \mathbf{1} = \mathbf{b}

\mathbf{T} \geq 0,  M_{i,j} = d(x_i,y_j)

can be solved with the following code:

.. code-block:: python

res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2')

- **Unbalanced OT [41]** (when unbalanced!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

with  M_{i,j} = d(x_i,y_j)

can be solved with the following code:

.. code-block:: python

# default is "KL"
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0)
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV')

- **Regularized unbalanced regularized OT [34]** (when unbalanced!=None and reg!=None):

.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})

with  M_{i,j} = d(x_i,y_j)

can be solved with the following code:

.. code-block:: python

# default is "KL" for both
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2')
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2',
unbalanced=1.0, unbalanced_type='L2')

- **Factored OT [2]** (when method='factored'):

This method solve the following OT problem [40]_

.. math::

where $\mu$ is a uniform weighted empirical distribution of  :math:\mu_a and :math:\mu_b are the empirical measures associated
to the samples in the source and target domains, and :math:W_2 is the
Wasserstein distance. This problem is solved using exact OT solvers for
reg=None and the Sinkhorn solver for reg!=None. The solution provides
two transport plans that can be used to recover a low rank OT plan between
the two distributions.

.. code-block:: python

res = ot.solve_sample(xa, xb, method='factored', rank=10)

# recover the lazy low rank plan
factored_solution_lazy = res.lazy_plan

# recover the full low rank plan
factored_solution = factored_solution_lazy[:]

- **Gaussian Bures-Wasserstein [2]** (when method='gaussian'):

This method computes the Gaussian Bures-Wasserstein distance between two
Gaussian distributions estimated from teh empirical distributions

.. math::
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}

where :

.. math::
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)

The covariances and means are estimated from the data.

.. code-block:: python

res = ot.solve_sample(xa, xb, method='gaussian')

# recover the squared Gaussian Bures-Wasserstein distance
BW_dist = res.value

- **Wasserstein 1d [1]** (when method='1D'):

This method computes the Wasserstein distance between two 1d distributions
estimated from the empirical distributions. For multivariate data the
distances are computed independently for each dimension.

.. code-block:: python

res = ot.solve_sample(xa, xb, method='1D')

# recover the squared Wasserstein distances
W_dists = res.value

.. _references-solve-sample:
References
----------

.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December).  Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
158). ACM.

.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013

.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.

.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
Optimal Transport. Proceedings of the Twenty-First International
Conference on Artificial Intelligence and Statistics (AISTATS).

.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.

.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
G., & Weed, J. (2019, April). Statistical optimal transport via factored
couplings. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2454-2465). PMLR.

.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.

.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
Low-rank Sinkhorn Factorization. In International Conference on
Machine Learning.

"""

if method is not None and method.lower() in lst_method_lazy:
lazy0 = lazy
lazy = True

if not lazy:  # default non lazy solver calls ot.solve

# compute cost matrix M and use solve function
M = dist(X_a, X_b, metric)

res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad)

return res

else:

# Detect backend
nx = get_backend(X_a, X_b, a, b)

# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
lazy_plan = None
status = None
log = None

method = method.lower() if method is not None else ''

if method == '1d':  # Wasserstein 1d (parallel on all dimensions)
if metric == 'sqeuclidean':
p = 2
elif metric in ['euclidean', 'cityblock']:
p = 1
else:
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))

value = wasserstein_1d(X_a, X_b, a, b, p=p)
value_linear = value

elif method == 'gaussian':  # Gaussian Bures-Wasserstein

if not metric.lower() in ['sqeuclidean']:
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))

if reg is None:
reg = 1e-6

value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True)
value = value**2  # return the value (squared bures distance)
value_linear = value  # return the value

elif method == 'factored':  # Factored OT

if not metric.lower() in ['sqeuclidean']:
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))

if max_iter is None:
max_iter = 100
if tol is None:
tol = 1e-7
if reg is None:
reg = 0

Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose)
log['X'] = X

value_linear = log['costa'] + log['costb']
value = value_linear  # TODO add reg term
lazy_plan = log['lazy_plan']
if not lazy0:  # store plan if not lazy
plan = lazy_plan[:]

elif method == "lowrank":

if not metric.lower() in ['sqeuclidean']:
raise (NotImplementedError('Not implemented metric="{}"'.format(metric)))

if max_iter is None:
max_iter = 2000
if tol is None:
tol = 1e-7
if reg is None:
reg = 0

Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True)
value = log['value']
value_linear = log['value_linear']
lazy_plan = log['lazy_plan']
if not lazy0:  # store plan if not lazy
plan = lazy_plan[:]

elif method.startswith('geomloss'):  # Geomloss solver for entropic OT

split_method = method.split('_')
if len(split_method) == 2:
backend = split_method[1]
else:
if lazy0 is None:
backend = 'auto'
elif lazy0:
backend = 'online'
else:
backend = 'tensorized'

value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend)

lazy_plan = log['lazy_plan']
if not lazy0:  # store plan if not lazy
plan = lazy_plan[:]

# return scaled potentials (to be consistent with other solvers)
potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2))

elif reg is None or reg == 0:  # exact OT

if unbalanced is None:  # balanced EMD solver not available for lazy
raise (NotImplementedError('Exact OT solver with lazy=True not implemented'))

else:
raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type)))

else:
if unbalanced is None:

if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100

value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol,
isLazy=True, batchSize=batch_size, verbose=verbose, log=True)
# compute potentials
potentials = (log["u"], log["v"])
lazy_plan = log['lazy_plan']

else:
raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type)))

res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan,
value_linear=value_linear, plan=plan, status=status, backend=nx, log=log)
return res