# -*- coding: utf-8 -*-
"""
Dictionary Learning based on Bregman projections for entropic regularized OT
"""
# Author: Remi Flamary <remi.flamary@unice.fr>
# Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License
import warnings
from ..utils import list_to_array
from ..backend import get_backend
from ._utils import projC, projR
[docs]
def unmix(
a,
D,
M,
M0,
h0,
reg,
reg0,
alpha,
numItermax=1000,
stopThr=1e-3,
verbose=False,
log=False,
warn=True,
):
r"""
Compute the unmixing of an observation with a given dictionary using Wasserstein distance
The function solve the following optimization problem:
.. math::
\mathbf{h} = \mathop{\arg \min}_\mathbf{h} \quad
(1 - \alpha) W_{\mathbf{M}, \mathrm{reg}}(\mathbf{a}, \mathbf{Dh}) +
\alpha W_{\mathbf{M_0}, \mathrm{reg}_0}(\mathbf{h}_0, \mathbf{h})
where :
- :math:`W_{M,reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance
with :math:`\mathbf{M}` loss matrix (see :py:func:`ot.bregman.sinkhorn`)
- :math:`\mathbf{D}` is a dictionary of `n_atoms` atoms of dimension `dim_a`,
its expected shape is `(dim_a, n_atoms)`
- :math:`\mathbf{h}` is the estimated unmixing of dimension `n_atoms`
- :math:`\mathbf{a}` is an observed distribution of dimension `dim_a`
- :math:`\mathbf{h}_0` is a prior on :math:`\mathbf{h}` of dimension `dim_prior`
- `reg` and :math:`\mathbf{M}` are respectively the regularization term and the
cost matrix (`dim_a`, `dim_a`) for OT data fitting
- `reg`:math:`_0` and :math:`\mathbf{M_0}` are respectively the regularization
term and the cost matrix (`dim_prior`, `n_atoms`) regularization
- :math:`\alpha` weight data fitting and regularization
The optimization problem is solved following the algorithm described
in :ref:`[4] <references-unmix>`
Parameters
----------
a : array-like, shape (dim_a)
observed distribution (histogram, sums to 1)
D : array-like, shape (dim_a, n_atoms)
dictionary matrix
M : array-like, shape (dim_a, dim_a)
loss matrix
M0 : array-like, shape (n_atoms, dim_prior)
loss matrix
h0 : array-like, shape (n_atoms,)
prior on the estimated unmixing h
reg : float
Regularization term >0 (Wasserstein data fitting)
reg0 : float
Regularization term >0 (Wasserstein reg with h0)
alpha : float
How much should we trust the prior ([0,1])
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (>0)
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.
Returns
-------
h : array-like, shape (n_atoms,)
Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters
.. _references-unmix:
References
----------
.. [4] S. Nakhostin, N. Courty, R. Flamary, D. Tuia, T. Corpetti,
Supervised planetary unmixing with optimal transport, Workshop
on Hyperspectral Image and Signal Processing :
Evolution in Remote Sensing (WHISPERS), 2016.
"""
a, D, M, M0, h0 = list_to_array(a, D, M, M0, h0)
nx = get_backend(a, D, M, M0, h0)
# M = M/np.median(M)
K = nx.exp(-M / reg)
# M0 = M0/np.median(M0)
K0 = nx.exp(-M0 / reg0)
old = h0
err = 1
# log = {'niter':0, 'all_err':[]}
if log:
log = {"err": []}
for ii in range(numItermax):
K = projC(K, a)
K0 = projC(K0, h0)
new = nx.sum(K0, axis=1)
# we recombine the current selection from dictionary
inv_new = nx.dot(D, new)
other = nx.sum(K, axis=1)
# geometric interpolation
delta = nx.exp(alpha * nx.log(other) + (1 - alpha) * nx.log(inv_new))
K = projR(K, delta)
K0 = nx.dot(D.T, delta / inv_new)[:, None] * K0
err = nx.norm(nx.sum(K0, axis=1) - old)
old = new
if log:
log["err"].append(err)
if verbose:
if ii % 200 == 0:
print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
print("{:5d}|{:8e}|".format(ii, err))
if err < stopThr:
break
else:
if warn:
warnings.warn(
"Unmixing algorithm did not converge. You might want to "
"increase the number of iterations `numItermax` "
"or the regularization parameter `reg`."
)
if log:
log["niter"] = ii
return nx.sum(K0, axis=1), log
else:
return nx.sum(K0, axis=1)