# -*- coding: utf-8 -*-
"""
Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
import warnings
from .lp import emd
from .bregman import sinkhorn
from .backend import get_backend
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
from scipy.optimize._linesearch import scalar_search_armijo
except ModuleNotFoundError:
# scipy<1.8.0
from scipy.optimize.linesearch import scalar_search_armijo
# The corresponding scipy function does not work for matrices
[docs]
def line_search_armijo(
f, xk, pk, gfk, old_fval, args=(), c1=1e-4,
alpha0=0.99, alpha_min=0., alpha_max=None, nx=None, **kwargs
):
r"""
Armijo linesearch function that works with matrices
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
.. note:: If the loss function f returns a float (resp. a 1d array) then
the returned alpha and fa are float (resp. 1d arrays).
Parameters
----------
f : callable
loss function
xk : array-like
initial position
pk : array-like
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
c1 : float, optional
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, default=0.
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
fa : float or 1d array
loss value at step alpha
"""
if nx is None:
xk0, pk0 = xk, pk
nx = get_backend(xk0, pk0)
else:
xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
xk = nx.to_numpy(xk)
pk = nx.to_numpy(pk)
gfk = nx.to_numpy(gfk)
fc = [0]
def phi(alpha1):
# it's necessary to check boundary condition here for the coefficient
# as the callback could be evaluated for negative value of alpha by
# `scalar_search_armijo` function here:
#
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
#
# see more details https://github.com/PythonOT/POT/issues/502
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
# The callable function operates on nx backend
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
fval = f(xk0 + alpha10 * pk0, *args)
if isinstance(fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
return fval
else:
return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.)
elif isinstance(old_fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
else:
phi0 = nx.to_numpy(old_fval)
derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min)
if alpha is None:
return 0., fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
alpha = np.clip(alpha, alpha_min, alpha_max)
return nx.from_numpy(alpha, type_as=xk0), fc[0], nx.from_numpy(phi1, type_as=xk0)
[docs]
def generic_conditional_gradient(a, b, M, f, df, reg1, reg2, lp_solver, line_search, G0=None,
numItermax=200, stopThr=1e-9,
stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem or its semi-relaxed version with
conditional gradient or generalized conditional gradient depending on the
provided linear program solver.
The function solves the following optimization problem if set as a conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
f : function
Regularization function taking a transportation matrix as argument
df: function
Gradient of the regularization function taking a transportation matrix as argument
reg1 : float
Regularization term >0
reg2 : float,
Entropic Regularization term >0. Ignored if set to None.
lp_solver: function,
linear program solver for direction finding of the (generalized) conditional gradient.
If set to emd will solve the general regularized OT problem using cg.
If set to lp_semi_relaxed_OT will solve the general regularized semi-relaxed OT problem using cg.
If set to sinkhorn will solve the general regularized OT problem using generalized cg.
line_search: function,
Function to find the optimal step. Currently used instances are:
line_search_armijo (generic solver). solve_gromov_linesearch for (F)GW problem.
solve_semirelaxed_gromov_linesearch for sr(F)GW problem. gcg_linesearch for the Generalized cg.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
.. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
nx = get_backend(a, b, M)
loop = 1
if log:
log = {'loss': []}
if G0 is None:
G = nx.outer(a, b)
else:
# to not change G0 in place.
G = nx.copy(G0)
if reg2 is None:
def cost(G):
return nx.sum(M * G) + reg1 * f(G)
else:
def cost(G):
return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
cost_G = cost(G)
if log:
log['loss'].append(cost_G)
it = 0
if verbose:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, 0, 0))
while loop:
it += 1
old_cost_G = cost_G
# problem linearization
Mi = M + reg1 * df(G)
if not (reg2 is None):
Mi = Mi + reg2 * (1 + nx.log(G))
# set M positive
Mi = Mi + nx.min(Mi)
# solve linear program
Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
# line search
deltaG = Gc - G
alpha, fc, cost_G = line_search(cost, G, deltaG, Mi, cost_G, **kwargs)
G = G + alpha * deltaG
# test convergence
if it >= numItermax:
loop = 0
abs_delta_cost_G = abs(cost_G - old_cost_G)
relative_delta_cost_G = abs_delta_cost_G / abs(cost_G) if cost_G != 0. else np.nan
if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
log['loss'].append(cost_G)
if verbose:
if it % 20 == 0:
print('{:5s}|{:12s}|{:8s}|{:8s}'.format(
'It.', 'Loss', 'Relative loss', 'Absolute loss') + '\n' + '-' * 48)
print('{:5d}|{:8e}|{:8e}|{:8e}'.format(it, cost_G, relative_delta_cost_G, abs_delta_cost_G))
if log:
log.update(innerlog_)
return G, log
else:
return G
[docs]
def cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
numItermax=200, numItermaxEmd=100000, stopThr=1e-9, stopThr2=1e-9,
verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
samples in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
line_search: function,
Function to find the optimal step.
Default is line_search_armijo.
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
See Also
--------
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
def lp_solver(a, b, M, **kwargs):
return emd(a, b, M, numItermaxEmd, log=True)
return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
numItermax=numItermax, stopThr=stopThr,
stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs]
def semirelaxed_cg(a, b, M, reg, f, df, G0=None, line_search=line_search_armijo,
numItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized and semi-relaxed OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
currently estimated samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
line_search: function,
Function to find the optimal step.
Default is the armijo line-search.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
References
----------
.. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2021.
"""
nx = get_backend(a, b)
def lp_solver(a, b, Mi, **kwargs):
# get minimum by rows as binary mask
Gc = nx.ones(1, type_as=a) * (Mi == nx.reshape(nx.min(Mi, axis=1), (-1, 1)))
Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
# return by default an empty inner_log
return Gc, {}
return generic_conditional_gradient(a, b, M, f, df, reg, None, lp_solver, line_search, G0=G0,
numItermax=numItermax, stopThr=stopThr,
stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs]
def gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10,
numInnerItermax=200, stopThr=1e-9, stopThr2=1e-9, verbose=False, log=False, **kwargs):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, (nt,)
samples in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-gcg:
References
----------
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
ot.optim.cg : conditional gradient
"""
def lp_solver(a, b, Mi, **kwargs):
return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
def line_search(cost, G, deltaG, Mi, cost_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
return generic_conditional_gradient(a, b, M, f, df, reg2, reg1, lp_solver, line_search, G0=G0,
numItermax=numItermax, stopThr=stopThr, stopThr2=stopThr2, verbose=verbose, log=log, **kwargs)
[docs]
def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
.. math::
\mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c
Parameters
----------
a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
-------
x : float
The optimal value which leads to the minimal cost
"""
if a > 0: # convex
minimum = min(1., max(0., -b / (2.0 * a)))
return minimum
else: # non convex
if a + b < 0:
return 1.
else:
return 0.