Source code for ot.optim

# -*- coding: utf-8 -*-
"""
Generic solvers for regularized OT or its semi-relaxed version.
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Titouan Vayer <titouan.vayer@irisa.fr>
#         Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License

import numpy as np
import warnings
from .lp import emd
from .bregman import sinkhorn
from .utils import list_to_array
from .backend import get_backend

with warnings.catch_warnings():
    warnings.simplefilter("ignore")
    try:
        from scipy.optimize import scalar_search_armijo
    except ImportError:
        from scipy.optimize.linesearch import scalar_search_armijo

# The corresponding scipy function does not work for matrices


[docs] def line_search_armijo( f, xk, pk, gfk, old_fval, args=(), c1=1e-4, alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs ): r""" Armijo linesearch function that works with matrices Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the armijo conditions. .. note:: If the loss function f returns a float (resp. a 1d array) then the returned alpha and fa are float (resp. 1d arrays). Parameters ---------- f : callable loss function xk : array-like initial position pk : array-like descent direction gfk : array-like gradient of `f` at :math:`x_k` old_fval : float or 1d array loss value at :math:`x_k` args : tuple, optional arguments given to `f` c1 : float, optional :math:`c_1` const in armijo rule (>0) alpha0 : float, optional initial step (>0) alpha_min : float, default=0. minimum value for alpha alpha_max : float, optional maximum value for alpha nx : backend, optional If let to its default value None, a backend test will be conducted. Returns ------- alpha : float or 1d array step that satisfy armijo conditions fc : int nb of function call fa : float or 1d array loss value at step alpha """ if nx is None: xk, pk, gfk = list_to_array(xk, pk, gfk) xk0, pk0 = xk, pk nx = get_backend(xk0, pk0) else: xk0, pk0 = xk, pk if len(xk.shape) == 0: xk = nx.reshape(xk, (-1,)) xk = nx.to_numpy(xk) pk = nx.to_numpy(pk) gfk = nx.to_numpy(gfk) fc = [0] def phi(alpha1): # it's necessary to check boundary condition here for the coefficient # as the callback could be evaluated for negative value of alpha by # `scalar_search_armijo` function here: # # https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686 # # see more details https://github.com/PythonOT/POT/issues/502 alpha1 = np.clip(alpha1, alpha_min, alpha_max) # The callable function operates on nx backend fc[0] += 1 alpha10 = nx.from_numpy(alpha1) fval = f(xk0 + alpha10 * pk0, *args) if isinstance(fval, float): # prevent bug from nx.to_numpy that can look for .cpu or .gpu return fval else: return nx.to_numpy(fval) if old_fval is None: phi0 = phi(0.) elif isinstance(old_fval, float): # prevent bug from nx.to_numpy that can look for .cpu or .gpu phi0 = old_fval else: phi0 = nx.to_numpy(old_fval) derphi0 = np.sum(pk * gfk) # Quickfix for matrices alpha, phi1 = scalar_search_armijo( phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min) if alpha is None: return 0., fc[0], nx.from_numpy(phi0, type_as=xk0) else: alpha = np.clip(alpha, alpha_min, alpha_max) return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
[docs] def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None, numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem or its semi-relaxed version with conditional gradient or generalized conditional gradient depending on the provided linear program solver. The function solves the following optimization problem if set as a conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1} \cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma^T \mathbf{1} &= \mathbf{b} (optional constraint) \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`f` is the regularization term (and `df` is its gradient) - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` The function solves the following optimization problem if set a generalized conditional gradient: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 where : - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) samples weights in the target domain M : array-like, shape (ns, nt) loss matrix f : function Regularization function taking a transportation matrix as argument df: function Gradient of the regularization function taking a transportation matrix as argument reg1 : float Regularization term >0 reg2 : float, Entropic Regularization term >0. Ignored if set to None. lp_solver: function, linear program solver for direction finding of the (generalized) conditional gradient. If set to emd will solve the general regularized OT problem using cg. If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg. If set to sinkhorn will solve the general regularized OT problem using generalized cg. line_search: function, Function to find the optimal step. Currently used instances are: line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem. solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg. G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True **kwargs : dict Parameters for linesearch Returns ------- gamma : (ns x nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters .. _references-cg: .. _references_gcg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also -------- ot.lp.emd : Unregularized optimal transport ot.bregman.sinkhorn : Entropic regularized optimal transport """ a, b, M, G0 = list_to_array(a, b, M, G0) if isinstance(M, int) or isinstance(M, float): nx = get_backend(a, b) else: nx = get_backend(a, b, M) loop = 1 if log: log = {'loss': []} if G0 is None: G = nx.outer(a, b) else: # to not change G0 in place. G = nx.copy(G0) if reg2 is None: def cost(G): return nx.sum(M * G) + reg1 * f(G) else: def cost(G): return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G)) cost_G = cost(G) if log: log['loss'].append(cost_G) it = 0 if verbose: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0)) while loop: it += 1 old_cost_G = cost_G # problem linearization Mi = M + reg1 * df(G) if not (reg2 is None): Mi = Mi + reg2 * (1 + nx.log(G)) # set M positive Mi = Mi + nx.min(Mi) # solve linear program Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs) # line search deltaG = Gc - G alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs) G = G + alpha * deltaG # test convergence if it >= numItermax: loop = 0 abs_delta_cost_G = abs(cost_G - old_cost_G) relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2: loop = 0 if log: log['loss'].append(cost_G) if verbose: if it % 20 == 0: print('{:5s}|{:12s}|{:8s}|{:8s}'.format( 'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48) print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G)) if log: log.update(innerlog_) return G, log else: return G
[docs] def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with conditional gradient The function solves the following optimization problem: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`f` is the regularization term (and `df` is its gradient) - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) samples in the target domain M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) line_search: function, Function to find the optimal step. Default is line_search_armijo. numItermax : int, optional Max number of iterations numItermaxEmd : int, optional Max number of iterations for emd stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True **kwargs : dict Parameters for linesearch Returns ------- gamma : (ns x nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters .. _references-cg: References ---------- .. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882. See Also -------- ot.lp.emd : Unregularized optimal transport ot.bregman.sinkhorn : Entropic regularized optimal transport """ def lp_solver(a, b, M, **kwargs): return emd(a, b, M, numItermaxEmd, log=True) return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs] def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo, numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized and semi-relaxed OT problem with conditional gradient The function solves the following optimization problem: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`f` is the regularization term (and `df` is its gradient) - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, shape (nt,) currently estimated samples weights in the target domain M : array-like, shape (ns, nt) loss matrix reg : float Regularization term >0 G0 : array-like, shape (ns,nt), optional initial guess (default is indep joint density) line_search: function, Function to find the optimal step. Default is the armijo line-search. numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True **kwargs : dict Parameters for linesearch Returns ------- gamma : (ns x nt) ndarray Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters .. _references-cg: References ---------- .. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty. "Semi-relaxed Gromov-Wasserstein divergence and applications on graphs" International Conference on Learning Representations (ICLR), 2021. """ nx = get_backend(a, b) def lp_solver(a, b, Mi, **kwargs): # get minimum by rows as binary mask Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1))) Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1)) # return by default an empty inner_log return Gc, {} return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0, numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs] def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs): r""" Solve the general regularized OT problem with the generalized conditional gradient The function solves the following optimization problem: .. math:: \gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma) s.t. \ \gamma \mathbf{1} &= \mathbf{a} \gamma^T \mathbf{1} &= \mathbf{b} \gamma &\geq 0 where : - :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix - :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})` - :math:`f` is the regularization term (and `df` is its gradient) - :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1) The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>` Parameters ---------- a : array-like, shape (ns,) samples weights in the source domain b : array-like, (nt,) samples in the target domain M : array-like, shape (ns, nt) loss matrix reg1 : float Entropic Regularization term >0 reg2 : float Second Regularization term >0 G0 : array-like, shape (ns, nt), optional initial guess (default is indep joint density) numItermax : int, optional Max number of iterations numInnerItermax : int, optional Max number of iterations of Sinkhorn stopThr : float, optional Stop threshold on the relative variation (>0) stopThr2 : float, optional Stop threshold on the absolute variation (>0) verbose : bool, optional Print information along iterations log : bool, optional record log if True Returns ------- gamma : ndarray, shape (ns, nt) Optimal transportation matrix for the given parameters log : dict log dictionary return only if log==True in parameters .. _references-gcg: References ---------- .. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1 .. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567. See Also -------- ot.optim.cg : conditional gradient """ def lp_solver(a, b, Mi, **kwargs): return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs) def line_search(cost, G, deltaG, Mi, cost_G, **kwargs): return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs) return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0, numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs] def solve_1d_linesearch_quad(a, b): r""" For any convex or non-convex 1d quadratic function `f`, solve the following problem: .. math:: \mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c Parameters ---------- a,b : float or tensors (1,) The coefficients of the quadratic function Returns ------- x : float The optimal value which leads to the minimal cost """ if a > 0: # convex minimum = min(1., max(0., -b / (2.0 * a))) return minimum else: # non convex if a + b < 0: return 1. else: return 0.