Source code for ot.solvers

# -*- coding: utf-8 -*-
General OT solvers with unified API

# Author: Remi Flamary <>
# License: MIT License

from .utils import OTResult, dist
from .lp import emd2, wasserstein_1d
from .backend import get_backend
from .unbalanced import mm_unbalanced, sinkhorn_knopp_unbalanced, lbfgsb_unbalanced
from .bregman import sinkhorn_log, empirical_sinkhorn2, empirical_sinkhorn2_geomloss
from .partial import partial_wasserstein_lagrange
from .smooth import smooth_ot_dual
from .gromov import (gromov_wasserstein2, fused_gromov_wasserstein2,
                     entropic_gromov_wasserstein2, entropic_fused_gromov_wasserstein2,
                     semirelaxed_gromov_wasserstein2, semirelaxed_fused_gromov_wasserstein2,
from .partial import partial_gromov_wasserstein2, entropic_partial_gromov_wasserstein2
from .gaussian import empirical_bures_wasserstein_distance
from .factored import factored_optimal_transport
from .lowrank import lowrank_sinkhorn

lst_method_lazy = ['1d', 'gaussian', 'lowrank', 'factored', 'geomloss', 'geomloss_auto', 'geomloss_tensorized', 'geomloss_online', 'geomloss_multiscale']

[docs] def solve(M, a=None, b=None, reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', method=None, n_threads=1, max_iter=None, plan_init=None, potentials_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem and return :any:`OTResult` object The function solves the following general optimal transport problem .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- M : array_like, shape (dim_a, dim_b) Loss matrix a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the OT problems: - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve(M, a, b) - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python # default is ``"KL"`` regularization (``reg_type="KL"``) res = ot.solve(M, a, b, reg=1.0) # or for original Sinkhorn paper formulation [2] res = ot.solve(M, a, b, reg=1.0, reg_type='entropy') - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve(M,a,b,reg=1.0,reg_type='L2') - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: .. code-block:: python # default is ``"KL"`` res = ot.solve(M,a,b,unbalanced=1.0) # quadratic unbalanced OT res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='L2') # TV = partial OT res = ot.solve(M,a,b,unbalanced=1.0,unbalanced_type='TV') - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) can be solved with the following code: .. code-block:: python # default is ``"KL"`` for both res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0) # quadratic unbalanced OT with KL regularization res = ot.solve(M,a,b,reg=1.0,unbalanced=1.0,unbalanced_type='L2') # both quadratic res = ot.solve(M,a,b,reg=1.0, reg_type='L2',unbalanced=1.0,unbalanced_type='L2') .. _references-solve: References ---------- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. """ # detect backend arr = [M] if a is not None: arr.append(a) if b is not None: arr.append(b) nx = get_backend(*arr) # create uniform weights if not given if a is None: a = nx.ones(M.shape[0], type_as=M) / M.shape[0] if b is None: b = nx.ones(M.shape[1], type_as=M) / M.shape[1] # default values for solutions potentials = None value = None value_linear = None plan = None status = None if reg is None or reg == 0: # exact OT if unbalanced is None: # Exact balanced OT # default values for EMD solver if max_iter is None: max_iter = 1000000 value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) value = value_linear potentials = (log['u'], log['v']) plan = log['G'] status = log["warning"] if log["warning"] is not None else 'Converged' elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT # default values for exact unbalanced OT if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 plan, log = mm_unbalanced(a, b, M, reg_m=unbalanced, div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, log=True, verbose=verbose, G0=plan_init) value_linear = log['cost'] if unbalanced_type.lower() == 'kl': value = value_linear + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) else: err_a = nx.sum(plan, 1) - a err_b = nx.sum(plan, 0) - b value = value_linear + unbalanced * nx.sum(err_a**2) + unbalanced * nx.sum(err_b**2) elif unbalanced_type.lower() == 'tv': if max_iter is None: max_iter = 1000000 plan, log = partial_wasserstein_lagrange(a, b, M, reg_m=unbalanced**2, log=True, numItermax=max_iter) value_linear = nx.sum(M * plan) err_a = nx.sum(plan, 1) - a err_b = nx.sum(plan, 0) - b value = value_linear + nx.sqrt(unbalanced**2 / 2.0 * (nx.sum(nx.abs(err_a)) + nx.sum(nx.abs(err_b)))) else: raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) else: # regularized OT if unbalanced is None: # Balanced regularized OT if reg_type.lower() in ['entropy', 'kl']: # default values for sinkhorn if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) if reg_type.lower() == 'entropy': value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) potentials = (log['log_u'], log['log_v']) elif reg_type.lower() == 'l2': if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = smooth_ot_dual(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan**2) potentials = (log['alpha'], log['beta']) else: raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) else: # unbalanced AND regularized OT if reg_type.lower() in ['kl'] and unbalanced_type.lower() == 'kl': if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_knopp_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.kl_div(plan, a[:, None] * b[None, :]) + unbalanced * (nx.kl_div(nx.sum(plan, 1), a) + nx.kl_div(nx.sum(plan, 0), b)) potentials = (log['logu'], log['logv']) elif reg_type.lower() in ['kl', 'l2', 'entropy'] and unbalanced_type.lower() in ['kl', 'l2']: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-12 plan, log = lbfgsb_unbalanced(a, b, M, reg=reg, reg_m=unbalanced, reg_div=reg_type.lower(), regm_div=unbalanced_type.lower(), numItermax=max_iter, stopThr=tol, verbose=verbose, log=True) value_linear = nx.sum(M * plan) value = log['loss'] else: raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, value_linear=value_linear, plan=plan, status=status, backend=nx) return res
[docs] def solve_gromov(Ca, Cb, M=None, a=None, b=None, loss='L2', symmetric=None, alpha=0.5, reg=None, reg_type="entropy", unbalanced=None, unbalanced_type='KL', n_threads=1, method=None, max_iter=None, plan_init=None, tol=None, verbose=False): r""" Solve the discrete (Fused) Gromov-Wasserstein and return :any:`OTResult` object The function solves the following optimization problem: .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- Ca : array_like, shape (dim_a, dim_a) Cost matrix in the source domain Cb : array_like, shape (dim_b, dim_b) Cost matrix in the target domain M : array_like, shape (dim_a, dim_b), optional Linear cost matrix for Fused Gromov-Wasserstein (default is None). a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) loss : str, optional Type of loss function, either ``"L2"`` or ``"KL"``, by default ``"L2"`` symmetric : bool, optional Use symmetric version of the Gromov-Wasserstein problem, by default None tests whether the matrices are symmetric or True/False to avoid the test. reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R`, by default "entropy" (only used when ``reg!=None``) alpha : float, optional Weight the quadratic term (alpha*Gromov) and the linear term ((1-alpha)*Wass) in the Fused Gromov-Wasserstein problem. Not used for Gromov problem (when M is not provided). By default ``alpha=None`` corresponds to ``alpha=1`` for Gromov problem (``M==None``) and ``alpha=0.5`` for Fused Gromov-Wasserstein problem (``M!=None``) unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT), Not implemented yet unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "semirelaxed", "partial", by default "KL" but note that it is not implemented yet. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 method : str, optional Method for solving the problem when multiple algorithms are available, default None for automatic selection. max_iter : int, optional Maximum number of iterations, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None tol : float, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan - res.value_quad : Quadratic (GW) part of the OT loss with the optimal OT plan See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the Gromov-Wasserstein problem: - **Classical Gromov-Wasserstein (GW) problem [3]** (default parameters): .. math:: \min_{\mathbf{T}\geq 0} \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb) # uniform weights res = ot.solve_gromov(Ca, Cb, a=a, b=b) # given weights res = ot.solve_gromov(Ca, Cb, loss='KL') # KL loss plan = res.plan # GW plan value = res.value # GW value - **Fused Gromov-Wasserstein (FGW) problem [24]** (when ``M!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, M) # uniform weights, alpha=0.5 (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, alpha=0.1) # given weights and alpha plan = res.plan # FGW plan loss_linear_term = res.value_linear # Wasserstein part of the loss loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value - **Regularized (Fused) Gromov-Wasserstein (GW) problem [12]** (when ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} + \lambda_r R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, reg=1.0) # GW entropy regularization (default) res = ot.solve_gromov(Ca, Cb, M, a=a, b=b, reg=10, alpha=0.1) # FGW with entropy plan = res.plan # FGW plan loss_linear_term = res.value_linear # Wasserstein part of the loss loss_quad_term = res.value_quad # Gromov part of the loss loss = res.value # FGW value (including regularization) - **Semi-relaxed (Fused) Gromov-Wasserstein (GW) [48]** (when ``unbalanced='semirelaxed'``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T} \geq 0 can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed') # semirelaxed GW res = ot.solve_gromov(Ca, Cb, unbalanced='semirelaxed', reg=1) # entropic semirelaxed GW res = ot.solve_gromov(Ca, Cb, M, unbalanced='semirelaxed', alpha=0.1) # semirelaxed FGW plan = res.plan # FGW plan right_marginal = res.marginal_b # right marginal of the plan - **Partial (Fused) Gromov-Wasserstein (GW) problem [29]** (when ``unbalanced='partial'``): .. math:: \min_{\mathbf{T}\geq 0} \quad (1 - \alpha) \langle \mathbf{T}, \mathbf{M} \rangle_F + \alpha \sum_{i,j,k,l} L(\mathbf{C_1}_{i,k}, \mathbf{C_2}_{j,l}) \mathbf{T}_{i,j}\mathbf{T}_{k,l} s.t. \ \mathbf{T} \mathbf{1} \leq \mathbf{a} \mathbf{T}^T \mathbf{1} \leq \mathbf{b} \mathbf{T} \geq 0 \mathbf{1}^T\mathbf{T}\mathbf{1} = m can be solved with the following code: .. code-block:: python res = ot.solve_gromov(Ca, Cb, unbalanced_type='partial', unbalanced=0.8) # partial GW with m=0.8 .. _references-solve-gromov: References ---------- .. [3] Mémoli, F. (2011). Gromov–Wasserstein distances and the metric approach to object matching. Foundations of computational mathematics, 11(4), 417-487. .. [12] Gabriel Peyré, Marco Cuturi, and Justin Solomon (2016), Gromov-Wasserstein averaging of kernel and distance matrices International Conference on Machine Learning (ICML). .. [24] Vayer, T., Chapel, L., Flamary, R., Tavenard, R. and Courty, N. (2019). Optimal Transport for structured data with application on graphs Proceedings of the 36th International Conference on Machine Learning (ICML). .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty (2022). Semi-relaxed Gromov-Wasserstein divergence and applications on graphs. International Conference on Learning Representations (ICLR), 2022. .. [29] Chapel, L., Alaya, M., Gasso, G. (2020). Partial Optimal Transport with Applications on Positive-Unlabeled Learning, Advances in Neural Information Processing Systems (NeurIPS), 2020. """ # detect backend nx = get_backend(Ca, Cb, M, a, b) # create uniform weights if not given if a is None: a = nx.ones(Ca.shape[0], type_as=Ca) / Ca.shape[0] if b is None: b = nx.ones(Cb.shape[1], type_as=Cb) / Cb.shape[1] # default values for solutions potentials = None value = None value_linear = None value_quad = None plan = None status = None log = None loss_dict = {'l2': 'square_loss', 'kl': 'kl_loss'} if loss.lower() not in loss_dict.keys(): raise (NotImplementedError('Not implemented GW loss="{}"'.format(loss))) loss_fun = loss_dict[loss.lower()] if reg is None or reg == 0: # exact OT if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Exact balanced OT if M is None or alpha == 1: # Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = gromov_wasserstein2(Ca, Cb, a, b, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 plan = log['T'] potentials = (log['u'], log['v']) elif alpha == 0: # Wasserstein problem # default values for EMD solver if max_iter is None: max_iter = 1000000 value_linear, log = emd2(a, b, M, numItermax=max_iter, log=True, return_matrix=True, numThreads=n_threads) value = value_linear potentials = (log['u'], log['v']) plan = log['G'] status = log["warning"] if log["warning"] is not None else 'Converged' value_quad = 0 else: # Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] potentials = (log['u'], log['v']) elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT if M is None or alpha == 1: # Semi relaxed Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = semirelaxed_gromov_wasserstein2(Ca, Cb, a, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_quad = value if alpha == 1: # set to 0 for FGW with alpha=1 value_linear = 0 plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # Semi relaxed Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 10000 if tol is None: tol = 1e-9 value, log = semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] # potentials = (log['u'], log['v']) TODO elif unbalanced_type.lower() in ['partial']: # Partial OT if M is None: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) if loss.lower() != 'l2': raise (NotImplementedError('Partial GW only implemented with L2 loss')) if symmetric is not None: raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-7 value, log = partial_gromov_wasserstein2(Ca, Cb, a, b, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_quad = value plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # partial FGW raise (NotImplementedError('Partial FGW not implemented yet')) elif unbalanced_type.lower() in ['kl', 'l2']: # unbalanced exact OT raise (NotImplementedError('Unbalanced_type="{}"'.format(unbalanced_type))) else: raise (NotImplementedError('Unknown unbalanced_type="{}"'.format(unbalanced_type))) else: # regularized OT if unbalanced is None and unbalanced_type.lower() not in ['semirelaxed']: # Balanced regularized OT if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: method = 'PGD' value_quad, log = entropic_gromov_wasserstein2(Ca, Cb, a, b, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) # potentials = (log['log_u'], log['log_v']) #TODO elif reg_type.lower() in ['entropy'] and M is not None and alpha == 0: # Entropic Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 plan, log = sinkhorn_log(a, b, M, reg=reg, numItermax=max_iter, stopThr=tol, log=True, verbose=verbose) value_linear = nx.sum(M * plan) value = value_linear + reg * nx.sum(plan * nx.log(plan + 1e-16)) potentials = (log['log_u'], log['log_v']) elif reg_type.lower() in ['entropy'] and M is not None: # Entropic Fused Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if method is None: method = 'PGD' value_noreg, log = entropic_fused_gromov_wasserstein2(M, Ca, Cb, a, b, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, solver=method, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] # potentials = (log['u'], log['v']) value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: raise (NotImplementedError('Not implemented reg_type="{}"'.format(reg_type))) elif unbalanced_type.lower() in ['semirelaxed']: # Semi-relaxed OT if reg_type.lower() in ['entropy'] and (M is None or alpha == 1): # Entropic Semi-relaxed Gromov-Wasserstein problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 value_quad, log = entropic_semirelaxed_gromov_wasserstein2(Ca, Cb, a, epsilon=reg, loss_fun=loss_fun, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) plan = log['T'] value_linear = 0 value = value_quad + reg * nx.sum(plan * nx.log(plan + 1e-16)) else: # Entropic Semi-relaxed FGW problem # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 value_noreg, log = entropic_semirelaxed_fused_gromov_wasserstein2(M, Ca, Cb, a, loss_fun=loss_fun, alpha=alpha, log=True, symmetric=symmetric, max_iter=max_iter, G0=plan_init, tol_rel=tol, tol_abs=tol, verbose=verbose) value_linear = log['lin_loss'] value_quad = log['quad_loss'] plan = log['T'] value = value_noreg + reg * nx.sum(plan * nx.log(plan + 1e-16)) elif unbalanced_type.lower() in ['partial']: # Partial OT if M is None: # Partial Gromov-Wasserstein problem if unbalanced > nx.sum(a) or unbalanced > nx.sum(b): raise (ValueError('Partial GW mass given in reg is too large')) if loss.lower() != 'l2': raise (NotImplementedError('Partial GW only implemented with L2 loss')) if symmetric is not None: raise (NotImplementedError('Partial GW only implemented with symmetric=True')) # default values for solver if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-7 value_quad, log = entropic_partial_gromov_wasserstein2(Ca, Cb, a, b, reg=reg, m=unbalanced, log=True, numItermax=max_iter, G0=plan_init, tol=tol, verbose=verbose) value_quad = value plan = log['T'] # potentials = (log['u'], log['v']) TODO else: # partial FGW raise (NotImplementedError('Partial entropic FGW not implemented yet')) else: # unbalanced AND regularized OT raise (NotImplementedError('Not implemented reg_type="{}" and unbalanced_type="{}"'.format(reg_type, unbalanced_type))) res = OTResult(potentials=potentials, value=value, value_linear=value_linear, value_quad=value_quad, plan=plan, status=status, backend=nx, log=log) return res
[docs] def solve_sample(X_a, X_b, a=None, b=None, metric='sqeuclidean', reg=None, reg_type="KL", unbalanced=None, unbalanced_type='KL', lazy=False, batch_size=None, method=None, n_threads=1, max_iter=None, plan_init=None, rank=100, scaling=0.95, potentials_init=None, X_init=None, tol=None, verbose=False): r"""Solve the discrete optimal transport problem using the samples in the source and target domains. The function solves the following general optimal transport problem .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) where the cost matrix :math:`\mathbf{M}` is computed from the samples in the source and target domains such that :math:`M_{i,j} = d(x_i,y_j)` where :math:`d` is a metric (by default the squared Euclidean distance). The regularization is selected with `reg` (:math:`\lambda_r`) and `reg_type`. By default ``reg=None`` and there is no regularization. The unbalanced marginal penalization can be selected with `unbalanced` (:math:`\lambda_u`) and `unbalanced_type`. By default ``unbalanced=None`` and the function solves the exact optimal transport problem (respecting the marginals). Parameters ---------- X_s : array-like, shape (n_samples_a, dim) samples in the source domain X_t : array-like, shape (n_samples_b, dim) samples in the target domain a : array-like, shape (dim_a,), optional Samples weights in the source domain (default is uniform) b : array-like, shape (dim_b,), optional Samples weights in the source domain (default is uniform) reg : float, optional Regularization weight :math:`\lambda_r`, by default None (no reg., exact OT) reg_type : str, optional Type of regularization :math:`R` either "KL", "L2", "entropy", by default "KL" unbalanced : float, optional Unbalanced penalization weight :math:`\lambda_u`, by default None (balanced OT) unbalanced_type : str, optional Type of unbalanced penalization function :math:`U` either "KL", "L2", "TV", by default "KL" lazy : bool, optional Return :any:`OTResultlazy` object to reduce memory cost when True, by default False batch_size : int, optional Batch size for lazy solver, by default None (default values in each solvers) method : str, optional Method for solving the problem, this can be used to select the solver for unbalanced problems (see :any:`ot.solve`), or to select a specific large scale solver. n_threads : int, optional Number of OMP threads for exact OT solver, by default 1 max_iter : int, optional Maximum number of iteration, by default None (default values in each solvers) plan_init : array_like, shape (dim_a, dim_b), optional Initialization of the OT plan for iterative methods, by default None rank : int, optional Rank of the OT matrix for lazy solers (method='factored'), by default 100 scaling : float, optional Scaling factor for the epsilon scaling lazy solvers (method='geomloss'), by default 0.95 potentials_init : (array_like(dim_a,),array_like(dim_b,)), optional Initialization of the OT dual potentials for iterative methods, by default None tol : _type_, optional Tolerance for solution precision, by default None (default values in each solvers) verbose : bool, optional Print information in the solver, by default False Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan - res.lazy_plan : Lazy OT plan (when ``lazy=True`` or lazy method) See :any:`OTResult` for more information. Notes ----- The following methods are available for solving the OT problems: - **Classical exact OT problem [1]** (default parameters) : .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python res = ot.solve_sample(xa, xb, a, b) # for uniform weights res = ot.solve_sample(xa, xb) - **Entropic regularized OT [2]** (when ``reg!=None``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` regularization (``reg_type="KL"``) res = ot.solve_sample(xa, xb, a, b, reg=1.0) # or for original Sinkhorn paper formulation [2] res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='entropy') # lazy solver of memory complexity O(n) res = ot.solve_sample(xa, xb, a, b, reg=1.0, lazy=True, batch_size=100) # lazy OT plan lazy_plan = res.lazy_plan We also have a very efficient solver with compiled CPU/CUDA code using geomloss/PyKeOps that can be used with the following code: .. code-block:: python # automatic solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss') # force O(n) memory efficient solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_online') # force pre-computed cost matrix res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_tensorized') # use multiscale solver res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss_multiscale') # One can play with speed (small scaling factor) and precision (scaling close to 1) res = ot.solve_sample(xa, xb, a, b, reg=1.0, method='geomloss', scaling=0.5) - **Quadratic regularized OT [17]** (when ``reg!=None`` and ``reg_type="L2"``): .. math:: \min_\mathbf{T} \quad \langle \mathbf{T}, \mathbf{M} \rangle_F + \lambda R(\mathbf{T}) s.t. \ \mathbf{T} \mathbf{1} = \mathbf{a} \mathbf{T}^T \mathbf{1} = \mathbf{b} \mathbf{T} \geq 0, M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2') - **Unbalanced OT [41]** (when ``unbalanced!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) with M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0) # quadratic unbalanced OT res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='L2') # TV = partial OT res = ot.solve_sample(xa, xb, a, b, unbalanced=1.0,unbalanced_type='TV') - **Regularized unbalanced regularized OT [34]** (when ``unbalanced!=None`` and ``reg!=None``): .. math:: \min_{\mathbf{T}\geq 0} \quad \sum_{i,j} T_{i,j}M_{i,j} + \lambda_r R(\mathbf{T}) + \lambda_u U(\mathbf{T}\mathbf{1},\mathbf{a}) + \lambda_u U(\mathbf{T}^T\mathbf{1},\mathbf{b}) with M_{i,j} = d(x_i,y_j) can be solved with the following code: .. code-block:: python # default is ``"KL"`` for both res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0) # quadratic unbalanced OT with KL regularization res = ot.solve_sample(xa, xb, a, b, reg=1.0, unbalanced=1.0,unbalanced_type='L2') # both quadratic res = ot.solve_sample(xa, xb, a, b, reg=1.0, reg_type='L2', unbalanced=1.0, unbalanced_type='L2') - **Factored OT [2]** (when ``method='factored'``): This method solve the following OT problem [40]_ .. math:: \mathop{\arg \min}_\mu \quad W_2^2(\mu_a,\mu)+ W_2^2(\mu,\mu_b) where $\mu$ is a uniform weighted empirical distribution of :math:`\mu_a` and :math:`\mu_b` are the empirical measures associated to the samples in the source and target domains, and :math:`W_2` is the Wasserstein distance. This problem is solved using exact OT solvers for `reg=None` and the Sinkhorn solver for `reg!=None`. The solution provides two transport plans that can be used to recover a low rank OT plan between the two distributions. .. code-block:: python res = ot.solve_sample(xa, xb, method='factored', rank=10) # recover the lazy low rank plan factored_solution_lazy = res.lazy_plan # recover the full low rank plan factored_solution = factored_solution_lazy[:] - **Gaussian Bures-Wasserstein [2]** (when ``method='gaussian'``): This method computes the Gaussian Bures-Wasserstein distance between two Gaussian distributions estimated from teh empirical distributions .. math:: \mathcal{W}(\mu_s, \mu_t)_2^2= \left\lVert \mathbf{m}_s - \mathbf{m}_t \right\rVert^2 + \mathcal{B}(\Sigma_s, \Sigma_t)^{2} where : .. math:: \mathbf{B}(\Sigma_s, \Sigma_t)^{2} = \text{Tr}\left(\Sigma_s + \Sigma_t - 2 \sqrt{\Sigma_s^{1/2}\Sigma_t\Sigma_s^{1/2}} \right) The covariances and means are estimated from the data. .. code-block:: python res = ot.solve_sample(xa, xb, method='gaussian') # recover the squared Gaussian Bures-Wasserstein distance BW_dist = res.value - **Wasserstein 1d [1]** (when ``method='1D'``): This method computes the Wasserstein distance between two 1d distributions estimated from the empirical distributions. For multivariate data the distances are computed independently for each dimension. .. code-block:: python res = ot.solve_sample(xa, xb, method='1D') # recover the squared Wasserstein distances W_dists = res.value .. _references-solve-sample: References ---------- .. [1] Bonneel, N., Van De Panne, M., Paris, S., & Heidrich, W. (2011, December). Displacement interpolation using Lagrangian mass transport. In ACM Transactions on Graphics (TOG) (Vol. 30, No. 6, p. 158). ACM. .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [17] Blondel, M., Seguy, V., & Rolet, A. (2018). Smooth and Sparse Optimal Transport. Proceedings of the Twenty-First International Conference on Artificial Intelligence and Statistics (AISTATS). .. [34] Feydy, J., Séjourné, T., Vialard, F. X., Amari, S. I., Trouvé, A., & Peyré, G. (2019, April). Interpolating between optimal transport and MMD using Sinkhorn divergences. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2681-2690). PMLR. .. [40] Forrow, A., Hütter, J. C., Nitzan, M., Rigollet, P., Schiebinger, G., & Weed, J. (2019, April). Statistical optimal transport via factored couplings. In The 22nd International Conference on Artificial Intelligence and Statistics (pp. 2454-2465). PMLR. .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. .. [65] Scetbon, M., Cuturi, M., & Peyré, G. (2021). Low-rank Sinkhorn Factorization. In International Conference on Machine Learning. """ if method is not None and method.lower() in lst_method_lazy: lazy0 = lazy lazy = True if not lazy: # default non lazy solver calls ot.solve # compute cost matrix M and use solve function M = dist(X_a, X_b, metric) res = solve(M, a, b, reg, reg_type, unbalanced, unbalanced_type, method, n_threads, max_iter, plan_init, potentials_init, tol, verbose) return res else: # Detect backend nx = get_backend(X_a, X_b, a, b) # default values for solutions potentials = None value = None value_linear = None plan = None lazy_plan = None status = None log = None method = method.lower() if method is not None else '' if method == '1d': # Wasserstein 1d (parallel on all dimensions) if metric == 'sqeuclidean': p = 2 elif metric in ['euclidean', 'cityblock']: p = 1 else: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) value = wasserstein_1d(X_a, X_b, a, b, p=p) value_linear = value elif method == 'gaussian': # Gaussian Bures-Wasserstein if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if reg is None: reg = 1e-6 value, log = empirical_bures_wasserstein_distance(X_a, X_b, reg=reg, log=True) value = value**2 # return the value (squared bures distance) value_linear = value # return the value elif method == 'factored': # Factored OT if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if max_iter is None: max_iter = 100 if tol is None: tol = 1e-7 if reg is None: reg = 0 Q, R, X, log = factored_optimal_transport(X_a, X_b, reg=reg, r=rank, log=True, stopThr=tol, numItermax=max_iter, verbose=verbose) log['X'] = X value_linear = log['costa'] + log['costb'] value = value_linear # TODO add reg term lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] elif method == "lowrank": if not metric.lower() in ['sqeuclidean']: raise (NotImplementedError('Not implemented metric="{}"'.format(metric))) if max_iter is None: max_iter = 2000 if tol is None: tol = 1e-7 if reg is None: reg = 0 Q, R, g, log = lowrank_sinkhorn(X_a, X_b, rank=rank, reg=reg, a=a, b=b, numItermax=max_iter, stopThr=tol, log=True) value = log['value'] value_linear = log['value_linear'] lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] elif method.startswith('geomloss'): # Geomloss solver for entropic OT split_method = method.split('_') if len(split_method) == 2: backend = split_method[1] else: if lazy0 is None: backend = 'auto' elif lazy0: backend = 'online' else: backend = 'tensorized' value, log = empirical_sinkhorn2_geomloss(X_a, X_b, reg=reg, a=a, b=b, metric=metric, log=True, verbose=verbose, scaling=scaling, backend=backend) lazy_plan = log['lazy_plan'] if not lazy0: # store plan if not lazy plan = lazy_plan[:] # return scaled potentials (to be consistent with other solvers) potentials = (log['f'] / (lazy_plan.blur**2), log['g'] / (lazy_plan.blur**2)) elif reg is None or reg == 0: # exact OT if unbalanced is None: # balanced EMD solver not available for lazy raise (NotImplementedError('Exact OT solver with lazy=True not implemented')) else: raise (NotImplementedError('Non regularized solver with unbalanced_type="{}" not implemented'.format(unbalanced_type))) else: if unbalanced is None: if max_iter is None: max_iter = 1000 if tol is None: tol = 1e-9 if batch_size is None: batch_size = 100 value_linear, log = empirical_sinkhorn2(X_a, X_b, reg, a, b, metric=metric, numIterMax=max_iter, stopThr=tol, isLazy=True, batchSize=batch_size, verbose=verbose, log=True) # compute potentials potentials = (log["u"], log["v"]) lazy_plan = log['lazy_plan'] else: raise (NotImplementedError('Not implemented unbalanced_type="{}" with regularization'.format(unbalanced_type))) res = OTResult(potentials=potentials, value=value, lazy_plan=lazy_plan, value_linear=value_linear, plan=plan, status=status, backend=nx, log=log) return res