# -*- coding: utf-8 -*-
"""
Generic solvers for regularized OT or its semi-relaxed version.
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Titouan Vayer <titouan.vayer@irisa.fr>
# Cédric Vincent-Cuaz <cedvincentcuaz@gmail.com>
# License: MIT License
import numpy as np
import warnings
from .lp import emd
from .bregman import sinkhorn
from .backend import get_backend
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
from scipy.optimize._linesearch import scalar_search_armijo
except ModuleNotFoundError:
# scipy<1.8.0
from scipy.optimize.linesearch import scalar_search_armijo
# The corresponding scipy function does not work for matrices
[docs]
def line_search_armijo(
f,
xk,
pk,
gfk,
old_fval,
args=(),
c1=1e-4,
alpha0=0.99,
alpha_min=0.0,
alpha_max=None,
nx=None,
**kwargs,
):
r"""
Armijo linesearch function that works with matrices
Find an approximate minimum of :math:`f(x_k + \alpha \cdot p_k)` that satisfies the
armijo conditions.
.. note:: If the loss function f returns a float (resp. a 1d array) then
the returned alpha and fa are float (resp. 1d arrays).
Parameters
----------
f : callable
loss function
xk : array-like
initial position
pk : array-like
descent direction
gfk : array-like
gradient of `f` at :math:`x_k`
old_fval : float or 1d array
loss value at :math:`x_k`
args : tuple, optional
arguments given to `f`
c1 : float, optional
:math:`c_1` const in armijo rule (>0)
alpha0 : float, optional
initial step (>0)
alpha_min : float, default=0.
minimum value for alpha
alpha_max : float, optional
maximum value for alpha
nx : backend, optional
If let to its default value None, a backend test will be conducted.
Returns
-------
alpha : float or 1d array
step that satisfy armijo conditions
fc : int
nb of function call
fa : float or 1d array
loss value at step alpha
"""
if nx is None:
xk0, pk0 = xk, pk
nx = get_backend(xk0, pk0)
else:
xk0, pk0 = xk, pk
if len(xk.shape) == 0:
xk = nx.reshape(xk, (-1,))
xk = nx.to_numpy(xk)
pk = nx.to_numpy(pk)
gfk = nx.to_numpy(gfk)
fc = [0]
def phi(alpha1):
# it's necessary to check boundary condition here for the coefficient
# as the callback could be evaluated for negative value of alpha by
# `scalar_search_armijo` function here:
#
# https://github.com/scipy/scipy/blob/11509c4a98edded6c59423ac44ca1b7f28fba1fd/scipy/optimize/linesearch.py#L686
#
# see more details https://github.com/PythonOT/POT/issues/502
alpha1 = np.clip(alpha1, alpha_min, alpha_max)
# The callable function operates on nx backend
fc[0] += 1
alpha10 = nx.from_numpy(alpha1)
fval = f(xk0 + alpha10 * pk0, *args)
if isinstance(fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
return fval
else:
return nx.to_numpy(fval)
if old_fval is None:
phi0 = phi(0.0)
elif isinstance(old_fval, float):
# prevent bug from nx.to_numpy that can look for .cpu or .gpu
phi0 = old_fval
else:
phi0 = nx.to_numpy(old_fval)
derphi0 = np.sum(pk * gfk) # Quickfix for matrices
alpha, phi1 = scalar_search_armijo(
phi, phi0, derphi0, c1=c1, alpha0=alpha0, amin=alpha_min
)
if alpha is None:
return 0.0, fc[0], nx.from_numpy(phi0, type_as=xk0)
else:
alpha = np.clip(alpha, alpha_min, alpha_max)
return (
nx.from_numpy(alpha, type_as=xk0),
fc[0],
nx.from_numpy(phi1, type_as=xk0),
)
[docs]
def generic_conditional_gradient(
a,
b,
M,
f,
df,
reg1,
reg2,
lp_solver,
line_search,
G0=None,
numItermax=200,
stopThr=1e-9,
stopThr2=1e-9,
verbose=False,
log=False,
nx=None,
**kwargs,
):
r"""
Solve the general regularized OT problem or its semi-relaxed version with
conditional gradient or generalized conditional gradient depending on the
provided linear program solver.
The function solves the following optimization problem if set as a conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b} (optional constraint)
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
The function solves the following optimization problem if set a generalized conditional gradient:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1}\cdot f(\gamma) + \mathrm{reg_2}\cdot\Omega(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
f : function
Regularization function taking a transportation matrix as argument
df: function
Gradient of the regularization function taking a transportation matrix as argument
reg1 : float
Regularization term >0
reg2 : float,
Entropic Regularization term >0. Ignored if set to None.
lp_solver: function,
linear program solver for direction finding of the (generalized) conditional gradient.
This function must take the form `lp_solver(a, b, Mi, **kwargs)` with p:
`a` and `b` are sample weights in both domains; `Mi` is the gradient of
the regularized objective; optimal arguments via kwargs.
It must output an admissible transport plan.
For instance, for the general regularized OT problem with conditional gradient :ref:`[1] <references-cg>`:
def lp_solver(a, b, M, **kwargs):
return ot.emd(a, b, M)
or with the generalized conditional gradient instead :ref:`[5, 7] <references-gcg>`:
def lp_solver(a, b, Mi, **kwargs):
return ot.sinkhorn(a, b, Mi)
line_search: function,
Function to find the optimal step. This function must take the form
`line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)` with: `cost`
the cost function, `G` the transport plan, `deltaG` the conditional
gradient direction given by lp_solver, `Mi` the gradient of regularized
objective, `cost_G` the cost at G, `df_G` the gradient of the regularizer
at G. Two types of outputs are supported:
Instances such as `ot.optim.line_search_armijo` (generic solver),
`ot.gromov.solve_gromov_linesearch` (FGW problems),
`solve_semirelaxed_gromov_linesearch` (srFGW problems) and
`gcg_linesearch` (generalized cg), output : the line-search step alpha,
the number of iterations used in the solver if applicable and the loss
value at step alpha. These can be called e.g as:
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
return ot.optim.line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
Instances such as `ot.gromov.solve_partial_gromov_linesearch` for partial
(F)GW problems add as finale output, the next step gradient reading as
a convex combination of previously computed gradients, taking advantage of the regularizer
quadratic form.
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
nx : backend, optional
If let to its default value None, the backend will be deduced from other inputs.
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
.. _references_gcg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
if nx is None:
if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
nx = get_backend(a, b, M)
loop = 1
if log:
log = {"loss": []}
if G0 is None:
G = nx.outer(a, b)
else:
# to not change G0 in place.
G = nx.copy(G0)
if reg2 is None:
def cost(G):
return nx.sum(M * G) + reg1 * f(G)
else:
def cost(G):
return nx.sum(M * G) + reg1 * f(G) + reg2 * nx.sum(G * nx.log(G))
cost_G = cost(G)
if log:
log["loss"].append(cost_G)
df_G = None
it = 0
if verbose:
print(
"{:5s}|{:12s}|{:8s}|{:8s}".format(
"It.", "Loss", "Relative loss", "Absolute loss"
)
+ "\n"
+ "-" * 48
)
print("{:5d}|{:8e}|{:8e}|{:8e}".format(it, cost_G, 0, 0))
while loop:
it += 1
old_cost_G = cost_G
# problem linearization
if df_G is None:
df_G = df(G)
Mi = M + reg1 * df_G
if reg2 is not None:
Mi = Mi + reg2 * (1 + nx.log(G))
# solve linear program
Gc, innerlog_ = lp_solver(a, b, Mi, **kwargs)
# line search
deltaG = Gc - G
res_line_search = line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs)
if len(res_line_search) == 3:
# the line-search does not allow to update the gradient
alpha, fc, cost_G = res_line_search
df_G = None
else:
# the line-search allows to update the gradient directly
# e.g. while using quadratic losses as the gromov-wasserstein loss
alpha, fc, cost_G, df_G = res_line_search
G = G + alpha * deltaG
# test convergence
if it >= numItermax:
loop = 0
abs_delta_cost_G = abs(cost_G - old_cost_G)
relative_delta_cost_G = (
abs_delta_cost_G / abs(cost_G) if cost_G != 0.0 else np.nan
)
if relative_delta_cost_G < stopThr or abs_delta_cost_G < stopThr2:
loop = 0
if log:
log["loss"].append(cost_G)
if verbose:
if it % 20 == 0:
print(
"{:5s}|{:12s}|{:8s}|{:8s}".format(
"It.", "Loss", "Relative loss", "Absolute loss"
)
+ "\n"
+ "-" * 48
)
print(
"{:5d}|{:8e}|{:8e}|{:8e}".format(
it, cost_G, relative_delta_cost_G, abs_delta_cost_G
)
)
if log:
log.update(innerlog_)
return G, log
else:
return G
[docs]
def cg(
a,
b,
M,
reg,
f,
df,
G0=None,
line_search=None,
numItermax=200,
numItermaxEmd=100000,
stopThr=1e-9,
stopThr2=1e-9,
verbose=False,
log=False,
nx=None,
**kwargs,
):
r"""
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
samples in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
line_search: function,
Function to find the optimal step.
Default is None and calls a wrapper to line_search_armijo.
numItermax : int, optional
Max number of iterations
numItermaxEmd : int, optional
Max number of iterations for emd
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
nx : backend, optional
If let to its default value None, the backend will be deduced from other inputs.
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
References
----------
.. [1] Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
See Also
--------
ot.lp.emd : Unregularized optimal transport
ot.bregman.sinkhorn : Entropic regularized optimal transport
"""
if nx is None:
if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
nx = get_backend(a, b, M)
if line_search is None:
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs)
def lp_solver(a, b, M, **kwargs):
return emd(a, b, M, numItermaxEmd, log=True)
return generic_conditional_gradient(
a,
b,
M,
f,
df,
reg,
None,
lp_solver,
line_search,
G0=G0,
numItermax=numItermax,
stopThr=stopThr,
stopThr2=stopThr2,
verbose=verbose,
log=log,
nx=nx,
**kwargs,
)
[docs]
def semirelaxed_cg(
a,
b,
M,
reg,
f,
df,
G0=None,
line_search=None,
numItermax=200,
stopThr=1e-9,
stopThr2=1e-9,
verbose=False,
log=False,
nx=None,
**kwargs,
):
r"""
Solve the general regularized and semi-relaxed OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
currently estimated samples weights in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
line_search: function,
Function to find the optimal step.
Default is None and calls a wrapper to line_search_armijo.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
nx : backend, optional
If let to its default value None, the backend will be deduced from other inputs.
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-cg:
References
----------
.. [48] Cédric Vincent-Cuaz, Rémi Flamary, Marco Corneli, Titouan Vayer, Nicolas Courty.
"Semi-relaxed Gromov-Wasserstein divergence and applications on graphs"
International Conference on Learning Representations (ICLR), 2021.
"""
if nx is None:
if isinstance(M, int) or isinstance(M, float):
nx = get_backend(a, b)
else:
nx = get_backend(a, b, M)
if line_search is None:
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, nx=nx, **kwargs)
def lp_solver(a, b, Mi, **kwargs):
# get minimum by rows as binary mask
min_ = nx.reshape(nx.min(Mi, axis=1), (-1, 1))
# instead of exact elements equal to min_ we consider a small margin (1e-15)
# for float precision issues. Then the mass is split uniformly
# between these elements.
Gc = nx.ones(1, type_as=a) * (Mi <= min_ + 1e-15)
Gc *= nx.reshape((a / nx.sum(Gc, axis=1)), (-1, 1))
# return by default an empty inner_log
return Gc, {}
return generic_conditional_gradient(
a,
b,
M,
f,
df,
reg,
None,
lp_solver,
line_search,
G0=G0,
numItermax=numItermax,
stopThr=stopThr,
stopThr2=stopThr2,
verbose=verbose,
log=log,
nx=nx,
**kwargs,
)
[docs]
def partial_cg(
a,
b,
a_extended,
b_extended,
M,
reg,
f,
df,
G0=None,
line_search=line_search_armijo,
numItermax=200,
stopThr=1e-9,
stopThr2=1e-9,
warn=True,
verbose=False,
log=False,
**kwargs,
):
r"""
Solve the general regularized partial OT problem with conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma \mathbf{1} &= \mathbf{b}
\mathbf{1}^T \gamma^T \mathbf{1} = m &\leq \min\{\|\mathbf{p}\|_1, \|\mathbf{q}\|_1\}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights
- `m` is the amount of mass to be transported
The algorithm used for solving the problem is conditional gradient as discussed in :ref:`[1] <references-cg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, shape (nt,)
currently estimated samples weights in the target domain
a_extended : array-like, shape (ns + nb_dummies,)
samples weights in the source domain with added dummy nodes
b_extended : array-like, shape (nt + nb_dummies,)
currently estimated samples weights in the target domain with added dummy nodes
M : array-like, shape (ns, nt)
loss matrix
reg : float
Regularization term >0
G0 : array-like, shape (ns,nt), optional
initial guess (default is indep joint density)
line_search: function,
Function to find the optimal step.
Default is the armijo line-search.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
warn: bool, optional.
Whether to raise a warning when EMD did not converge.
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
**kwargs : dict
Parameters for linesearch
Returns
-------
gamma : (ns x nt) ndarray
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-partial-cg:
References
----------
.. [29] Chapel, L., Alaya, M., Gasso, G. (2020). "Partial Optimal
Transport with Applications on Positive-Unlabeled Learning".
NeurIPS.
"""
n, m = a.shape[0], b.shape[0]
n_extended, m_extended = a_extended.shape[0], b_extended.shape[0]
nb_dummies = n_extended - n
def lp_solver(a, b, Mi, **kwargs):
# add dummy nodes to Mi
Mi_extended = np.zeros((n_extended, m_extended), dtype=Mi.dtype)
Mi_extended[:n, :m] = Mi
Mi_extended[-nb_dummies:, -nb_dummies:] = np.max(M) * 1e2
G_extended, log_ = emd(
a_extended, b_extended, Mi_extended, numItermax, log=True
)
Gc = G_extended[:n, :m]
if warn:
if log_["warning"] is not None:
raise ValueError(
"Error in the EMD resolution: try to increase the"
" number of dummy points"
)
return Gc, log_
return generic_conditional_gradient(
a,
b,
M,
f,
df,
reg,
None,
lp_solver,
line_search,
G0=G0,
numItermax=numItermax,
stopThr=stopThr,
stopThr2=stopThr2,
verbose=verbose,
log=log,
**kwargs,
)
[docs]
def gcg(
a,
b,
M,
reg1,
reg2,
f,
df,
G0=None,
numItermax=10,
numInnerItermax=200,
stopThr=1e-9,
stopThr2=1e-9,
verbose=False,
log=False,
**kwargs,
):
r"""
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
.. math::
\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)
s.t. \ \gamma \mathbf{1} &= \mathbf{a}
\gamma^T \mathbf{1} &= \mathbf{b}
\gamma &\geq 0
where :
- :math:`\mathbf{M}` is the (`ns`, `nt`) metric cost matrix
- :math:`\Omega` is the entropic regularization term :math:`\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})`
- :math:`f` is the regularization term (and `df` is its gradient)
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional gradient as discussed in :ref:`[5, 7] <references-gcg>`
Parameters
----------
a : array-like, shape (ns,)
samples weights in the source domain
b : array-like, (nt,)
samples in the target domain
M : array-like, shape (ns, nt)
loss matrix
reg1 : float
Entropic Regularization term >0
reg2 : float
Second Regularization term >0
G0 : array-like, shape (ns, nt), optional
initial guess (default is indep joint density)
numItermax : int, optional
Max number of iterations
numInnerItermax : int, optional
Max number of iterations of Sinkhorn
stopThr : float, optional
Stop threshold on the relative variation (>0)
stopThr2 : float, optional
Stop threshold on the absolute variation (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
gamma : ndarray, shape (ns, nt)
Optimal transportation matrix for the given parameters
log : dict
log dictionary return only if log==True in parameters
.. _references-gcg:
References
----------
.. [5] N. Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, "Optimal Transport for Domain Adaptation," in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
.. [7] Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See Also
--------
ot.optim.cg : conditional gradient
"""
def lp_solver(a, b, Mi, **kwargs):
return sinkhorn(a, b, Mi, reg1, numItermax=numInnerItermax, log=True, **kwargs)
def line_search(cost, G, deltaG, Mi, cost_G, df_G, **kwargs):
return line_search_armijo(cost, G, deltaG, Mi, cost_G, **kwargs)
return generic_conditional_gradient(
a,
b,
M,
f,
df,
reg2,
reg1,
lp_solver,
line_search,
G0=G0,
numItermax=numItermax,
stopThr=stopThr,
stopThr2=stopThr2,
verbose=verbose,
log=log,
**kwargs,
)
[docs]
def solve_1d_linesearch_quad(a, b):
r"""
For any convex or non-convex 1d quadratic function `f`, solve the following problem:
.. math::
\mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c
Parameters
----------
a,b : float or tensors (1,)
The coefficients of the quadratic function
Returns
-------
x : float
The optimal value which leads to the minimal cost
"""
if a > 0: # convex
minimum = min(1.0, max(0.0, -b / (2.0 * a)))
return minimum
else: # non convex
if a + b < 0:
return 1.0
else:
return 0.0