# Source code for ot.bregman._convolutional

# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized Wasserstein convolutional barycenters
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#

import warnings

from ..utils import list_to_array
from ..backend import get_backend

def convolutional_barycenter2d(A, reg, weights=None, method="sinkhorn", numItermax=10000,
stopThr=1e-4, verbose=False, log=False,
warn=True, **kwargs):
r"""Compute the entropic regularized wasserstein barycenter of distributions :math:\mathbf{A}
where :math:\mathbf{A} is a collection of 2D images.

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:W_{reg}(\cdot,\cdot) is the entropic regularized Wasserstein
distance (see :py:func:ot.bregman.sinkhorn)
- :math:\mathbf{a}_i are training distributions (2D images) in the mast two dimensions
of matrix :math:\mathbf{A}
- reg is the regularization strength scalar value

The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm
as proposed in :ref:[21] <references-convolutional-barycenter-2d>

Parameters
----------
A : array-like, shape (n_hists, width, height)
n distributions (2D images) of size width x height
reg : float
Regularization term >0
weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
method : string, optional
method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-convolutional-barycenter-2d:
References
----------

.. [21] Solomon, J., De Goes, F., PeyrÃ©, G., Cuturi, M., Butscher,
A., Nguyen, A. & Guibas, L. (2015).     Convolutional wasserstein distances:
Efficient optimal transportation on geometric domains. ACM Transactions
on Graphics (TOG), 34(4), 66

.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th
International Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

if method.lower() == 'sinkhorn':
return _convolutional_barycenter2d(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return _convolutional_barycenter2d_log(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

def _convolutional_barycenter2d(A, reg, weights=None, numItermax=10000,
stopThr=1e-9, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images.
"""

A = list_to_array(A)

nx = get_backend(A)

if weights is None:
weights = nx.ones((A.shape[0],), type_as=A) / A.shape[0]
else:
assert (len(weights) == A.shape[0])

if log:
log = {'err': []}

bar = nx.ones(A.shape[1:], type_as=A)
bar /= nx.sum(bar)
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, A.shape[1], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, A.shape[2], type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

KU = convol_imgs(U)
for ii in range(numItermax):
V = bar[None] / KU
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
bar = nx.exp(
nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
)
if ii % 10 == 9:
err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
log['U'] = U
return bar, log
else:
return bar

def _convolutional_barycenter2d_log(A, reg, weights=None, numItermax=10000,
stopThr=1e-4, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the entropic regularized wasserstein barycenter of distributions A
where A is a collection of 2D images in log-domain.
"""

A = list_to_array(A)

nx = get_backend(A)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)

n_hists, width, height = A.shape

if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == n_hists)

if log:
log = {'err': []}

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img

logA = nx.log(A + stabThr)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros((width, height), type_as=A)
for k in range(n_hists):
f = logA[k] - convol_img(G[k])
log_KU[k] = convol_img(f)
log_bar = log_bar + weights[k] * log_KU[k]

if ii % 10 == 9:
err = nx.exp(G + log_KU).std(axis=0).sum()
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)

def convolutional_barycenter2d_debiased(A, reg, weights=None, method="sinkhorn",
numItermax=10000, stopThr=1e-3,
verbose=False, log=False, warn=True,
**kwargs):
r"""Compute the debiased sinkhorn barycenter of distributions :math:\mathbf{A}
where :math:\mathbf{A} is a collection of 2D images.

The function solves the following optimization problem:

.. math::
\mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i)

where :

- :math:S_{reg}(\cdot,\cdot) is the debiased entropic regularized Wasserstein
distance (see :py:func:ot.bregman.barycenter_debiased)
- :math:\mathbf{a}_i are training distributions (2D images) in the mast two
dimensions of matrix :math:\mathbf{A}
- reg is the regularization strength scalar value

The algorithm used for solving the problem is the debiased Sinkhorn scaling
algorithm as proposed in :ref:[37] <references-convolutional-barycenter2d-debiased>

Parameters
----------
A : array-like, shape (n_hists, width, height)
n distributions (2D images) of size width x height
reg : float
Regularization term >0
weights : array-like, shape (n_hists,)
Weights of each image on the simplex (barycentric coodinates)
method : string, optional
method used for the solver either 'sinkhorn' or 'sinkhorn_log'
numItermax : int, optional
Max number of iterations
stopThr : float, optional
Stop threshold on error (> 0)
stabThr : float, optional
Stabilization threshold to avoid numerical precision issue
verbose : bool, optional
Print information along iterations
log : bool, optional
record log if True
warn : bool, optional
if True, raises a warning if the algorithm doesn't convergence.

Returns
-------
a : array-like, shape (width, height)
2D Wasserstein barycenter
log : dict
log dictionary return only if log==True in parameters

.. _references-convolutional-barycenter2d-debiased:
References
----------

.. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International
Conference on Machine Learning, PMLR 119:4692-4701, 2020
"""

if method.lower() == 'sinkhorn':
return _convolutional_barycenter2d_debiased(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
elif method.lower() == 'sinkhorn_log':
return _convolutional_barycenter2d_debiased_log(A, reg, weights=weights,
numItermax=numItermax,
stopThr=stopThr, verbose=verbose,
log=log, warn=warn,
**kwargs)
else:
raise ValueError("Unknown method '%s'." % method)

def _convolutional_barycenter2d_debiased(A, reg, weights=None, numItermax=10000,
stopThr=1e-3, stabThr=1e-15, verbose=False,
log=False, warn=True):
r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.
"""

A = list_to_array(A)
n_hists, width, height = A.shape

nx = get_backend(A)

if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == n_hists)

if log:
log = {'err': []}

bar = nx.ones((width, height), type_as=A)
bar /= width * height
U = nx.ones(A.shape, type_as=A)
V = nx.ones(A.shape, type_as=A)
c = nx.ones(A.shape[1:], type_as=A)
err = 1

# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K1 = nx.exp(-(X - Y) ** 2 / reg)

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
K2 = nx.exp(-(X - Y) ** 2 / reg)

def convol_imgs(imgs):
kx = nx.einsum("...ij,kjl->kil", K1, imgs)
kxy = nx.einsum("...ij,klj->kli", K2, kx)
return kxy

KU = convol_imgs(U)
for ii in range(numItermax):
V = bar[None] / KU
KV = convol_imgs(V)
U = A / KV
KU = convol_imgs(U)
bar = c * nx.exp(
nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)
)

for _ in range(10):
c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5

if ii % 10 == 9:
err = nx.sum(nx.std(V * KU, axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))

# debiased Sinkhorn does not converge monotonically
# guarantee a few iterations are done before stopping
if err < stopThr and ii > 20:
break
else:
if warn:
warnings.warn("Sinkhorn did not converge. You might want to "
"increase the number of iterations numItermax "
"or the regularization parameter reg.")
if log:
log['niter'] = ii
log['U'] = U
return bar, log
else:
return bar

def _convolutional_barycenter2d_debiased_log(A, reg, weights=None, numItermax=10000,
stopThr=1e-3, stabThr=1e-30, verbose=False,
log=False, warn=True):
r"""Compute the debiased barycenter of 2D images in log-domain.
"""

A = list_to_array(A)
n_hists, width, height = A.shape
nx = get_backend(A)
if nx.__name__ in ("jax", "tf"):
raise NotImplementedError(
"Log-domain functions are not yet implemented"
" for Jax and TF. Use numpy or torch arrays instead."
)
if weights is None:
weights = nx.ones((n_hists,), type_as=A) / n_hists
else:
assert (len(weights) == A.shape[0])

if log:
log = {'err': []}

err = 1
# build the convolution operator
# this is equivalent to blurring on horizontal then vertical directions
t = nx.linspace(0, 1, width, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M1 = - (X - Y) ** 2 / reg

t = nx.linspace(0, 1, height, type_as=A)
[Y, X] = nx.meshgrid(t, t)
M2 = - (X - Y) ** 2 / reg

def convol_img(log_img):
log_img = nx.logsumexp(M1[:, :, None] + log_img[None], axis=1)
log_img = nx.logsumexp(M2[:, :, None] + log_img.T[None], axis=1).T
return log_img

logA = nx.log(A + stabThr)
log_bar, c = nx.zeros((2, width, height), type_as=A)
log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A)
err = 1
for ii in range(numItermax):
log_bar = nx.zeros((width, height), type_as=A)
for k in range(n_hists):
f = logA[k] - convol_img(G[k])
log_KU[k] = convol_img(f)
log_bar = log_bar + weights[k] * log_KU[k]
log_bar += c
for _ in range(10):
c = 0.5 * (c + log_bar - convol_img(c))

if ii % 10 == 9:
err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0))
# log and verbose print
if log:
log['err'].append(err)

if verbose:
if ii % 200 == 0:
print('{:5s}|{:12s}'.format(
'It.', 'Err') + '\n' + '-' * 19)
print('{:5d}|{:8e}|'.format(ii, err))
if err < stopThr and ii > 20:
break
G = log_bar[None, :, :] - log_KU

else:
if warn:
warnings.warn("Convolutional Sinkhorn did not converge. "
"Try a larger number of iterations numItermax "
"or a larger entropy reg.")
if log:
log['niter'] = ii
return nx.exp(log_bar), log
else:
return nx.exp(log_bar)