Source code for ot.unbalanced._mm

# -*- coding: utf-8 -*-
"""
Regularized Unbalanced OT solvers
"""

# Author: Hicham Janati <hicham.janati@inria.fr>
#         Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
#         Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License

from ..backend import get_backend
from ..utils import list_to_array, get_parameter_pair


[docs] def mm_unbalanced( a, b, M, reg_m, c=None, reg=0, div="kl", G0=None, numItermax=1000, stopThr=1e-15, verbose=False, log=False, ): r""" Solve the unbalanced optimal transport problem and return the OT plan. The function solves the following optimization problem: .. math:: W = \arg \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 where: - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - :math:`\mathbf{c}` is a reference distribution for the regularization - div is a divergence, either Kullback-Leibler or half-squared :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- minimization algorithm as proposed in :ref:`[41] <references-regpath>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` If `a` is an empty list or array ([]), then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` If `b` is an empty list or array ([]), then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must have the same backend as input arrays `(a, b, M)`. reg : float, optional (default = 0) Regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. If None, then use :math:`\mathbf{c} = \mathbf{a} \mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) G0: array-like (dim_a, dim_b) Initialization of the transport matrix numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : (dim_a, dim_b) array-like Optimal transportation matrix for the given parameters log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2) array([[0.45, 0. ], [0. , 0.34]]) >>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2) array([[0.4, 0. ], [0. , 0.1]]) .. _references-regpath: References ---------- .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. See Also -------- ot.lp.emd : Unregularized OT ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT """ M, a, b = list_to_array(M, a, b) nx = get_backend(M, a, b) dim_a, dim_b = M.shape if len(a) == 0: a = nx.ones(dim_a, type_as=M) / dim_a if len(b) == 0: b = nx.ones(dim_b, type_as=M) / dim_b G = a[:, None] * b[None, :] if G0 is None else G0 if reg > 0: # regularized case c = a[:, None] * b[None, :] if c is None else c else: # unregularized case c = 0 reg_m1, reg_m2 = get_parameter_pair(reg_m) if log: log = {"err": [], "G": []} div = div.lower() if div == "kl": sum_r = reg + reg_m1 + reg_m2 r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r K = (a[:, None] ** r1) * (b[None, :] ** r2) * (c**r) * nx.exp(-M / sum_r) elif div == "l2": K = (reg_m1 * a[:, None]) + (reg_m2 * b[None, :]) + reg * c - M K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M)) else: raise ValueError("Unknown div = {}. Must be either 'kl' or 'l2'".format(div)) for i in range(numItermax): Gprev = G if div == "kl": Gd = (nx.sum(G, 1, keepdims=True) ** r1) * ( nx.sum(G, 0, keepdims=True) ** r2 ) + 1e-16 G = K * G ** (r1 + r2) / Gd elif div == "l2": Gd = ( reg_m1 * nx.sum(G, 1, keepdims=True) + reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16 ) G = K * G / Gd err = nx.sqrt(nx.sum((G - Gprev) ** 2)) if log: log["err"].append(err) log["G"].append(G) if verbose: print("{:5d}|{:8e}|".format(i, err)) if err < stopThr: break if log: linear_cost = nx.sum(G * M) log["cost"] = linear_cost m1, m2 = nx.sum(G, 1), nx.sum(G, 0) if div == "kl": cost = ( linear_cost + reg_m1 * nx.kl_div(m1, a, mass=True) + reg_m2 * nx.kl_div(m2, b, mass=True) ) if reg > 0: cost = cost + reg * nx.kl_div(G, c, mass=True) else: cost = ( linear_cost + reg_m1 * 0.5 * nx.sum((m1 - a) ** 2) + reg_m2 * 0.5 * nx.sum((m2 - b) ** 2) ) if reg > 0: cost = cost + reg * 0.5 * nx.sum((G - c) ** 2) log["total_cost"] = cost return G, log else: return G
[docs] def mm_unbalanced2( a, b, M, reg_m, c=None, reg=0, div="kl", G0=None, returnCost="linear", numItermax=1000, stopThr=1e-15, verbose=False, log=False, ): r""" Solve the unbalanced optimal transport problem and return the OT cost. The function solves the following optimization problem: .. math:: \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) + \mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) + \mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c}) s.t. \gamma \geq 0 where: - :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions - :math:`\mathbf{c}` is a reference distribution for the regularization - :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or half-squared :math:`\ell_2` divergence The algorithm used for solving the problem is a maximization- minimization algorithm as proposed in :ref:`[41] <references-regpath>` Parameters ---------- a : array-like (dim_a,) Unnormalized histogram of dimension `dim_a` If `a` is an empty list or array ([]), then `a` is set to uniform distribution. b : array-like (dim_b,) Unnormalized histogram of dimension `dim_b` If `b` is an empty list or array ([]), then `b` is set to uniform distribution. M : array-like (dim_a, dim_b) loss matrix reg_m: float or indexable object of length 1 or 2 Marginal relaxation term: nonnegative but cannot be infinity. If :math:`\mathrm{reg_{m}}` is a scalar or an indexable object of length 1, then the same :math:`\mathrm{reg_{m}}` is applied to both marginal relaxations. If :math:`\mathrm{reg_{m}}` is an array, it must have the same backend as input arrays `(a, b, M)`. reg : float, optional (default = 0) Entropy regularization term >= 0. By default, solve the unregularized problem c : array-like (dim_a, dim_b), optional (default = None) Reference measure for the regularization. If None, then use :math:`\mathbf{c} = mathbf{a} mathbf{b}^T`. div: string, optional Divergence to quantify the difference between the marginals. Can take two values: 'kl' (Kullback-Leibler) or 'l2' (half-squared) G0: array-like (dim_a, dim_b) Initialization of the transport matrix returnCost: string, optional (default = "linear") If `returnCost` = "linear", then return the linear part of the unbalanced OT loss. If `returnCost` = "total", then return the total unbalanced OT loss. numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- ot_cost : array-like the OT cost between :math:`\mathbf{a}` and :math:`\mathbf{b}` log : dict log dictionary returned only if `log` is `True` Examples -------- >>> import ot >>> import numpy as np >>> a=[.5, .5] >>> b=[.5, .5] >>> M=[[1., 36.],[9., 4.]] >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2) 0.8 >>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2) 1.79 References ---------- .. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021). Unbalanced optimal transport through non-negative penalized linear regression. NeurIPS. See Also -------- ot.lp.emd2 : Unregularized OT loss ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss """ _, log_mm = mm_unbalanced( a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=True, ) if returnCost == "linear": cost = log_mm["cost"] elif returnCost == "total": cost = log_mm["total_cost"] else: raise ValueError("Unknown returnCost = {}".format(returnCost)) if log: return cost, log_mm else: return cost