Source code for ot.solvers

# -*- coding: utf-8 -*-
"""
General OT solvers with unified API
"""

# Author: Remi Flamary <remi.flamary@polytechnique.edu>
#
# License: MIT License

from .utils import OTResult, dist
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
                     entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2,
                     semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2,
                     entropic_semirelaxed_fused_gromov_wasserstein2,
                     entropic_semirelaxed_gromov_wasserstein2)
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn
from .optim import cg

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']


[docs] def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False, grad='autodiff'): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- M : array_like, shape (dim_a, dim_b) Loss matrix a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL". a tuple of functions can be provided for general solver (see :any:`cg`). This is only used when ``reg!=None``. unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL". method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False grad : str, optional Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the OT problems: - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve(M, a, b) - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python # default is ``"KL"`` regularization (``reg_type="KL"``) res = ot.solve(M, a, b, reg=1.0) # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') # Use envelope theorem differentiation for memory saving res = ot.solve(M, a, b, reg=1.0, grad='envelope') # M, a, b are torch tensors res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve(M,a,b,reg=1.0,reg_type='L2') - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: .. code-block:: python # default is ``"KL"`` res = ot.solve(M,a,b,unbalanced=1.0) # quadratic unbalanced OT res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') # TV = partial OT res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: .. code-block:: python # default is ``"KL"`` for both res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) # quadratic unbalanced OT with KL regularization res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') # both quadratic res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2') .. _references-solve: References ---------- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. """ # detect backend arr = [M] if a is not None: arr.append(a) if b is not None: arr.append(b) nx = get_backend(*arr) # create uniform weights if not given if a is None: a = nx.ones(M.shape[0], type_as=M) / M.shape[0] if b is None: b = nx.ones(M.shape[1], type_as=M) / M.shape[1] # default values for solutions potentials = None value = None value_linear = None plan = None status = None if reg is None or reg == 0: # exact OT if unbalanced is None: # Exact balanced OT # default values for EMD solver if max_iter is None: max_iter = 1000000 value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) value = value_linear potentials = (log['u'], log['v']) plan = log['G'] status = log["warning"] if log["warning"] is not None else 'Converged' elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT # default values for exact unbalanced OT if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) value_linear = log['cost'] if unbalanced_type.lower() == 'kl': value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) else: err_a = nx.sum(plan, 1) - a err_b = nx.sum(plan, 0) - b value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) elif unbalanced_type.lower() == 'tv': if max_iter is None: max_iter = 1000000 plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) value_linear = nx.sum(M * plan) err_a = nx.sum(plan, 1) - a err_b = nx.sum(plan, 0) - b value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + nx.sum(nx.abs(err_b)))) else: raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) else: # regularized OT if unbalanced is None: # Balanced regularized OT if isinstance(reg_type, tuple): # general solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = cg(a, b, M, reg=reg, f=reg_type[0], df=reg_type[1], numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) value_linear = nx.sum(M * plan) value = log['loss'][-1] potentials = (log['u'], log['v']) elif reg_type.lower() in ['entropy', 'kl']: if grad == 'envelope': # if envelope then detach the input M0, a0, b0 = M, a, b M, a, b = nx.detach(M, a, b) # default values for sinkhorn if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) if reg_type.lower() == 'entropy': value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) potentials = (log['log_u'], log['log_v']) if grad == 'envelope': # set the gradient at convergence value = nx.set_gradients(value, (M0, a0, b0), (plan, reg * (potentials[0] - potentials[0].mean()), reg * (potentials[1] - potentials[1].mean()))) elif reg_type.lower() == 'l2': if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan**2) potentials = (log['alpha'], log['beta']) else: raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) else: # unbalanced AND regularized OT if not isinstance(reg_type, tuple) and reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) potentials = (log['logu'], log['logv']) elif (isinstance(reg_type, tuple) or reg_type.lower() in ['kl', 'l2', 'entropy']) and unbalanced_type.lower() in ['kl', 'l2', 'tv']: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 if isinstance(reg_type, str): reg_type = reg_type.lower() plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type, regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True, G0=plan_init) value_linear = nx.sum(M * plan) value = log['loss'] else: raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, value_linear=value_linear, plan=plan, status=status, backend=nx) return res
[docs] def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha=0.5, reg=None, reg_type="entropy", unbalanced=None, unbalanced_type='KL', n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, verbose=False): r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object The function solves the following optimization problem: .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- Ca : array_like, shape (dim_a, dim_a) Cost matrix in the source domain Cb : array_like, shape (dim_b, dim_b) Cost matrix in the target domain M : array_like, shape (dim_a, dim_b), optional Linear cost matrix for Fused Gromov-Wasserstein (default is None). a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) loss : str, optional Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` symmetric : bool, optional Use symmetric version of the Gromov-Wasserstein problem, by default None tests whether the matrices are symmetric or True/False to avoid the test. reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R`, by default "entropy" (only used when ``reg!=None``) alpha : float, optional Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for Gromov problem (when M is not provided). By default ``alpha=None`` corresponds to ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT), Not implemented yet unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", "partial", by default "KL" but note that it is not implemented yet. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None tol : float, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the Gromov-Wasserstein problem: - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): .. math:: \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb) # uniform weights res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss plan = res.plan # GW plan value = res.value # GW value - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha plan = res.plan # FGW plan loss_linear_term = res.value_linear # Wasserstein part of the loss loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy plan = res.plan # FGW plan loss_linear_term = res.value_linear # Wasserstein part of the loss loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value (including regularization) - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW plan = res.plan # FGW plan right_marginal = res.marginal_b # right marginal of the plan - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} \mathbf{T}^T \mathbf{1} \leq \mathbf{b} \mathbf{T} \geq 0 \mathbf{1}^T\mathbf{T}\mathbf{1} = m can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 .. _references-solve-gromov: References ---------- .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics, 11(4), 417-487. .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), Gromov-Wasserstein averaging of kernel and distance matrices International Conference on Machine Learning (ICML). .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). Optimal Transport for structured data with application on graphs Proceedings of the 36th International Conference on Machine Learning (ICML). .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and applications on graphs. International Conference on Learning Representations (ICLR), 2022. .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport with Applications on Positive-Unlabeled Learning, Advances in Neural Information Processing Systems (NeurIPS), 2020. """ # detect backend nx = get_backend(Ca, Cb, M, a, b) # create uniform weights if not given if a is None: a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] if b is None: b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] # default values for solutions potentials = None value = None value_linear = None value_quad = None plan = None status = None log = None loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} if loss.lower() not in loss_dict.keys(): raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) loss_fun = loss_dict[loss.lower()] if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 plan = log['T'] potentials = (log['u'], log['v']) elif alpha == 0: # Wasserstein problem # default values for EMD solver if max_iter is None: max_iter = 1000000 value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) value = value_linear potentials = (log['u'], log['v']) plan = log['G'] status = log["warning"] if log["warning"] is not None else 'Converged' value_quad = 0 else: # Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] potentials = (log['u'], log['v']) elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # Semi relaxed Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] # potentials = (log['u'], log['v']) TODO elif unbalanced_type.lower() in ['partial']: # Partial OT if M is None: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) if loss.lower() != 'l2': raise (NotImplementedError('Partial GW only implemented with L2 loss')) if symmetric is not None: raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-7 value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_quad = value plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # partial FGW raise (NotImplementedError('Partial FGW not implemented yet')) elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) else: raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: method = 'PGD' value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) # potentials = (log['log_u'], log['log_v']) #TODO elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) potentials = (log['log_u'], log['log_v']) elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: method = 'PGD' value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] # potentials = (log['u'], log['v']) value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # Entropic Semi-relaxed FGW problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) elif unbalanced_type.lower() in ['partial']: # Partial OT if M is None: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) if loss.lower() != 'l2': raise (NotImplementedError('Partial GW only implemented with L2 loss')) if symmetric is not None: raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-7 value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_quad = value plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # partial FGW raise (NotImplementedError('Partial entropic FGW not implemented yet')) else: # unbalanced AND regularized OT raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) return res
[docs] def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False, grad='autodiff'): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) where the cost matrix :math:`\mathbf{M}` is computed from the samples in the source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where :math:`d` is a metric (by default the squared Euclidean distance). The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- X_s : array-like, shape (n_samples_a, dim) samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" lazy : bool, optional Return :any:`OTResultlazy` object to reduce memory cost when True, by default False batch_size : int, optional Batch size for lazy solver, by default None (default values in each solvers) method : str, optional Method for solving the problem, this can be used to select the solver for unbalanced problems (see :any:`ot.solve`), or to select a specific large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional Maximum number of iteration, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None rank : int, optional Rank of the OT matrix for lazy solers (method='factored'), by default 100 scaling : float, optional Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False grad : str, optional Type of gradient computation, either or 'autodiff' or 'envelope' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the OT problems: - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python res = ot.solve_sample(xa, xb, a, b) # for uniform weights res = ot.solve_sample(xa, xb) - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` regularization (``reg_type="KL"``) res = ot.solve_sample(xa, xb, a, b, reg=1.0) # or for original Sinkhorn paper formulation [2] res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') # lazy solver of memory complexity O(n) res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) # lazy OT plan lazy_plan = res.lazy_plan # Use envelope theorem differentiation for memory saving res = ot.solve_sample(xa, xb, a, b, reg=1.0, grad='envelope') res.value.backward() # only the value is differentiable Note that by default the Sinkhorn solver uses automatic differentiation to compute the gradients of the values and plan. This can be changed with the `grad` parameter. The `envelope` mode computes the gradients only for the value and the other outputs are detached. This is useful for memory saving when only the gradient of value is needed. We also have a very efficient solver with compiled CPU/CUDA code using geomloss/PyKeOps that can be used with the following code: .. code-block:: python # automatic solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') # force O(n) memory efficient solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') # force pre-computed cost matrix res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') # use multiscale solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') # One can play with speed (small scaling factor) and precision (scaling close to 1) res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) with M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) # quadratic unbalanced OT res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') # TV = partial OT res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) with M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` for both res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) # quadratic unbalanced OT with KL regularization res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') # both quadratic res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') - **Factored OT [2]** (when ``method='factored'``): This method solve the following OT problem [40]_ .. math:: \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated to the samples in the source and target domains, and :math:`W_2` is the Wasserstein distance. This problem is solved using exact OT solvers for `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides two transport plans that can be used to recover a low rank OT plan between the two distributions. .. code-block:: python res = ot.solve_sample(xa, xb, method='factored', rank=10) # recover the lazy low rank plan factored_solution_lazy = res.lazy_plan # recover the full low rank plan factored_solution = factored_solution_lazy[:] - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): This method computes the Gaussian Bures-Wasserstein distance between two Gaussian distributions estimated from teh empirical distributions .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} where : .. math:: \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) The covariances and means are estimated from the data. .. code-block:: python res = ot.solve_sample(xa, xb, method='gaussian') # recover the squared Gaussian Bures-Wasserstein distance BW_dist = res.value - **Wasserstein 1d [1]** (when ``method='1D'``): This method computes the Wasserstein distance between two 1d distributions estimated from the empirical distributions. For multivariate data the distances are computed independently for each dimension. .. code-block:: python res = ot.solve_sample(xa, xb, method='1D') # recover the squared Wasserstein distances W_dists = res.value .. _references-solve-sample: References ---------- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). Statistical optimal transport via factored couplings. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). Low-rank Sinkhorn Factorization. In International Conference on Machine Learning. """ if method is not None and method.lower() in lst_method_lazy: lazy0 = lazy lazy = True if not lazy: # default non lazy solver calls ot.solve # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose, grad) return res else: # Detect backend nx = get_backend(X_a, X_b, a, b) # default values for solutions potentials = None value = None value_linear = None plan = None lazy_plan = None status = None log = None method = method.lower() if method is not None else '' if method == '1d': # Wasserstein 1d (parallel on all dimensions) if metric == 'sqeuclidean': p = 2 elif metric in ['euclidean', 'cityblock']: p = 1 else: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) value = wasserstein_1d(X_a, X_b, a, b, p=p) value_linear = value elif method == 'gaussian': # Gaussian Bures-Wasserstein if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if reg is None: reg = 1e-6 value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) value = value**2 # return the value (squared bures distance) value_linear = value # return the value elif method == 'factored': # Factored OT if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if max_iter is None: max_iter = 100 if tol is None: tol = 1e-7 if reg is None: reg = 0 Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) log['X'] = X value_linear = log['costa'] + log['costb'] value = value_linear # TODO add reg term lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] elif method == "lowrank": if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if max_iter is None: max_iter = 2000 if tol is None: tol = 1e-7 if reg is None: reg = 0 Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) value = log['value'] value_linear = log['value_linear'] lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] elif method.startswith('geomloss'): # Geomloss solver for entropic OT split_method = method.split('_') if len(split_method) == 2: backend = split_method[1] else: if lazy0 is None: backend = 'auto' elif lazy0: backend = 'online' else: backend = 'tensorized' value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] # return scaled potentials (to be consistent with other solvers) potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) elif reg is None or reg == 0: # exact OT if unbalanced is None: # balanced EMD solver not available for lazy raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) else: raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) else: if unbalanced is None: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if batch_size is None: batch_size = 100 value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, isLazy=True, batchSize=batch_size, verbose=verbose, log=True) # compute potentials potentials = (log["u"], log["v"]) lazy_plan = log['lazy_plan'] else: raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) return res