Source code for ot.unbalanced

# -*- coding: utf-8 -*-
"""
Regularized Unbalanced OT solvers
"""

# Author: Hicham Janati <hicham.janati@inria.fr>
#         Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#         Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

from __future__ import division
import warnings

import numpy as np
from scipy.optimize import minimize, Bounds

from .backend import get_backend
from .utils import list_to_array, get_parameter_pair


[docs] def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn', reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the unbalanced entropic regularization optimal transport problem and return the OT plan The function solves the following optimization problem: .. math:: W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \Omega(\gamma) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The entropic balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameters reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or 'kl' (Kullback-Leibler) :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- if n_hists == 1: - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.], [1., 0.]] >>> ot.sinkhorn_unbalanced(a, b, M, 1, 1) array([[0.51122814, 0.18807032], [0.18807032, 0.51122814]]) .. _references-sinkhorn-unbalanced: References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>` ot.unbalanced.sinkhorn_stabilized_unbalanced: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>` ot.unbalanced.sinkhorn_reg_scaling_unbalanced: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>` """ if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method)
[docs] def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn', reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss The function solves the following optimization problem: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \Omega(\gamma) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma\geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The entropic balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). method : str method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or 'sinkhorn_reg_scaling', see those function for specific parameterss reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or 'kl' (Kullback-Leibler) :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .10] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8) 0.31912858 .. _references-sinkhorn-unbalanced2: References ---------- .. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal Transport, Advances in Neural Information Processing Systems (NIPS) 26, 2013 .. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519. .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>` ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>` ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>` """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) if len(b.shape) < 2: if method.lower() == 'sinkhorn': res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method) if log: return nx.sum(M * res[0]), res[1] else: return nx.sum(M * res) else: if method.lower() == 'sinkhorn': return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp') return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type, warmstart, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError('Unknown method %s.' % method)
[docs] def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the OT plan The function solves the following optimization problem: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \Omega(\gamma) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b` If many, compute all the OT distances (a, b_i) M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The entropic balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or 'kl' (Kullback-Leibler) :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- if n_hists == 1: - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.) array([[0.51122814, 0.18807032], [0.18807032, 0.51122814]]) .. _references-sinkhorn-knopp-unbalanced: References ---------- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: if n_hists: u = nx.ones((dim_a, 1), type_as=M) v = nx.ones((dim_b, n_hists), type_as=M) a = a.reshape(dim_a, 1) else: u = nx.ones(dim_a, type_as=M) v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "kl": K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :] elif reg_type == "entropy": K = nx.exp(-M / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 err = 1. for i in range(numItermax): uprev = u vprev = v Kv = nx.dot(K, v) u = (a / Kv) ** fi_1 Ktu = nx.dot(K.T, u) v = (b / Ktu) ** fi_2 if (nx.any(Ktu == 0.) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev break err_u = nx.max(nx.abs(u - uprev)) / max( nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. ) err_v = nx.max(nx.abs(v - vprev)) / max( nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1. ) err = 0.5 * (err_u + err_v) if log: log['err'].append(err) if verbose: if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(i, err)) if err < stopThr: break if log: log['logu'] = nx.log(u + 1e-300) log['logv'] = nx.log(v + 1e-300) if n_hists: # return only loss res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M) if log: return res, log else: return res else: # return OT matrix if log: return u[:, None] * K * v[None, :], log else: return u[:, None] * K * v[None, :]
[docs] def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy", warmstart=None, tau=1e5, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r""" Solve the entropic regularization unbalanced optimal transport problem and return the loss The function solves the following optimization problem using log-domain stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot \Omega(\gamma) + \mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 where : - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - KL is the Kullback-Leibler divergence The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) or array-like (dim_b, n_hists) One or multiple unnormalized histograms of dimension `dim_b`. If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i` M : array-like (dim_a, dim_b) loss matrix reg : float Entropy regularization term > 0 reg_m: float or indexable object of length 1 or 2 Marginal relaxation term. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. The entropic balanced OT can be recovered using `reg_m=float("inf")`. For semi-relaxed case, use either `reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg_type : string, optional Regularizer term. Can take two values: 'entropy' (negative entropy) :math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or 'kl' (Kullback-Leibler) :math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`. warmstart: tuple of arrays, shape (dim_a, dim_b), optional Initialization of dual potentials. If provided, the dual potentials should be given (that is the logarithm of the u,v sinkhorn scaling vectors). tau : float threshold for max value in u or v for log scaling numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- if n_hists == 1: - gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters - log : dict log dictionary returned only if `log` is `True` else: - ot_distance : (n_hists,) array-like the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i` - log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[0., 1.],[1., 0.]] >>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.) array([[0.51122814, 0.18807032], [0.18807032, 0.51122814]]) .. _references-sinkhorn-stabilized-unbalanced: References ---------- .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. .. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. : Learning with a Wasserstein Loss, Advances in Neural Information Processing Systems (NIPS) 2015 See Also -------- ot.lp.emd : Unregularized OT ot.optim.cg : General regularized OT """ a, b, M = list_to_array(a, b, M) nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b if len(b.shape) > 1: n_hists = b.shape[1] else: n_hists = 0 reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': []} # we assume that no distances are null except those of the diagonal of # distances if warmstart is None: if n_hists: u = nx.ones((dim_a, n_hists), type_as=M) v = nx.ones((dim_b, n_hists), type_as=M) a = a.reshape(dim_a, 1) else: u = nx.ones(dim_a, type_as=M) v = nx.ones(dim_b, type_as=M) else: u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1]) if reg_type == "kl": log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :] M0 = M - reg * log_ab else: M0 = M K = nx.exp(-M0 / reg) fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1 fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1 cpt = 0 err = 1. alpha = nx.zeros(dim_a, type_as=M) beta = nx.zeros(dim_b, type_as=M) ones_a = nx.ones(dim_a, type_as=M) ones_b = nx.ones(dim_b, type_as=M) while (err > stopThr and cpt < numItermax): uprev = u vprev = v Kv = nx.dot(K, v) f_alpha = nx.exp(- alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a f_beta = nx.exp(- beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b if n_hists: f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((a / (Kv + 1e-16)) ** fi_1) * f_alpha Ktu = nx.dot(K.T, u) v = ((b / (Ktu + 1e-16)) ** fi_2) * f_beta absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True if n_hists: alpha = alpha + reg * nx.log(nx.max(u, 1)) beta = beta + reg * nx.log(nx.max(v, 1)) else: alpha = alpha + reg * nx.log(nx.max(u)) beta = beta + reg * nx.log(nx.max(v)) K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) if (nx.any(Ktu == 0.) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) u = uprev v = vprev break if (cpt % 10 == 0 and not absorbing) or cpt == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(u - uprev)) / max( nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1. ) if log: log['err'].append(err) if verbose: if cpt % 200 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(cpt, err)) cpt = cpt + 1 if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if n_hists: logu = alpha[:, None] / reg + nx.log(u) logv = beta[:, None] / reg + nx.log(v) else: logu = alpha / reg + nx.log(u) logv = beta / reg + nx.log(v) if log: log['logu'] = logu log['logv'] = logv if n_hists: # return only loss res = nx.logsumexp( nx.log(M + 1e-100)[:, :, None] + logu[:, None, :] + logv[None, :, :] - M0[:, :, None] / reg, axis=(0, 1) ) res = nx.exp(res) if log: return res, log else: return res else: # return OT matrix ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg) if log: return ot_matrix, log else: return ot_matrix
[docs] def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3, numItermax=1000, stopThr=1e-6, verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization. The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>` Parameters ---------- A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m : float Marginal relaxation term > 0 tau : float Stabilization threshold for log domain absorption. weights : array-like (n_hists,) optional Weight of each distribution (barycentric coordinates) If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-unbalanced-stabilized: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816. """ A, M = list_to_array(A, M) nx = get_backend(A, M) dim, n_hists = A.shape if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} fi = reg_m / (reg_m + reg) u = nx.ones((dim, n_hists), type_as=A) / dim v = nx.ones((dim, n_hists), type_as=A) / dim # print(reg) K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) cpt = 0 err = 1. alpha = nx.zeros(dim, type_as=A) beta = nx.zeros(dim, type_as=A) q = nx.ones(dim, type_as=A) / dim for i in range(numItermax): qprev = nx.copy(q) Kv = nx.dot(K, v) f_alpha = nx.exp(- alpha / (reg + reg_m)) f_beta = nx.exp(- beta / (reg + reg_m)) f_alpha = f_alpha[:, None] f_beta = f_beta[:, None] u = ((A / (Kv + 1e-16)) ** fi) * f_alpha Ktu = nx.dot(K.T, u) q = (Ktu ** (1 - fi)) * f_beta q = nx.dot(q, weights) ** (1 / (1 - fi)) Q = q[:, None] v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta absorbing = False if nx.any(u > tau) or nx.any(v > tau): absorbing = True alpha = alpha + reg * nx.log(nx.max(u, 1)) beta = beta + reg * nx.log(nx.max(v, 1)) K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg) v = nx.ones(v.shape, type_as=v) Kv = nx.dot(K, v) if (nx.any(Ktu == 0.) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % cpt) q = qprev break if (i % 10 == 0 and not absorbing) or i == 0: # we can speed up the process by checking for the error only all # the 10th iterations err = nx.max(nx.abs(q - qprev)) / max( nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1. ) if log: log['err'].append(err) if verbose: if i % 50 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(i, err)) if err < stopThr: break if err > stopThr: warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." + "Try a larger entropy `reg` or a lower mass `reg_m`." + "Or a larger absorption threshold `tau`.") if log: log['niter'] = i log['logu'] = nx.log(u + 1e-300) log['logv'] = nx.log(v + 1e-300) return q, log else: return q
[docs] def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>` Parameters ---------- A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-unbalanced-sinkhorn: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprin arXiv:1607.05816. """ A, M = list_to_array(A, M) nx = get_backend(A, M) dim, n_hists = A.shape if weights is None: weights = nx.ones(n_hists, type_as=A) / n_hists else: assert (len(weights) == A.shape[1]) if log: log = {'err': []} K = nx.exp(-M / reg) fi = reg_m / (reg_m + reg) v = nx.ones((dim, n_hists), type_as=A) u = nx.ones((dim, 1), type_as=A) q = nx.ones(dim, type_as=A) err = 1. for i in range(numItermax): uprev = nx.copy(u) vprev = nx.copy(v) qprev = nx.copy(q) Kv = nx.dot(K, v) u = (A / Kv) ** fi Ktu = nx.dot(K.T, u) q = nx.dot(Ktu ** (1 - fi), weights) q = q ** (1 / (1 - fi)) Q = q[:, None] v = (Q / Ktu) ** fi if (nx.any(Ktu == 0.) or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v)) or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))): # we have reached the machine precision # come back to previous solution and quit loop warnings.warn('Numerical errors at iteration %s' % i) u = uprev v = vprev q = qprev break # compute change in barycenter err = nx.max(nx.abs(q - qprev)) / max( nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0 ) if log: log['err'].append(err) # if barycenter did not change + at least 10 iterations - stop if err < stopThr and i > 10: break if verbose: if i % 10 == 0: print( '{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19) print('{:5d}|{:8e}|'.format(i, err)) if log: log['niter'] = i log['logu'] = nx.log(u + 1e-300) log['logv'] = nx.log(v + 1e-300) return q, log else: return q
[docs] def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None, numItermax=1000, stopThr=1e-6, verbose=False, log=False, **kwargs): r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`. The function solves the following optimization problem with :math:`\mathbf{a}` .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`) - :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}` - reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT - reg_mis the marginal relaxation hyperparameter The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>` Parameters ---------- A : array-like (dim, n_hists) `n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim` M : array-like (dim, dim) ground metric matrix for OT. reg : float Entropy regularization term > 0 reg_m: float Marginal relaxation term > 0 weights : array-like (n_hists,) optional Weight of each distribution (barycentric coodinates) If None, uniform weights are used. numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- a : (dim,) array-like Unbalanced Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-barycenter-unbalanced: References ---------- .. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G. (2015). Iterative Bregman projections for regularized transportation problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138. .. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016). Scaling algorithms for unbalanced transport problems. arXiv preprin arXiv:1607.05816. """ if method.lower() == 'sinkhorn': return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() == 'sinkhorn_stabilized': return barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) elif method.lower() in ['sinkhorn_reg_scaling']: warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp') return barycenter_unbalanced(A, M, reg, reg_m, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, **kwargs) else: raise ValueError("Unknown method '%s'." % method)
[docs] def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 where: - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- minimization algorithm as proposed in :ref:`[41] <references-regpath>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) Regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) array([[0.45, 0. ], [0. , 0.34]]) >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) array([[0.4, 0. ], [0. , 0.1]]) .. _references-regpath: References ---------- .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. See Also -------- ot.lp.emd : Unregularized OT ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b G = a[:, None] * b[None, :] if G0 is None else G0 c = a[:, None] * b[None, :] if c is None else c reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {'err': [], 'G': []} if div not in ["kl", "l2"]: warnings.warn("The div parameter should be either equal to 'kl' or \ 'l2': it has been set to 'kl'.") div = 'kl' if div == 'kl': sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r) elif div == 'l2': K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) for i in range(numItermax): Gprev = G if div == 'kl': Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16 G = K * G**(r1 + r2) / Gd elif div == 'l2': Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \ reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 G = K * G / Gd err = nx.sqrt(nx.sum((G - Gprev) ** 2)) if log: log['err'].append(err) log['G'].append(G) if verbose: print('{:5d}|{:8e}|'.format(i, err)) if err < stopThr: break if log: log['cost'] = nx.sum(G * M) return G, log else: return G
[docs] def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 where: - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- minimization algorithm as proposed in :ref:`[41] <references-regpath>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must have the same backend as input arrays (a, b, M). reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- ot_distance : array-like the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}` log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) 0.8 >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) 1.79 References ---------- .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. See Also -------- ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ _, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True) if log: return log_mm['cost'], log_mm else: return log_mm['cost']
def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'): """ return the loss function (scipy.optimize compatible) for regularized unbalanced OT """ m, n = M.shape def kl(p, q): return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q) def reg_l2(G): return np.sum((G - c)**2) / 2 def grad_l2(G): return G - c def reg_kl(G): return kl(G, c) def grad_kl(G): return np.log(G / c + 1e-16) def reg_entropy(G): return np.sum(G * np.log(G + 1e-16)) - np.sum(G) def grad_entropy(G): return np.log(G + 1e-16) if reg_div == 'kl': reg_fun = reg_kl grad_reg_fun = grad_kl elif reg_div == 'entropy': reg_fun = reg_entropy grad_reg_fun = grad_entropy elif isinstance(reg_div, tuple): reg_fun = reg_div[0] grad_reg_fun = reg_div[1] else: reg_fun = reg_l2 grad_reg_fun = grad_l2 def marg_l2(G): return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \ reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2) def grad_marg_l2(G): return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \ reg_m2 * np.outer(np.ones(m), (G.sum(0) - b)) def marg_kl(G): return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b) def grad_marg_kl(G): return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \ reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16)) def marg_tv(G): return reg_m1 * np.sum(np.abs(G.sum(1) - a)) + \ reg_m2 * np.sum(np.abs(G.sum(0) - b)) def grad_marg_tv(G): return reg_m1 * np.outer(np.sign(G.sum(1) - a), np.ones(n)) + \ reg_m2 * np.outer(np.ones(m), np.sign(G.sum(0) - b)) if regm_div == 'kl': regm_fun = marg_kl grad_regm_fun = grad_marg_kl elif regm_div == 'tv': regm_fun = marg_tv grad_regm_fun = grad_marg_tv else: regm_fun = marg_l2 grad_regm_fun = grad_marg_l2 def _func(G): G = G.reshape((m, n)) # compute loss val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G) # compute gradient grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G) return val, grad.ravel() return _func
[docs] def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000, stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False): r""" Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B. The function solves the following optimization problem: .. math:: W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + + \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c}) \mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) s.t. \gamma \geq 0 where: - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` M : array-like (dim_a, dim_b) loss matrix reg: float regularization term >=0 c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`. reg_m: float or indexable object of length 1 or 2 Marginal relaxation term >= 0, but cannot be infinity. If reg_m is a scalar or an indexable object of length 1, then the same reg_m is applied to both marginal relaxations. If reg_m is an array, it must be a Numpy array. reg_div: string, optional Divergence used for regularization. Can take three values: 'entropy' (negative entropy), or 'kl' (Kullback-Leibler) or 'l2' (quadratic) or a tuple of two calable functions returning the reg term and its derivative. Note that the callable functions should be able to handle numpy arrays and not tesors from the backend regm_div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2) array([[0.45, 0. ], [0. , 0.34]]) >>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2) array([[0.4, 0. ], [0. , 0.1]]) References ---------- .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. See Also -------- ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) M0 = M # convert to numpy a, b, M = nx.to_numpy(a, b, M) G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0) c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c) # wrap the callable function to handle numpy arrays if isinstance(reg_div, tuple): f0, df0 = reg_div try: f0(G0) df0(G0) except BaseException: warnings.warn("The callable functions should be able to handle numpy arrays, wrapper ar added to handle this which comes with overhead") def f(x): return nx.to_numpy(f0(nx.from_numpy(x, type_as=M0))) def df(x): return nx.to_numpy(df0(nx.from_numpy(x, type_as=M0))) reg_div = (f, df) reg_m1, reg_m2 = get_parameter_pair(reg_m) _func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div) res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf), tol=stopThr, options=dict(maxiter=numItermax, disp=verbose)) G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0) if log: log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res} return G, log else: return G