# -*- coding: utf-8 -*-
"""
General OT solvers with unified API
"""
# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License
from .utils import OTResult, dist
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .smooth import smooth_ot_dual
from .gromov import (
gromov_wasserstein2,
fused_gromov_wasserstein2,
entropic_gromov_wasserstein2,
entropic_fused_gromov_wasserstein2,
semirelaxed_gromov_wasserstein2,
semirelaxed_fused_gromov_wasserstein2,
entropic_semirelaxed_fused_gromov_wasserstein2,
entropic_semirelaxed_gromov_wasserstein2,
partial_gromov_wasserstein2,
entropic_partial_gromov_wasserstein2,
)
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn
from .optim import cg
lst_method_lazy = [
"1d",
"gaussian",
"lowrank",
"factored",
"geomloss",
"geomloss_auto",
"geomloss_tensorized",
"geomloss_online",
"geomloss_multiscale",
]
[docs]
def solve(
M,
a=None,
b=None,
reg=None,
c=None,
reg_type="KL",
unbalanced=None,
unbalanced_type="KL",
method=None,
n_threads=1,
max_iter=None,
plan_init=None,
potentials_init=None,
tol=None,
verbose=False,
grad="autodiff",
):
r"""Solve the discrete optimal transport problem and return :any:`OTResult` object
The function solves the following general optimal transport problem
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b})
The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with `unbalanced` (:math:`(\lambda_1, \lambda_2)`) and
`unbalanced_type`. By default ``unbalanced=None`` and the function
solves the exact optimal transport problem (respecting the marginals).
Parameters
----------
M : array_like, shape (dim_a, dim_b)
Loss matrix
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
c : array-like (dim_a, dim_b), optional (default=None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", "entropy",
by default "KL". a tuple of functions can be provided for general
solver (see :any:`cg`). This is only used when ``reg!=None``.
unbalanced : float or indexable object of length 1 or 2
Marginal relaxation term.
If it is a scalar or an indexable object of length 1,
then the same relaxation is applied to both marginal relaxations.
The balanced OT can be recovered using :math:`unbalanced=float("inf")`.
For semi-relaxed case, use either
:math:`unbalanced=(float("inf"), scalar)` or
:math:`unbalanced=(scalar, float("inf"))`.
If unbalanced is an array,
it must have the same backend as input arrays `(a, b, M)`.
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2",
"TV", by default "KL".
method : str, optional
Method for solving the problem when multiple algorithms are available,
default None for automatic selection.
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iterations, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:
- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
See :any:`OTResult` for more information.
Notes
-----
The following methods are available for solving the OT problems:
- **Classical exact OT problem [1]** (default parameters) :
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve(M, a, b)
- **Entropic regularized OT [2]** (when ``reg!=None``):
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
# default is ``"KL"`` regularization (``reg_type="KL"``)
res = ot.solve(M, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve(M, a, b, reg=1.0, reg_type='entropy')
# Use envelope theorem differentiation for memory saving
res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `envelope` mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.
- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve(M,a,b,reg=1.0,reg_type='L2')
- **Unbalanced OT [41]** (when ``unbalanced!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} +
\lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b})
can be solved with the following code:
.. code-block:: python
# default is ``"KL"``
res = ot.solve(M,a,b,unbalanced=1.0)
# quadratic unbalanced OT
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV')
- **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_1 U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_2 U(\mathbf{T}^T\mathbf{1},\mathbf{b})
can be solved with the following code:
.. code-block:: python
# default is ``"KL"`` for both
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2')
# both quadratic
res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2')
.. _references-solve:
References
----------
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December). Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
158). ACM.
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
Optimal Transport. Proceedings of the Twenty-First International
Conference on Artificial Intelligence and Statistics (AISTATS).
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
"""
# detect backend
nx = get_backend(M, a, b, c)
# create uniform weights if not given
if a is None:
a = nx.ones(M.shape[0], type_as=M) / M.shape[0]
if b is None:
b = nx.ones(M.shape[1], type_as=M) / M.shape[1]
if c is None:
c = a[:, None] * b[None, :]
if reg is None:
reg = 0
# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
status = None
if reg == 0: # exact OT
if unbalanced is None: # Exact balanced OT
# default values for EMD solver
if max_iter is None:
max_iter = 1000000
value_linear, log = emd2(
a,
b,
M,
numItermax=max_iter,
log=True,
return_matrix=True,
numThreads=n_threads,
)
value = value_linear
potentials = (log["u"], log["v"])
plan = log["G"]
status = log["warning"] if log["warning"] is not None else "Converged"
elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT
# default values for exact unbalanced OT
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12
plan, log = mm_unbalanced(
a,
b,
M,
reg_m=unbalanced,
c=c,
reg=reg,
div=unbalanced_type,
numItermax=max_iter,
stopThr=tol,
log=True,
verbose=verbose,
G0=plan_init,
)
value_linear = log["cost"]
value = log["total_cost"]
elif unbalanced_type.lower() == "tv":
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12
if isinstance(reg_type, str):
reg_type = reg_type.lower()
plan, log = lbfgsb_unbalanced(
a,
b,
M,
reg=reg,
reg_m=unbalanced,
c=c,
reg_div=reg_type,
regm_div=unbalanced_type,
numItermax=max_iter,
stopThr=tol,
verbose=verbose,
log=True,
G0=plan_init,
)
value_linear = log["cost"]
value = log["total_cost"]
else:
raise (
NotImplementedError(
'Unknown unbalanced_type="{}"'.format(unbalanced_type)
)
)
else: # regularized OT
if unbalanced is None: # Balanced regularized OT
if isinstance(reg_type, tuple): # general solver
f, df = reg_type
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
plan, log = cg(
a,
b,
M,
reg=reg,
f=f,
df=df,
numItermax=max_iter,
stopThr=tol,
log=True,
verbose=verbose,
G0=plan_init,
)
value_linear = nx.sum(M * plan)
value = log["loss"][-1]
potentials = (log["u"], log["v"])
elif reg_type.lower() in ["entropy", "kl"]:
if grad == "envelope": # if envelope then detach the input
M0, a0, b0 = M, a, b
M, a, b = nx.detach(M, a, b)
# default values for sinkhorn
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
plan, log = sinkhorn_log(
a,
b,
M,
reg=reg,
numItermax=max_iter,
stopThr=tol,
log=True,
verbose=verbose,
)
value_linear = nx.sum(M * plan)
if reg_type.lower() == "entropy":
value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
else:
value = value_linear + reg * nx.kl_div(
plan, a[:, None] * b[None, :]
)
potentials = (log["log_u"], log["log_v"])
if grad == "envelope": # set the gradient at convergence
value = nx.set_gradients(
value,
(M0, a0, b0),
(
plan,
reg * (potentials[0] - potentials[0].mean()),
reg * (potentials[1] - potentials[1].mean()),
),
)
elif reg_type.lower() == "l2":
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
plan, log = smooth_ot_dual(
a,
b,
M,
reg=reg,
numItermax=max_iter,
stopThr=tol,
log=True,
verbose=verbose,
)
value_linear = nx.sum(M * plan)
value = value_linear + reg * nx.sum(plan**2)
potentials = (log["alpha"], log["beta"])
else:
raise (
NotImplementedError(
'Not implemented reg_type="{}"'.format(reg_type)
)
)
else: # unbalanced AND regularized OT
if (
not isinstance(reg_type, tuple)
and reg_type.lower() in ["kl"]
and unbalanced_type.lower() == "kl"
):
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
plan, log = sinkhorn_knopp_unbalanced(
a,
b,
M,
reg=reg,
reg_m=unbalanced,
method=method,
reg_type=reg_type,
c=c,
warmstart=potentials_init,
numItermax=max_iter,
stopThr=tol,
verbose=verbose,
log=True,
)
value_linear = log["cost"]
value = log["total_cost"]
potentials = (log["logu"], log["logv"])
elif (
isinstance(reg_type, tuple)
or reg_type.lower() in ["kl", "l2", "entropy"]
) and unbalanced_type.lower() in ["kl", "l2", "tv"]:
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-12
if isinstance(reg_type, str):
reg_type = reg_type.lower()
plan, log = lbfgsb_unbalanced(
a,
b,
M,
reg=reg,
reg_m=unbalanced,
c=c,
reg_div=reg_type,
regm_div=unbalanced_type,
numItermax=max_iter,
stopThr=tol,
verbose=verbose,
log=True,
G0=plan_init,
)
value_linear = log["cost"]
value = log["total_cost"]
else:
raise (
NotImplementedError(
'Not implemented reg_type="{}" and unbalanced_type="{}"'.format(
reg_type, unbalanced_type
)
)
)
res = OTResult(
potentials=potentials,
value=value,
value_linear=value_linear,
plan=plan,
status=status,
backend=nx,
)
return res
[docs]
def solve_gromov(
Ca,
Cb,
M=None,
a=None,
b=None,
loss="L2",
symmetric=None,
alpha=0.5,
reg=None,
reg_type="entropy",
unbalanced=None,
unbalanced_type="KL",
n_threads=1,
method=None,
max_iter=None,
plan_init=None,
tol=None,
verbose=False,
):
r"""Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object
The function solves the following optimization problem:
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
The regularization is selected with `reg` (:math:`\lambda_r`) and
`reg_type`. By default ``reg=None`` and there is no regularization. The
unbalanced marginal penalization can be selected with `unbalanced`
(:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None``
and the function solves the exact optimal transport problem (respecting the
marginals).
Parameters
----------
Ca : array_like, shape (dim_a, dim_a)
Cost matrix in the source domain
Cb : array_like, shape (dim_b, dim_b)
Cost matrix in the target domain
M : array_like, shape (dim_a, dim_b), optional
Linear cost matrix for Fused Gromov-Wasserstein (default is None).
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
loss : str, optional
Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"``
symmetric : bool, optional
Use symmetric version of the Gromov-Wasserstein problem, by default None
tests whether the matrices are symmetric or True/False to avoid the test.
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
reg_type : str, optional
Type of regularization :math:`R`, by default "entropy" (only used when
``reg!=None``)
alpha : float, optional
Weight the quadratic term (alpha*Gromov) and the linear term
((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for
Gromov problem (when M is not provided). By default ``alpha=None``
corresponds to ``alpha=1`` for Gromov problem (``M==None``) and
``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``)
unbalanced : float, optional
Unbalanced penalization weight :math:`\lambda_u`, by default None
(balanced OT), Not implemented yet
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed",
"partial", by default "KL" but note that it is not implemented yet.
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
method : str, optional
Method for solving the problem when multiple algorithms are available,
default None for automatic selection.
max_iter : int, optional
Maximum number of iterations, by default None (default values in each
solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
tol : float, optional
Tolerance for solution precision, by default None (default values in
each solvers)
verbose : bool, optional
Print information in the solver, by default False
Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:
- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan
See :any:`OTResult` for more information.
Notes
-----
The following methods are available for solving the Gromov-Wasserstein
problem:
- **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters):
.. math::
\min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve_gromov(Ca, Cb) # uniform weights
res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights
res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss
plan = res.plan # GW plan
value = res.value # GW value
- **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha
plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss_quad_term = res.value_quad # Gromov part of the loss
loss = res.value # FGW value
- **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default)
res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy
plan = res.plan # FGW plan
loss_linear_term = res.value_linear # Wasserstein part of the loss
loss_quad_term = res.value_quad # Gromov part of the loss
loss = res.value # FGW value (including regularization)
- **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``):
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T} \geq 0
can be solved with the following code:
.. code-block:: python
res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW
res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW
res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW
plan = res.plan # FGW plan
right_marginal = res.marginal_b # right marginal of the plan
- **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``):
.. math::
\min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F +
\alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l}
s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a}
\mathbf{T}^T \mathbf{1} \leq \mathbf{b}
\mathbf{T} \geq 0
\mathbf{1}^T\mathbf{T}\mathbf{1} = m
can be solved with the following code:
.. code-block:: python
res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8
.. _references-solve-gromov:
References
----------
.. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric
approach to object matching. Foundations of computational mathematics,
11(4), 417-487.
.. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016),
Gromov-Wasserstein averaging of kernel and distance matrices
International Conference on Machine Learning (ICML).
.. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N.
(2019). Optimal Transport for structured data with application on graphs
Proceedings of the 36th International Conference on Machine Learning
(ICML).
.. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer,
Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and
applications on graphs. International Conference on Learning
Representations (ICLR), 2022.
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport
with Applications on Positive-Unlabeled Learning, Advances in Neural
Information Processing Systems (NeurIPS), 2020.
"""
# detect backend
nx = get_backend(Ca, Cb, M, a, b)
# create uniform weights if not given
if a is None:
a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0]
if b is None:
b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1]
# default values for solutions
potentials = None
value = None
value_linear = None
value_quad = None
plan = None
status = None
log = None
loss_dict = {"l2": "square_loss", "kl": "kl_loss"}
if loss.lower() not in loss_dict.keys():
raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss)))
loss_fun = loss_dict[loss.lower()]
if reg is None or reg == 0: # exact OT
if unbalanced is None and unbalanced_type.lower() not in [
"semirelaxed"
]: # Exact balanced OT
if M is None or alpha == 1: # Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9
value, log = gromov_wasserstein2(
Ca,
Cb,
a,
b,
loss_fun=loss_fun,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
value_quad = value
if alpha == 1: # set to 0 for FGW with alpha=1
value_linear = 0
plan = log["T"]
potentials = (log["u"], log["v"])
elif alpha == 0: # Wasserstein problem
# default values for EMD solver
if max_iter is None:
max_iter = 1000000
value_linear, log = emd2(
a,
b,
M,
numItermax=max_iter,
log=True,
return_matrix=True,
numThreads=n_threads,
)
value = value_linear
potentials = (log["u"], log["v"])
plan = log["G"]
status = log["warning"] if log["warning"] is not None else "Converged"
value_quad = 0
else: # Fused Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9
value, log = fused_gromov_wasserstein2(
M,
Ca,
Cb,
a,
b,
loss_fun=loss_fun,
alpha=alpha,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
value_linear = log["lin_loss"]
value_quad = log["quad_loss"]
plan = log["T"]
potentials = (log["u"], log["v"])
elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT
if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9
value, log = semirelaxed_gromov_wasserstein2(
Ca,
Cb,
a,
loss_fun=loss_fun,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
value_quad = value
if alpha == 1: # set to 0 for FGW with alpha=1
value_linear = 0
plan = log["T"]
# potentials = (log['u'], log['v']) TODO
else: # Semi relaxed Fused Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 10000
if tol is None:
tol = 1e-9
value, log = semirelaxed_fused_gromov_wasserstein2(
M,
Ca,
Cb,
a,
loss_fun=loss_fun,
alpha=alpha,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
value_linear = log["lin_loss"]
value_quad = log["quad_loss"]
plan = log["T"]
# potentials = (log['u'], log['v']) TODO
elif unbalanced_type.lower() in ["partial"]: # Partial OT
if M is None: # Partial Gromov-Wasserstein problem
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (ValueError("Partial GW mass given in reg is too large"))
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-7
value, log = partial_gromov_wasserstein2(
Ca,
Cb,
a,
b,
m=unbalanced,
loss_fun=loss_fun,
log=True,
numItermax=max_iter,
G0=plan_init,
tol=tol,
symmetric=symmetric,
verbose=verbose,
)
value_quad = value
plan = log["T"]
# potentials = (log['u'], log['v']) TODO
else: # partial FGW
raise (NotImplementedError("Partial FGW not implemented yet"))
elif unbalanced_type.lower() in ["kl", "l2"]: # unbalanced exact OT
raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type)))
else:
raise (
NotImplementedError(
'Unknown unbalanced_type="{}"'.format(unbalanced_type)
)
)
else: # regularized OT
if unbalanced is None and unbalanced_type.lower() not in [
"semirelaxed"
]: # Balanced regularized OT
if reg_type.lower() in ["entropy"] and (
M is None or alpha == 1
): # Entropic Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = "PGD"
value_quad, log = entropic_gromov_wasserstein2(
Ca,
Cb,
a,
b,
epsilon=reg,
loss_fun=loss_fun,
log=True,
symmetric=symmetric,
solver=method,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
plan = log["T"]
value_linear = 0
value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16))
# potentials = (log['log_u'], log['log_v']) #TODO
elif (
reg_type.lower() in ["entropy"] and M is not None and alpha == 0
): # Entropic Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
plan, log = sinkhorn_log(
a,
b,
M,
reg=reg,
numItermax=max_iter,
stopThr=tol,
log=True,
verbose=verbose,
)
value_linear = nx.sum(M * plan)
value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16))
potentials = (log["log_u"], log["log_v"])
elif (
reg_type.lower() in ["entropy"] and M is not None
): # Entropic Fused Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if method is None:
method = "PGD"
value_noreg, log = entropic_fused_gromov_wasserstein2(
M,
Ca,
Cb,
a,
b,
loss_fun=loss_fun,
alpha=alpha,
log=True,
symmetric=symmetric,
solver=method,
max_iter=max_iter,
G0=plan_init,
tol_rel=tol,
tol_abs=tol,
verbose=verbose,
)
value_linear = log["lin_loss"]
value_quad = log["quad_loss"]
plan = log["T"]
# potentials = (log['u'], log['v'])
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))
else:
raise (
NotImplementedError(
'Not implemented reg_type="{}"'.format(reg_type)
)
)
elif unbalanced_type.lower() in ["semirelaxed"]: # Semi-relaxed OT
if reg_type.lower() in ["entropy"] and (
M is None or alpha == 1
): # Entropic Semi-relaxed Gromov-Wasserstein problem
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
value_quad, log = entropic_semirelaxed_gromov_wasserstein2(
Ca,
Cb,
a,
epsilon=reg,
loss_fun=loss_fun,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol=tol,
verbose=verbose,
)
plan = log["T"]
value_linear = 0
value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16))
else: # Entropic Semi-relaxed FGW problem
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(
M,
Ca,
Cb,
a,
loss_fun=loss_fun,
alpha=alpha,
log=True,
symmetric=symmetric,
max_iter=max_iter,
G0=plan_init,
tol=tol,
verbose=verbose,
)
value_linear = log["lin_loss"]
value_quad = log["quad_loss"]
plan = log["T"]
value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16))
elif unbalanced_type.lower() in ["partial"]: # Partial OT
if M is None: # Partial Gromov-Wasserstein problem
if unbalanced > nx.sum(a) or unbalanced > nx.sum(b):
raise (ValueError("Partial GW mass given in reg is too large"))
# default values for solver
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-7
value_quad, log = entropic_partial_gromov_wasserstein2(
Ca,
Cb,
a,
b,
reg=reg,
loss_fun=loss_fun,
m=unbalanced,
log=True,
numItermax=max_iter,
G0=plan_init,
tol=tol,
symmetric=symmetric,
verbose=verbose,
)
value_quad = value
plan = log["T"]
# potentials = (log['u'], log['v']) TODO
else: # partial FGW
raise (NotImplementedError("Partial entropic FGW not implemented yet"))
else: # unbalanced AND regularized OT
raise (
NotImplementedError(
'Not implemented reg_type="{}" and unbalanced_type="{}"'.format(
reg_type, unbalanced_type
)
)
)
res = OTResult(
potentials=potentials,
value=value,
value_linear=value_linear,
value_quad=value_quad,
plan=plan,
status=status,
backend=nx,
log=log,
)
return res
[docs]
def solve_sample(
X_a,
X_b,
a=None,
b=None,
metric="sqeuclidean",
reg=None,
c=None,
reg_type="KL",
unbalanced=None,
unbalanced_type="KL",
lazy=False,
batch_size=None,
method=None,
n_threads=1,
max_iter=None,
plan_init=None,
rank=100,
scaling=0.95,
potentials_init=None,
X_init=None,
tol=None,
verbose=False,
grad="autodiff",
):
r"""Solve the discrete optimal transport problem using the samples in the source and target domains.
The function solves the following general optimal transport problem
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) +
\lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) +
\lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
where the cost matrix :math:`\mathbf{M}` is computed from the samples in the
source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where
:math:`d` is a metric (by default the squared Euclidean distance).
The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By
default ``reg=None`` and there is no regularization. The unbalanced marginal
penalization can be selected with `unbalanced` (:math:`\lambda_u`) and
`unbalanced_type`. By default ``unbalanced=None`` and the function
solves the exact optimal transport problem (respecting the marginals).
Parameters
----------
X_s : array-like, shape (n_samples_a, dim)
samples in the source domain
X_t : array-like, shape (n_samples_b, dim)
samples in the target domain
a : array-like, shape (dim_a,), optional
Samples weights in the source domain (default is uniform)
b : array-like, shape (dim_b,), optional
Samples weights in the source domain (default is uniform)
reg : float, optional
Regularization weight :math:`\lambda_r`, by default None (no reg., exact
OT)
c : array-like (dim_a, dim_b), optional (default=None)
Reference measure for the regularization.
If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
If :math:`\texttt{reg_type}='entropy'`, then :math:`\mathbf{c} = 1_{dim_a} 1_{dim_b}^T`.
reg_type : str, optional
Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL"
unbalanced : float or indexable object of length 1 or 2
Marginal relaxation term.
If it is a scalar or an indexable object of length 1,
then the same relaxation is applied to both marginal relaxations.
The balanced OT can be recovered using :math:`unbalanced=float("inf")`.
For semi-relaxed case, use either
:math:`unbalanced=(float("inf"), scalar)` or
:math:`unbalanced=(scalar, float("inf"))`.
If unbalanced is an array,
it must have the same backend as input arrays `(a, b, M)`.
unbalanced_type : str, optional
Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL"
lazy : bool, optional
Return :any:`OTResultlazy` object to reduce memory cost when True, by
default False
batch_size : int, optional
Batch size for lazy solver, by default None (default values in each
solvers)
method : str, optional
Method for solving the problem, this can be used to select the solver
for unbalanced problems (see :any:`ot.solve`), or to select a specific
large scale solver.
n_threads : int, optional
Number of OMP threads for exact OT solver, by default 1
max_iter : int, optional
Maximum number of iteration, by default None (default values in each solvers)
plan_init : array_like, shape (dim_a, dim_b), optional
Initialization of the OT plan for iterative methods, by default None
rank : int, optional
Rank of the OT matrix for lazy solers (method='factored'), by default 100
scaling : float, optional
Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95
potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional
Initialization of the OT dual potentials for iterative methods, by default None
tol : _type_, optional
Tolerance for solution precision, by default None (default values in each solvers)
verbose : bool, optional
Print information in the solver, by default False
grad : str, optional
Type of gradient computation, either or 'autodiff' or 'envelope' used only for
Sinkhorn solver. By default 'autodiff' provides gradients wrt all
outputs (`plan, value, value_linear`) but with important memory cost.
'envelope' provides gradients only for `value` and and other outputs are
detached. This is useful for memory saving when only the value is needed.
Returns
-------
res : OTResult()
Result of the optimization problem. The information can be obtained as follows:
- res.plan : OT plan :math:`\mathbf{T}`
- res.potentials : OT dual potentials
- res.value : Optimal value of the optimization problem
- res.value_linear : Linear OT loss with the optimal OT plan
- res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method)
See :any:`OTResult` for more information.
Notes
-----
The following methods are available for solving the OT problems:
- **Classical exact OT problem [1]** (default parameters) :
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j)
can be solved with the following code:
.. code-block:: python
res = ot.solve_sample(xa, xb, a, b)
# for uniform weights
res = ot.solve_sample(xa, xb)
- **Entropic regularized OT [2]** (when ``reg!=None``):
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j)
can be solved with the following code:
.. code-block:: python
# default is ``"KL"`` regularization (``reg_type="KL"``)
res = ot.solve_sample(xa, xb, a, b, reg=1.0)
# or for original Sinkhorn paper formulation [2]
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy')
# lazy solver of memory complexity O(n)
res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100)
# lazy OT plan
lazy_plan = res.lazy_plan
# Use envelope theorem differentiation for memory saving
res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope')
res.value.backward() # only the value is differentiable
Note that by default the Sinkhorn solver uses automatic differentiation to
compute the gradients of the values and plan. This can be changed with the
`grad` parameter. The `envelope` mode computes the gradients only
for the value and the other outputs are detached. This is useful for
memory saving when only the gradient of value is needed.
We also have a very efficient solver with compiled CPU/CUDA code using
geomloss/PyKeOps that can be used with the following code:
.. code-block:: python
# automatic solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss')
# force O(n) memory efficient solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online')
# force pre-computed cost matrix
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized')
# use multiscale solver
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale')
# One can play with speed (small scaling factor) and precision (scaling close to 1)
res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5)
- **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``):
.. math::
\min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T})
s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a}
\mathbf{T}^T \mathbf{1} = \mathbf{b}
\mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j)
can be solved with the following code:
.. code-block:: python
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2')
- **Unbalanced OT [41]** (when ``unbalanced!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
with M_{i,j} = d(x_i,y_j)
can be solved with the following code:
.. code-block:: python
# default is ``"KL"``
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0)
# quadratic unbalanced OT
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2')
# TV = partial OT
res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV')
- **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``):
.. math::
\min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b})
with M_{i,j} = d(x_i,y_j)
can be solved with the following code:
.. code-block:: python
# default is ``"KL"`` for both
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0)
# quadratic unbalanced OT with KL regularization
res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2')
# both quadratic
res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2',
unbalanced=1.0, unbalanced_type='L2')
- **Factored OT [2]** (when ``method='factored'``):
This method solve the following OT problem [40]_
.. math::
\mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b)
where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated
to the samples in the source and target domains, and :math:`W_2` is the
Wasserstein distance. This problem is solved using exact OT solvers for
`reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides
two transport plans that can be used to recover a low rank OT plan between
the two distributions.
.. code-block:: python
res = ot.solve_sample(xa, xb, method='factored', rank=10)
# recover the lazy low rank plan
factored_solution_lazy = res.lazy_plan
# recover the full low rank plan
factored_solution = factored_solution_lazy[:]
- **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``):
This method computes the Gaussian Bures-Wasserstein distance between two
Gaussian distributions estimated from the empirical distributions
.. math::
\mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2}
where :
.. math::
\mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right)
The covariances and means are estimated from the data.
.. code-block:: python
res = ot.solve_sample(xa, xb, method='gaussian')
# recover the squared Gaussian Bures-Wasserstein distance
BW_dist = res.value
- **Wasserstein 1d [1]** (when ``method='1D'``):
This method computes the Wasserstein distance between two 1d distributions
estimated from the empirical distributions. For multivariate data the
distances are computed independently for each dimension.
.. code-block:: python
res = ot.solve_sample(xa, xb, method='1D')
# recover the squared Wasserstein distances
W_dists = res.value
.. _references-solve-sample:
References
----------
.. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W.
(2011, December). Displacement interpolation using Lagrangian mass
transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p.
158). ACM.
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation
of Optimal Transport, Advances in Neural Information Processing
Systems (NIPS) 26, 2013
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems.
arXiv preprint arXiv:1607.05816.
.. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse
Optimal Transport. Proceedings of the Twenty-First International
Conference on Artificial Intelligence and Statistics (AISTATS).
.. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé,
A., & Peyré, G. (2019, April). Interpolating between optimal transport
and MMD using Sinkhorn divergences. In The 22nd International Conference
on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR.
.. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger,
G., & Weed, J. (2019, April). Statistical optimal transport via factored
couplings. In The 22nd International Conference on Artificial
Intelligence and Statistics (pp. 2454-2465). PMLR.
.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
.. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021).
Low-rank Sinkhorn Factorization. In International Conference on
Machine Learning.
"""
if method is not None and method.lower() in lst_method_lazy:
lazy0 = lazy
lazy = True
if not lazy: # default non lazy solver calls ot.solve
# compute cost matrix M and use solve function
M = dist(X_a, X_b, metric)
res = solve(
M,
a,
b,
reg,
c,
reg_type,
unbalanced,
unbalanced_type,
method,
n_threads,
max_iter,
plan_init,
potentials_init,
tol,
verbose,
grad,
)
return res
else:
# Detect backend
nx = get_backend(X_a, X_b, a, b)
# default values for solutions
potentials = None
value = None
value_linear = None
plan = None
lazy_plan = None
status = None
log = None
method = method.lower() if method is not None else ""
if method == "1d": # Wasserstein 1d (parallel on all dimensions)
if metric == "sqeuclidean":
p = 2
elif metric in ["euclidean", "cityblock"]:
p = 1
else:
raise (
NotImplementedError('Not implemented metric="{}"'.format(metric))
)
value = wasserstein_1d(X_a, X_b, a, b, p=p)
value_linear = value
elif method == "gaussian": # Gaussian Bures-Wasserstein
if metric.lower() not in ["sqeuclidean"]:
raise (
NotImplementedError('Not implemented metric="{}"'.format(metric))
)
if reg is None:
reg = 1e-6
value, log = empirical_bures_wasserstein_distance(
X_a, X_b, reg=reg, log=True
)
value = value**2 # return the value (squared bures distance)
value_linear = value # return the value
elif method == "factored": # Factored OT
if metric.lower() not in ["sqeuclidean"]:
raise (
NotImplementedError('Not implemented metric="{}"'.format(metric))
)
if max_iter is None:
max_iter = 100
if tol is None:
tol = 1e-7
if reg is None:
reg = 0
Q, R, X, log = factored_optimal_transport(
X_a,
X_b,
reg=reg,
r=rank,
log=True,
stopThr=tol,
numItermax=max_iter,
verbose=verbose,
)
log["X"] = X
value_linear = log["costa"] + log["costb"]
value = value_linear # TODO add reg term
lazy_plan = log["lazy_plan"]
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]
elif method == "lowrank":
if metric.lower() not in ["sqeuclidean"]:
raise (
NotImplementedError('Not implemented metric="{}"'.format(metric))
)
if max_iter is None:
max_iter = 2000
if tol is None:
tol = 1e-7
if reg is None:
reg = 0
Q, R, g, log = lowrank_sinkhorn(
X_a,
X_b,
rank=rank,
reg=reg,
a=a,
b=b,
numItermax=max_iter,
stopThr=tol,
log=True,
)
value = log["value"]
value_linear = log["value_linear"]
lazy_plan = log["lazy_plan"]
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]
elif method.startswith("geomloss"): # Geomloss solver for entropic OT
split_method = method.split("_")
if len(split_method) == 2:
backend = split_method[1]
else:
if lazy0 is None:
backend = "auto"
elif lazy0:
backend = "online"
else:
backend = "tensorized"
value, log = empirical_sinkhorn2_geomloss(
X_a,
X_b,
reg=reg,
a=a,
b=b,
metric=metric,
log=True,
verbose=verbose,
scaling=scaling,
backend=backend,
)
lazy_plan = log["lazy_plan"]
if not lazy0: # store plan if not lazy
plan = lazy_plan[:]
# return scaled potentials (to be consistent with other solvers)
potentials = (
log["f"] / (lazy_plan.blur**2),
log["g"] / (lazy_plan.blur**2),
)
elif reg is None or reg == 0: # exact OT
if unbalanced is None: # balanced EMD solver not available for lazy
raise (
NotImplementedError(
"Exact OT solver with lazy=True not implemented"
)
)
else:
raise (
NotImplementedError(
'Non regularized solver with unbalanced_type="{}" not implemented'.format(
unbalanced_type
)
)
)
else:
if unbalanced is None:
if max_iter is None:
max_iter = 1000
if tol is None:
tol = 1e-9
if batch_size is None:
batch_size = 100
value_linear, log = empirical_sinkhorn2(
X_a,
X_b,
reg,
a,
b,
metric=metric,
numIterMax=max_iter,
stopThr=tol,
isLazy=True,
batchSize=batch_size,
verbose=verbose,
log=True,
)
# compute potentials
potentials = (log["u"], log["v"])
lazy_plan = log["lazy_plan"]
else:
raise (
NotImplementedError(
'Not implemented unbalanced_type="{}" with regularization'.format(
unbalanced_type
)
)
)
res = OTResult(
potentials=potentials,
value=value,
lazy_plan=lazy_plan,
value_linear=value_linear,
plan=plan,
status=status,
backend=nx,
log=log,
)
return res