# -*- coding: utf-8 -*-
"""
Regularized Unbalanced OT solvers
"""
# Author: Hicham Janati <hicham.janati@inria.fr>
# Laetitia Chapel <laetitia.chapel@univ-ubs.fr>
# Quang Huy Tran <quang-huy.tran@univ-ubs.fr>
#
# License: MIT License
from __future__ import division
import warnings
import numpy as np
from scipy.optimize import minimize, Bounds
from .backend import get_backend
from .utils import list_to_array, get_parameter_pair
[docs]
def sinkhorn_unbalanced(a, b, M, reg, reg_m, method='sinkhorn',
reg_type="entropy", warmstart=None, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the unbalanced entropic regularization optimal transport problem
and return the OT plan
The function solves the following optimization problem:
.. math::
W = \min_\gamma \ \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot \Omega(\gamma) +
\mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma \geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
The entropic balanced OT can be recovered using `reg_m=float("inf")`.
For semi-relaxed case, use either
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
'sinkhorn_reg_scaling', see those function for specific parameters
reg_type : string, optional
Regularizer term. Can take two values:
'entropy' (negative entropy)
:math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or
'kl' (Kullback-Leibler)
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors).
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
if n_hists == 1:
- gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.], [1., 0.]]
>>> ot.sinkhorn_unbalanced(a, b, M, 1, 1)
array([[0.51122814, 0.18807032],
[0.18807032, 0.51122814]])
.. _references-sinkhorn-unbalanced:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp_unbalanced : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_stabilized_unbalanced:
Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced>`
ot.unbalanced.sinkhorn_reg_scaling_unbalanced:
Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced>`
"""
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def sinkhorn_unbalanced2(a, b, M, reg, reg_m, method='sinkhorn',
reg_type="entropy", warmstart=None, numItermax=1000,
stopThr=1e-6, verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and
return the loss
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot \Omega(\gamma) +
\mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma\geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-unbalanced2>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
The entropic balanced OT can be recovered using `reg_m=float("inf")`.
For semi-relaxed case, use either
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
method : str
method used for the solver either 'sinkhorn', 'sinkhorn_stabilized' or
'sinkhorn_reg_scaling', see those function for specific parameterss
reg_type : string, optional
Regularizer term. Can take two values:
'entropy' (negative entropy)
:math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or
'kl' (Kullback-Leibler)
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors).
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> import numpy as np
>>> a=[.5, .10]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
>>> np.round(ot.unbalanced.sinkhorn_unbalanced2(a, b, M, 1., 1.), 8)
0.31912858
.. _references-sinkhorn-unbalanced2:
References
----------
.. [2] M. Cuturi, Sinkhorn Distances : Lightspeed Computation of Optimal
Transport, Advances in Neural Information Processing Systems
(NIPS) 26, 2013
.. [9] Schmitzer, B. (2016). Stabilized Sparse Scaling Algorithms for
Entropy Regularized Transport Problems. arXiv preprint arXiv:1610.06519.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
See Also
--------
ot.unbalanced.sinkhorn_knopp : Unbalanced Classic Sinkhorn :ref:`[10] <references-sinkhorn-unbalanced2>`
ot.unbalanced.sinkhorn_stabilized: Unbalanced Stabilized sinkhorn :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
ot.unbalanced.sinkhorn_reg_scaling: Unbalanced Sinkhorn with epsilon scaling :ref:`[9, 10] <references-sinkhorn-unbalanced2>`
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
if len(b.shape) < 2:
if method.lower() == 'sinkhorn':
res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
res = sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
res = sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
raise ValueError('Unknown method %s.' % method)
if log:
return nx.sum(M * res[0]), res[1]
else:
return nx.sum(M * res)
else:
if method.lower() == 'sinkhorn':
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn-Knopp')
return sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type,
warmstart, numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
raise ValueError('Unknown method %s.' % method)
[docs]
def sinkhorn_knopp_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
warmstart=None, numItermax=1000, stopThr=1e-6,
verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport problem and
return the OT plan
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot \Omega(\gamma) +
\mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma \geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-knopp-unbalanced>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`
If many, compute all the OT distances (a, b_i)
M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
The entropic balanced OT can be recovered using `reg_m=float("inf")`.
For semi-relaxed case, use either
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
reg_type : string, optional
Regularizer term. Can take two values:
'entropy' (negative entropy)
:math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or
'kl' (Kullback-Leibler)
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors).
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
if n_hists == 1:
- gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
>>> ot.unbalanced.sinkhorn_knopp_unbalanced(a, b, M, 1., 1.)
array([[0.51122814, 0.18807032],
[0.18807032, 0.51122814]])
.. _references-sinkhorn-knopp-unbalanced:
References
----------
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
reg_m1, reg_m2 = get_parameter_pair(reg_m)
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
if n_hists:
u = nx.ones((dim_a, 1), type_as=M)
v = nx.ones((dim_b, n_hists), type_as=M)
a = a.reshape(dim_a, 1)
else:
u = nx.ones(dim_a, type_as=M)
v = nx.ones(dim_b, type_as=M)
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
if reg_type == "kl":
K = nx.exp(-M / reg) * a.reshape(-1)[:, None] * b.reshape(-1)[None, :]
elif reg_type == "entropy":
K = nx.exp(-M / reg)
fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1
fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1
err = 1.
for i in range(numItermax):
uprev = u
vprev = v
Kv = nx.dot(K, v)
u = (a / Kv) ** fi_1
Ktu = nx.dot(K.T, u)
v = (b / Ktu) ** fi_2
if (nx.any(Ktu == 0.)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
break
err_u = nx.max(nx.abs(u - uprev)) / max(
nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
)
err_v = nx.max(nx.abs(v - vprev)) / max(
nx.max(nx.abs(v)), nx.max(nx.abs(vprev)), 1.
)
err = 0.5 * (err_u + err_v)
if log:
log['err'].append(err)
if verbose:
if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(i, err))
if err < stopThr:
break
if log:
log['logu'] = nx.log(u + 1e-300)
log['logv'] = nx.log(v + 1e-300)
if n_hists: # return only loss
res = nx.einsum('ik,ij,jk,ij->k', u, K, v, M)
if log:
return res, log
else:
return res
else: # return OT matrix
if log:
return u[:, None] * K * v[None, :], log
else:
return u[:, None] * K * v[None, :]
[docs]
def sinkhorn_stabilized_unbalanced(a, b, M, reg, reg_m, reg_type="entropy",
warmstart=None, tau=1e5,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False, **kwargs):
r"""
Solve the entropic regularization unbalanced optimal transport
problem and return the loss
The function solves the following optimization problem using log-domain
stabilization as proposed in :ref:`[10] <references-sinkhorn-stabilized-unbalanced>`:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg} \cdot \Omega(\gamma) +
\mathrm{reg_{m1}} \cdot \mathrm{KL}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{KL}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma \geq 0
where :
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\Omega` is the entropic regularization term, can be either KL divergence or negative entropy
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target unbalanced distributions
- KL is the Kullback-Leibler divergence
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10, 25] <references-sinkhorn-stabilized-unbalanced>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,) or array-like (dim_b, n_hists)
One or multiple unnormalized histograms of dimension `dim_b`.
If many, compute all the OT distances :math:`(\mathbf{a}, \mathbf{b}_i)_i`
M : array-like (dim_a, dim_b)
loss matrix
reg : float
Entropy regularization term > 0
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
The entropic balanced OT can be recovered using `reg_m=float("inf")`.
For semi-relaxed case, use either
`reg_m=(float("inf"), scalar)` or `reg_m=(scalar, float("inf"))`.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
reg_type : string, optional
Regularizer term. Can take two values:
'entropy' (negative entropy)
:math:`\Omega(\gamma) = \sum_{i,j} \gamma_{i,j} \log(\gamma_{i,j}) - \sum_{i,j} \gamma_{i,j}`, or
'kl' (Kullback-Leibler)
:math:`\Omega(\gamma) = \text{KL}(\gamma, \mathbf{a} \mathbf{b}^T)`.
warmstart: tuple of arrays, shape (dim_a, dim_b), optional
Initialization of dual potentials. If provided, the dual potentials should be given
(that is the logarithm of the u,v sinkhorn scaling vectors).
tau : float
threshold for max value in u or v for log scaling
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
if n_hists == 1:
- gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
- log : dict
log dictionary returned only if `log` is `True`
else:
- ot_distance : (n_hists,) array-like
the OT distance between :math:`\mathbf{a}` and each of the histograms :math:`\mathbf{b}_i`
- log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[0., 1.],[1., 0.]]
>>> ot.unbalanced.sinkhorn_stabilized_unbalanced(a, b, M, 1., 1.)
array([[0.51122814, 0.18807032],
[0.18807032, 0.51122814]])
.. _references-sinkhorn-stabilized-unbalanced:
References
----------
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint arXiv:1607.05816.
.. [25] Frogner C., Zhang C., Mobahi H., Araya-Polo M., Poggio T. :
Learning with a Wasserstein Loss, Advances in Neural Information
Processing Systems (NIPS) 2015
See Also
--------
ot.lp.emd : Unregularized OT
ot.optim.cg : General regularized OT
"""
a, b, M = list_to_array(a, b, M)
nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
b = nx.ones(dim_b, type_as=M) / dim_b
if len(b.shape) > 1:
n_hists = b.shape[1]
else:
n_hists = 0
reg_m1, reg_m2 = get_parameter_pair(reg_m)
if log:
log = {'err': []}
# we assume that no distances are null except those of the diagonal of
# distances
if warmstart is None:
if n_hists:
u = nx.ones((dim_a, n_hists), type_as=M)
v = nx.ones((dim_b, n_hists), type_as=M)
a = a.reshape(dim_a, 1)
else:
u = nx.ones(dim_a, type_as=M)
v = nx.ones(dim_b, type_as=M)
else:
u, v = nx.exp(warmstart[0]), nx.exp(warmstart[1])
if reg_type == "kl":
log_ab = nx.log(a + 1e-16).reshape(-1)[:, None] + nx.log(b + 1e-16).reshape(-1)[None, :]
M0 = M - reg * log_ab
else:
M0 = M
K = nx.exp(-M0 / reg)
fi_1 = reg_m1 / (reg_m1 + reg) if reg_m1 != float("inf") else 1
fi_2 = reg_m2 / (reg_m2 + reg) if reg_m2 != float("inf") else 1
cpt = 0
err = 1.
alpha = nx.zeros(dim_a, type_as=M)
beta = nx.zeros(dim_b, type_as=M)
ones_a = nx.ones(dim_a, type_as=M)
ones_b = nx.ones(dim_b, type_as=M)
while (err > stopThr and cpt < numItermax):
uprev = u
vprev = v
Kv = nx.dot(K, v)
f_alpha = nx.exp(- alpha / (reg + reg_m1)) if reg_m1 != float("inf") else ones_a
f_beta = nx.exp(- beta / (reg + reg_m2)) if reg_m2 != float("inf") else ones_b
if n_hists:
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((a / (Kv + 1e-16)) ** fi_1) * f_alpha
Ktu = nx.dot(K.T, u)
v = ((b / (Ktu + 1e-16)) ** fi_2) * f_beta
absorbing = False
if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
if n_hists:
alpha = alpha + reg * nx.log(nx.max(u, 1))
beta = beta + reg * nx.log(nx.max(v, 1))
else:
alpha = alpha + reg * nx.log(nx.max(u))
beta = beta + reg * nx.log(nx.max(v))
K = nx.exp((alpha[:, None] + beta[None, :] - M0) / reg)
v = nx.ones(v.shape, type_as=v)
Kv = nx.dot(K, v)
if (nx.any(Ktu == 0.)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
u = uprev
v = vprev
break
if (cpt % 10 == 0 and not absorbing) or cpt == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = nx.max(nx.abs(u - uprev)) / max(
nx.max(nx.abs(u)), nx.max(nx.abs(uprev)), 1.
)
if log:
log['err'].append(err)
if verbose:
if cpt % 200 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(cpt, err))
cpt = cpt + 1
if err > stopThr:
warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if n_hists:
logu = alpha[:, None] / reg + nx.log(u)
logv = beta[:, None] / reg + nx.log(v)
else:
logu = alpha / reg + nx.log(u)
logv = beta / reg + nx.log(v)
if log:
log['logu'] = logu
log['logv'] = logv
if n_hists: # return only loss
res = nx.logsumexp(
nx.log(M + 1e-100)[:, :, None]
+ logu[:, None, :]
+ logv[None, :, :]
- M0[:, :, None] / reg,
axis=(0, 1)
)
res = nx.exp(res)
if log:
return res, log
else:
return res
else: # return OT matrix
ot_matrix = nx.exp(logu[:, None] + logv[None, :] - M0 / reg)
if log:
return ot_matrix, log
else:
return ot_matrix
[docs]
def barycenter_unbalanced_stabilized(A, M, reg, reg_m, weights=None, tau=1e3,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}` with stabilization.
The function solves the following optimization problem:
.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-stabilized>`
Parameters
----------
A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m : float
Marginal relaxation term > 0
tau : float
Stabilization threshold for log domain absorption.
weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coordinates)
If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-barycenter-unbalanced-stabilized:
References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré,
G. (2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprint
arXiv:1607.05816.
"""
A, M = list_to_array(A, M)
nx = get_backend(A, M)
dim, n_hists = A.shape
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
fi = reg_m / (reg_m + reg)
u = nx.ones((dim, n_hists), type_as=A) / dim
v = nx.ones((dim, n_hists), type_as=A) / dim
# print(reg)
K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
cpt = 0
err = 1.
alpha = nx.zeros(dim, type_as=A)
beta = nx.zeros(dim, type_as=A)
q = nx.ones(dim, type_as=A) / dim
for i in range(numItermax):
qprev = nx.copy(q)
Kv = nx.dot(K, v)
f_alpha = nx.exp(- alpha / (reg + reg_m))
f_beta = nx.exp(- beta / (reg + reg_m))
f_alpha = f_alpha[:, None]
f_beta = f_beta[:, None]
u = ((A / (Kv + 1e-16)) ** fi) * f_alpha
Ktu = nx.dot(K.T, u)
q = (Ktu ** (1 - fi)) * f_beta
q = nx.dot(q, weights) ** (1 / (1 - fi))
Q = q[:, None]
v = ((Q / (Ktu + 1e-16)) ** fi) * f_beta
absorbing = False
if nx.any(u > tau) or nx.any(v > tau):
absorbing = True
alpha = alpha + reg * nx.log(nx.max(u, 1))
beta = beta + reg * nx.log(nx.max(v, 1))
K = nx.exp((alpha[:, None] + beta[None, :] - M) / reg)
v = nx.ones(v.shape, type_as=v)
Kv = nx.dot(K, v)
if (nx.any(Ktu == 0.)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % cpt)
q = qprev
break
if (i % 10 == 0 and not absorbing) or i == 0:
# we can speed up the process by checking for the error only all
# the 10th iterations
err = nx.max(nx.abs(q - qprev)) / max(
nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.
)
if log:
log['err'].append(err)
if verbose:
if i % 50 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(i, err))
if err < stopThr:
break
if err > stopThr:
warnings.warn("Stabilized Unbalanced Sinkhorn did not converge." +
"Try a larger entropy `reg` or a lower mass `reg_m`." +
"Or a larger absorption threshold `tau`.")
if log:
log['niter'] = i
log['logu'] = nx.log(u + 1e-300)
log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
[docs]
def barycenter_unbalanced_sinkhorn(A, M, reg, reg_m, weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False):
r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced-sinkhorn>`
Parameters
----------
A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-barycenter-unbalanced-sinkhorn:
References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprin
arXiv:1607.05816.
"""
A, M = list_to_array(A, M)
nx = get_backend(A, M)
dim, n_hists = A.shape
if weights is None:
weights = nx.ones(n_hists, type_as=A) / n_hists
else:
assert (len(weights) == A.shape[1])
if log:
log = {'err': []}
K = nx.exp(-M / reg)
fi = reg_m / (reg_m + reg)
v = nx.ones((dim, n_hists), type_as=A)
u = nx.ones((dim, 1), type_as=A)
q = nx.ones(dim, type_as=A)
err = 1.
for i in range(numItermax):
uprev = nx.copy(u)
vprev = nx.copy(v)
qprev = nx.copy(q)
Kv = nx.dot(K, v)
u = (A / Kv) ** fi
Ktu = nx.dot(K.T, u)
q = nx.dot(Ktu ** (1 - fi), weights)
q = q ** (1 / (1 - fi))
Q = q[:, None]
v = (Q / Ktu) ** fi
if (nx.any(Ktu == 0.)
or nx.any(nx.isnan(u)) or nx.any(nx.isnan(v))
or nx.any(nx.isinf(u)) or nx.any(nx.isinf(v))):
# we have reached the machine precision
# come back to previous solution and quit loop
warnings.warn('Numerical errors at iteration %s' % i)
u = uprev
v = vprev
q = qprev
break
# compute change in barycenter
err = nx.max(nx.abs(q - qprev)) / max(
nx.max(nx.abs(q)), nx.max(nx.abs(qprev)), 1.0
)
if log:
log['err'].append(err)
# if barycenter did not change + at least 10 iterations - stop
if err < stopThr and i > 10:
break
if verbose:
if i % 10 == 0:
print(
'{:5s}|{:12s}'.format('It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(i, err))
if log:
log['niter'] = i
log['logu'] = nx.log(u + 1e-300)
log['logv'] = nx.log(v + 1e-300)
return q, log
else:
return q
[docs]
def barycenter_unbalanced(A, M, reg, reg_m, method="sinkhorn", weights=None,
numItermax=1000, stopThr=1e-6,
verbose=False, log=False, **kwargs):
r"""Compute the entropic unbalanced wasserstein barycenter of :math:`\mathbf{A}`.
The function solves the following optimization problem with :math:`\mathbf{a}`
.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{u_{reg}}(\mathbf{a},\mathbf{a}_i)
where :
- :math:`W_{u_{reg}}(\cdot,\cdot)` is the unbalanced entropic regularized Wasserstein distance (see :py:func:`ot.unbalanced.sinkhorn_unbalanced`)
- :math:`\mathbf{a}_i` are training distributions in the columns of matrix :math:`\mathbf{A}`
- reg and :math:`\mathbf{M}` are respectively the regularization term and the cost matrix for OT
- reg_mis the marginal relaxation hyperparameter
The algorithm used for solving the problem is the generalized
Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[10] <references-barycenter-unbalanced>`
Parameters
----------
A : array-like (dim, n_hists)
`n_hists` training distributions :math:`\mathbf{a}_i` of dimension `dim`
M : array-like (dim, dim)
ground metric matrix for OT.
reg : float
Entropy regularization term > 0
reg_m: float
Marginal relaxation term > 0
weights : array-like (n_hists,) optional
Weight of each distribution (barycentric coodinates)
If None, uniform weights are used.
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
a : (dim,) array-like
Unbalanced Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-barycenter-unbalanced:
References
----------
.. [3] Benamou, J. D., Carlier, G., Cuturi, M., Nenna, L., & Peyré, G.
(2015). Iterative Bregman projections for regularized transportation
problems. SIAM Journal on Scientific Computing, 37(2), A1111-A1138.
.. [10] Chizat, L., Peyré, G., Schmitzer, B., & Vialard, F. X. (2016).
Scaling algorithms for unbalanced transport problems. arXiv preprin
arXiv:1607.05816.
"""
if method.lower() == 'sinkhorn':
return barycenter_unbalanced_sinkhorn(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
elif method.lower() == 'sinkhorn_stabilized':
return barycenter_unbalanced_stabilized(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr,
verbose=verbose,
log=log, **kwargs)
elif method.lower() in ['sinkhorn_reg_scaling']:
warnings.warn('Method not implemented yet. Using classic Sinkhorn Knopp')
return barycenter_unbalanced(A, M, reg, reg_m,
weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, **kwargs)
else:
raise ValueError("Unknown method '%s'." % method)
[docs]
def mm_unbalanced(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000,
stopThr=1e-15, verbose=False, log=False):
r"""
Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) +
\mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c})
s.t.
\gamma \geq 0
where:
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
unbalanced distributions
- :math:`\mathbf{c}` is a reference distribution for the regularization
- div is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
The algorithm used for solving the problem is a maximization-
minimization algorithm as proposed in :ref:`[41] <references-regpath>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,)
Unnormalized histogram of dimension `dim_b`
M : array-like (dim_a, dim_b)
loss matrix
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term >= 0, but cannot be infinity.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
reg : float, optional (default = 0)
Regularization term >= 0.
By default, solve the unregularized problem
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
div: string, optional
Divergence to quantify the difference between the marginals.
Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> import numpy as np
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[1., 36.],[9., 4.]]
>>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='kl'), 2)
array([[0.45, 0. ],
[0. , 0.34]])
>>> np.round(ot.unbalanced.mm_unbalanced(a, b, M, 5, div='l2'), 2)
array([[0.4, 0. ],
[0. , 0.1]])
.. _references-regpath:
References
----------
.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
See Also
--------
ot.lp.emd : Unregularized OT
ot.unbalanced.sinkhorn_unbalanced : Entropic regularized OT
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
dim_a, dim_b = M.shape
if len(a) == 0:
a = nx.ones(dim_a, type_as=M) / dim_a
if len(b) == 0:
b = nx.ones(dim_b, type_as=M) / dim_b
G = a[:, None] * b[None, :] if G0 is None else G0
c = a[:, None] * b[None, :] if c is None else c
reg_m1, reg_m2 = get_parameter_pair(reg_m)
if log:
log = {'err': [], 'G': []}
if div not in ["kl", "l2"]:
warnings.warn("The div parameter should be either equal to 'kl' or \
'l2': it has been set to 'kl'.")
div = 'kl'
if div == 'kl':
sum_r = reg + reg_m1 + reg_m2
r1, r2, r = reg_m1 / sum_r, reg_m2 / sum_r, reg / sum_r
K = (a[:, None]**r1) * (b[None, :]**r2) * (c**r) * nx.exp(- M / sum_r)
elif div == 'l2':
K = reg_m1 * a[:, None] + reg_m2 * b[None, :] + reg * c - M
K = nx.maximum(K, nx.zeros((dim_a, dim_b), type_as=M))
for i in range(numItermax):
Gprev = G
if div == 'kl':
Gd = (nx.sum(G, 1, keepdims=True)**r1) * (nx.sum(G, 0, keepdims=True)**r2) + 1e-16
G = K * G**(r1 + r2) / Gd
elif div == 'l2':
Gd = reg_m1 * nx.sum(G, 1, keepdims=True) + \
reg_m2 * nx.sum(G, 0, keepdims=True) + reg * G + 1e-16
G = K * G / Gd
err = nx.sqrt(nx.sum((G - Gprev) ** 2))
if log:
log['err'].append(err)
log['G'].append(G)
if verbose:
print('{:5d}|{:8e}|'.format(i, err))
if err < stopThr:
break
if log:
log['cost'] = nx.sum(G * M)
return G, log
else:
return G
[docs]
def mm_unbalanced2(a, b, M, reg_m, c=None, reg=0, div='kl', G0=None, numItermax=1000,
stopThr=1e-15, verbose=False, log=False):
r"""
Solve the unbalanced optimal transport problem and return the OT plan.
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
\mathrm{reg_{m1}} \cdot \mathrm{div}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b}) +
\mathrm{reg} \cdot \mathrm{div}(\gamma, \mathbf{c})
s.t.
\gamma \geq 0
where:
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
unbalanced distributions
- :math:`\mathbf{c}` is a reference distribution for the regularization
- :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
The algorithm used for solving the problem is a maximization-
minimization algorithm as proposed in :ref:`[41] <references-regpath>`
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,)
Unnormalized histogram of dimension `dim_b`
M : array-like (dim_a, dim_b)
loss matrix
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term >= 0, but cannot be infinity.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
If reg_m is an array, it must have the same backend as input arrays (a, b, M).
reg : float, optional (default = 0)
Entropy regularization term >= 0.
By default, solve the unregularized problem
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use `\mathbf{c} = mathbf{a} mathbf{b}^T`.
div: string, optional
Divergence to quantify the difference between the marginals.
Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
ot_distance : array-like
the OT distance between :math:`\mathbf{a}` and :math:`\mathbf{b}`
log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> import numpy as np
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[1., 36.],[9., 4.]]
>>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='l2'), 2)
0.8
>>> np.round(ot.unbalanced.mm_unbalanced2(a, b, M, 5, div='kl'), 2)
1.79
References
----------
.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
See Also
--------
ot.lp.emd2 : Unregularized OT loss
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
"""
_, log_mm = mm_unbalanced(a, b, M, reg_m, c=c, reg=reg, div=div, G0=G0,
numItermax=numItermax, stopThr=stopThr,
verbose=verbose, log=True)
if log:
return log_mm['cost'], log_mm
else:
return log_mm['cost']
def _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div='kl', regm_div='kl'):
"""
return the loss function (scipy.optimize compatible) for regularized
unbalanced OT
"""
m, n = M.shape
def kl(p, q):
return np.sum(p * np.log(p / q + 1e-16)) - np.sum(p) + np.sum(q)
def reg_l2(G):
return np.sum((G - c)**2) / 2
def grad_l2(G):
return G - c
def reg_kl(G):
return kl(G, c)
def grad_kl(G):
return np.log(G / c + 1e-16)
def reg_entropy(G):
return np.sum(G * np.log(G + 1e-16)) - np.sum(G)
def grad_entropy(G):
return np.log(G + 1e-16)
if reg_div == 'kl':
reg_fun = reg_kl
grad_reg_fun = grad_kl
elif reg_div == 'entropy':
reg_fun = reg_entropy
grad_reg_fun = grad_entropy
else:
reg_fun = reg_l2
grad_reg_fun = grad_l2
def marg_l2(G):
return reg_m1 * 0.5 * np.sum((G.sum(1) - a)**2) + \
reg_m2 * 0.5 * np.sum((G.sum(0) - b)**2)
def grad_marg_l2(G):
return reg_m1 * np.outer((G.sum(1) - a), np.ones(n)) + \
reg_m2 * np.outer(np.ones(m), (G.sum(0) - b))
def marg_kl(G):
return reg_m1 * kl(G.sum(1), a) + reg_m2 * kl(G.sum(0), b)
def grad_marg_kl(G):
return reg_m1 * np.outer(np.log(G.sum(1) / a + 1e-16), np.ones(n)) + \
reg_m2 * np.outer(np.ones(m), np.log(G.sum(0) / b + 1e-16))
if regm_div == 'kl':
regm_fun = marg_kl
grad_regm_fun = grad_marg_kl
else:
regm_fun = marg_l2
grad_regm_fun = grad_marg_l2
def _func(G):
G = G.reshape((m, n))
# compute loss
val = np.sum(G * M) + reg * reg_fun(G) + regm_fun(G)
# compute gradient
grad = M + reg * grad_reg_fun(G) + grad_regm_fun(G)
return val, grad.ravel()
return _func
[docs]
def lbfgsb_unbalanced(a, b, M, reg, reg_m, c=None, reg_div='kl', regm_div='kl', G0=None, numItermax=1000,
stopThr=1e-15, method='L-BFGS-B', verbose=False, log=False):
r"""
Solve the unbalanced optimal transport problem and return the OT plan using L-BFGS-B.
The function solves the following optimization problem:
.. math::
W = \min_\gamma \quad \langle \gamma, \mathbf{M} \rangle_F +
+ \mathrm{reg} \mathrm{div}(\gamma, \mathbf{c})
\mathrm{reg_{m1}} \cdot \mathrm{div_m}(\gamma \mathbf{1}, \mathbf{a}) +
\mathrm{reg_{m2}} \cdot \mathrm{div}(\gamma^T \mathbf{1}, \mathbf{b})
s.t.
\gamma \geq 0
where:
- :math:`\mathbf{M}` is the (`dim_a`, `dim_b`) metric cost matrix
- :math:`\mathbf{a}` and :math:`\mathbf{b}` are source and target
unbalanced distributions
- :math:`\mathbf{c}` is a reference distribution for the regularization
- :math:`\mathrm{div}` is a divergence, either Kullback-Leibler or :math:`\ell_2` divergence
The algorithm used for solving the problem is a L-BFGS-B from scipy.optimize
Parameters
----------
a : array-like (dim_a,)
Unnormalized histogram of dimension `dim_a`
b : array-like (dim_b,)
Unnormalized histogram of dimension `dim_b`
M : array-like (dim_a, dim_b)
loss matrix
reg: float
regularization term >=0
c : array-like (dim_a, dim_b), optional (default = None)
Reference measure for the regularization.
If None, then use `\mathbf{c} = \mathbf{a} \mathbf{b}^T`.
reg_m: float or indexable object of length 1 or 2
Marginal relaxation term >= 0, but cannot be infinity.
If reg_m is a scalar or an indexable object of length 1,
then the same reg_m is applied to both marginal relaxations.
If reg_m is an array, it must be a Numpy array.
reg_div: string, optional
Divergence used for regularization.
Can take three values: 'entropy' (negative entropy), or
'kl' (Kullback-Leibler) or 'l2' (quadratic).
regm_div: string, optional
Divergence to quantify the difference between the marginals.
Can take two values: 'kl' (Kullback-Leibler) or 'l2' (quadratic)
G0: array-like (dim_a, dim_b)
Initialization of the transport matrix
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
Returns
-------
gamma : (dim_a, dim_b) array-like
Optimal transportation matrix for the given parameters
log : dict
log dictionary returned only if `log` is `True`
Examples
--------
>>> import ot
>>> import numpy as np
>>> a=[.5, .5]
>>> b=[.5, .5]
>>> M=[[1., 36.],[9., 4.]]
>>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='kl', regm_div='kl'), 2)
array([[0.45, 0. ],
[0. , 0.34]])
>>> np.round(ot.unbalanced.lbfgsb_unbalanced(a, b, M, reg=0, reg_m=5, reg_div='l2', regm_div='l2'), 2)
array([[0.4, 0. ],
[0. , 0.1]])
References
----------
.. [41] Chapel, L., Flamary, R., Wu, H., Févotte, C., and Gasso, G. (2021).
Unbalanced optimal transport through non-negative penalized
linear regression. NeurIPS.
See Also
--------
ot.lp.emd2 : Unregularized OT loss
ot.unbalanced.sinkhorn_unbalanced2 : Entropic regularized OT loss
"""
M, a, b = list_to_array(M, a, b)
nx = get_backend(M, a, b)
M0 = M
# convert to numpy
a, b, M = nx.to_numpy(a, b, M)
G0 = np.zeros(M.shape) if G0 is None else nx.to_numpy(G0)
c = a[:, None] * b[None, :] if c is None else nx.to_numpy(c)
reg_m1, reg_m2 = get_parameter_pair(reg_m)
_func = _get_loss_unbalanced(a, b, c, M, reg, reg_m1, reg_m2, reg_div, regm_div)
res = minimize(_func, G0.ravel(), method=method, jac=True, bounds=Bounds(0, np.inf),
tol=stopThr, options=dict(maxiter=numItermax, disp=verbose))
G = nx.from_numpy(res.x.reshape(M.shape), type_as=M0)
if log:
log = {'loss': nx.from_numpy(res.fun, type_as=M0), 'res': res}
return G, log
else:
return G