ot.optim
Generic solvers for regularized OT
Functions
- ot.optim.cg(a, b, M, reg, f, df, G0=None, numItermax=200, numItermaxEmd=100000, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False, **kwargs)[source]
Solve the general regularized OT problem with conditional gradient
The function solves the following optimization problem:
\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg} \cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]where :
\(\mathbf{M}\) is the (ns, nt) metric cost matrix
\(f\) is the regularization term (and df is its gradient)
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (sum to 1)
The algorithm used for solving the problem is conditional gradient as discussed in [1]
- Parameters
a (array-like, shape (ns,)) – samples weights in the source domain
b (array-like, shape (nt,)) – samples in the target domain
M (array-like, shape (ns, nt)) – loss matrix
reg (float) – Regularization term >0
G0 (array-like, shape (ns,nt), optional) – initial guess (default is indep joint density)
numItermax (int, optional) – Max number of iterations
numItermaxEmd (int, optional) – Max number of iterations for emd
stopThr (float, optional) – Stop threshold on the relative variation (>0)
stopThr2 (float, optional) – Stop threshold on the absolute variation (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
**kwargs (dict) – Parameters for linesearch
- Returns
gamma ((ns x nt) ndarray) – Optimal transportation matrix for the given parameters
log (dict) – log dictionary return only if log==True in parameters
References
- 1
Ferradans, S., Papadakis, N., Peyré, G., & Aujol, J. F. (2014). Regularized discrete optimal transport. SIAM Journal on Imaging Sciences, 7(3), 1853-1882.
See also
ot.lp.emd
Unregularized optimal ransport
ot.bregman.sinkhorn
Entropic regularized optimal transport
Examples using ot.optim.cg
- ot.optim.gcg(a, b, M, reg1, reg2, f, df, G0=None, numItermax=10, numInnerItermax=200, stopThr=1e-09, stopThr2=1e-09, verbose=False, log=False)[source]
Solve the general regularized OT problem with the generalized conditional gradient
The function solves the following optimization problem:
\[ \begin{align}\begin{aligned}\gamma = \mathop{\arg \min}_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F + \mathrm{reg_1}\cdot\Omega(\gamma) + \mathrm{reg_2}\cdot f(\gamma)\\s.t. \ \gamma \mathbf{1} &= \mathbf{a}\\ \gamma^T \mathbf{1} &= \mathbf{b}\\ \gamma &\geq 0\end{aligned}\end{align} \]where :
\(\mathbf{M}\) is the (ns, nt) metric cost matrix
\(\Omega\) is the entropic regularization term \(\Omega(\gamma)=\sum_{i,j} \gamma_{i,j}\log(\gamma_{i,j})\)
\(f\) is the regularization term (and df is its gradient)
\(\mathbf{a}\) and \(\mathbf{b}\) are source and target weights (sum to 1)
The algorithm used for solving the problem is the generalized conditional gradient as discussed in [5, 7]
- Parameters
a (array-like, shape (ns,)) – samples weights in the source domain
b (array-like, (nt,)) – samples in the target domain
M (array-like, shape (ns, nt)) – loss matrix
reg1 (float) – Entropic Regularization term >0
reg2 (float) – Second Regularization term >0
G0 (array-like, shape (ns, nt), optional) – initial guess (default is indep joint density)
numItermax (int, optional) – Max number of iterations
numInnerItermax (int, optional) – Max number of iterations of Sinkhorn
stopThr (float, optional) – Stop threshold on the relative variation (>0)
stopThr2 (float, optional) – Stop threshold on the absolute variation (>0)
verbose (bool, optional) – Print information along iterations
log (bool, optional) – record log if True
- Returns
gamma (ndarray, shape (ns, nt)) – Optimal transportation matrix for the given parameters
log (dict) – log dictionary return only if log==True in parameters
References
- 5
Courty; R. Flamary; D. Tuia; A. Rakotomamonjy, “Optimal Transport for Domain Adaptation,” in IEEE Transactions on Pattern Analysis and Machine Intelligence , vol.PP, no.99, pp.1-1
- 7
Rakotomamonjy, A., Flamary, R., & Courty, N. (2015). Generalized conditional gradient: analysis of convergence and applications. arXiv preprint arXiv:1510.06567.
See also
ot.optim.cg
conditional gradient
Examples using ot.optim.gcg
- ot.optim.line_search_armijo(f, xk, pk, gfk, old_fval, args=(), c1=0.0001, alpha0=0.99, alpha_min=None, alpha_max=None)[source]
Armijo linesearch function that works with matrices
Find an approximate minimum of \(f(x_k + \alpha \cdot p_k)\) that satisfies the armijo conditions.
- Parameters
f (callable) – loss function
xk (array-like) – initial position
pk (array-like) – descent direction
gfk (array-like) – gradient of f at \(x_k\)
old_fval (float) – loss value at \(x_k\)
args (tuple, optional) – arguments given to f
c1 (float, optional) – \(c_1\) const in armijo rule (>0)
alpha0 (float, optional) – initial step (>0)
alpha_min (float, optional) – minimum value for alpha
alpha_max (float, optional) – maximum value for alpha
- Returns
alpha (float) – step that satisfy armijo conditions
fc (int) – nb of function call
fa (float) – loss value at step alpha
- ot.optim.solve_1d_linesearch_quad(a, b, c)[source]
For any convex or non-convex 1d quadratic function f, solve the following problem:
\[\mathop{\arg \min}_{0 \leq x \leq 1} \quad f(x) = ax^{2} + bx + c\]
- ot.optim.solve_linesearch(cost, G, deltaG, Mi, f_val, armijo=True, C1=None, C2=None, reg=None, Gc=None, constC=None, M=None, alpha_min=None, alpha_max=None)[source]
Solve the linesearch in the FW iterations
- Parameters
cost (method) – Cost in the FW for the linesearch
G (array-like, shape(ns,nt)) – The transport map at a given iteration of the FW
deltaG (array-like (ns,nt)) – Difference between the optimal map found by linearization in the FW algorithm and the value at a given iteration
Mi (array-like (ns,nt)) – Cost matrix of the linearized transport problem. Corresponds to the gradient of the cost
f_val (float) – Value of the cost at G
armijo (bool, optional) – If True the steps of the line-search is found via an armijo research. Else closed form is used. If there is convergence issues use False.
C1 (array-like (ns,ns), optional) – Structure matrix in the source domain. Only used and necessary when armijo=False
C2 (array-like (nt,nt), optional) – Structure matrix in the target domain. Only used and necessary when armijo=False
reg (float, optional) – Regularization parameter. Only used and necessary when armijo=False
Gc (array-like (ns,nt)) – Optimal map found by linearization in the FW algorithm. Only used and necessary when armijo=False
constC (array-like (ns,nt)) – Constant for the gromov cost. See [24]. Only used and necessary when armijo=False
M (array-like (ns,nt), optional) – Cost matrix between the features. Only used and necessary when armijo=False
alpha_min (float, optional) – Minimum value for alpha
alpha_max (float, optional) – Maximum value for alpha
- Returns
alpha (float) – The optimal step size of the FW
fc (int) – nb of function call. Useless here
f_val (float) – The value of the cost for the next iteration
References
- 24
Vayer Titouan, Chapel Laetitia, Flamary Rémi, Tavenard Romain and Courty Nicolas “Optimal Transport for structured data with application on graphs” International Conference on Machine Learning (ICML). 2019.