Source code for ot.bregman._convolutional

# -*- coding: utf-8 -*-
"""
Bregman projections solvers for entropic regularized Wasserstein convolutional barycenters
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Nicolas Courty <ncourty@irisa.fr>
#
# License: MIT License

import warnings

from ..backend import get_backend
from ..utils import list_to_array

_warning_msg = (
    "Convolutional Sinkhorn did not converge. "
    "Try a larger number of iterations `numItermax` "
    "or a larger entropy `reg`."
)


def _get_convol_img_fn(nx, width, height, reg, type_as, log_domain=False):
    """Return the convolution operator for 2D images.

    The function constructed is equivalent to blurring on horizontal then vertical directions."""
    t1 = nx.linspace(0, 1, width, type_as=type_as)
    Y1, X1 = nx.meshgrid(t1, t1)
    M1 = -((X1 - Y1) ** 2) / reg

    t2 = nx.linspace(0, 1, height, type_as=type_as)
    Y2, X2 = nx.meshgrid(t2, t2)
    M2 = -((X2 - Y2) ** 2) / reg

    # If normal domain is selected, we can use M1 and M2 to compute the convolution
    if not log_domain:
        K1, K2 = nx.exp(M1), nx.exp(M2)

        def convol_imgs(imgs):
            kx = nx.einsum("...ij,kjl->kil", K1, imgs)
            kxy = nx.einsum("...ij,klj->kli", K2, kx)
            return kxy

    # Else, we can use M1 and M2 to compute the convolution in log-domain
    else:

        def convol_imgs(log_imgs):
            log_imgs = nx.logsumexp(M1[:, :, None] + log_imgs[None], axis=1)
            log_imgs = nx.logsumexp(M2[:, :, None] + log_imgs.T[None], axis=1).T
            return log_imgs

    return convol_imgs


def _print_report(ii, err):
    """Print the report of the iteration."""
    if ii % 200 == 0:
        print("{:5s}|{:12s}".format("It.", "Err") + "\n" + "-" * 19)
    print("{:5d}|{:8e}|".format(ii, err))


[docs] def convolutional_barycenter2d( A, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-4, verbose=False, log=False, warn=True, **kwargs, ): r"""Compute the entropic regularized wasserstein barycenter of distributions :math:`\mathbf{A}` where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i W_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`W_{reg}(\cdot,\cdot)` is the entropic regularized Wasserstein distance (see :py:func:`ot.bregman.sinkhorn`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the Sinkhorn-Knopp matrix scaling algorithm as proposed in :ref:`[21] <references-convolutional-barycenter-2d>` Parameters ---------- A : array-like, shape (n_hists, width, height) `n` distributions (2D images) of size `width` x `height` reg : float Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coordinates) method : string, optional method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : array-like, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-convolutional-barycenter-2d: References ---------- .. [21] Solomon, J., De Goes, F., Peyré, G., Cuturi, M., Butscher, A., Nguyen, A. & Guibas, L. (2015). Convolutional wasserstein distances: Efficient optimal transportation on geometric domains. ACM Transactions on Graphics (TOG), 34(4), 66 .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == "sinkhorn": return _convolutional_barycenter2d( A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) elif method.lower() == "sinkhorn_log": return _convolutional_barycenter2d_log( A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) else: raise ValueError("Unknown method '%s'." % method)
def _convolutional_barycenter2d( A, reg, weights=None, numItermax=10000, stopThr=1e-9, stabThr=1e-30, verbose=False, log=False, warn=True, ): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images. """ A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: assert len(weights) == n_hists if log: log = {"err": []} bar = nx.ones((width, height), type_as=A) bar /= nx.sum(bar) U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) err = 1 # build the convolution operator convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): V = bar[None] / KU KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) bar = nx.exp(nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)) if ii % 10 == 9: err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log["err"].append(err) if verbose: _print_report(ii, err) if err < stopThr: break else: if warn: warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U log["V"] = V return bar, log else: return bar def _convolutional_barycenter2d_log( A, reg, weights=None, numItermax=10000, stopThr=1e-4, stabThr=1e-30, verbose=False, log=False, warn=True, ): r"""Compute the entropic regularized wasserstein barycenter of distributions A where A is a collection of 2D images in log-domain. """ A = list_to_array(A) nx = get_backend(A) # This error is raised because we are using mutable assignment in the line # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" " for Jax and TF. Use numpy or torch arrays instead." ) n_hists, width, height = A.shape if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: assert len(weights) == n_hists if log: log = {"err": []} err = 1 # build the convolution operator convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) err = 1 for ii in range(numItermax): log_bar = nx.zeros((width, height), type_as=A) for k in range(n_hists): f = logA[k] - convol_img(G[k]) log_KU[k] = convol_img(f) log_bar = log_bar + weights[k] * log_KU[k] if ii % 10 == 9: err = nx.exp(G + log_KU).std(axis=0).sum() # log and verbose print if log: log["err"].append(err) if verbose: _print_report(ii, err) if err < stopThr: break G = log_bar[None, :, :] - log_KU else: if warn: warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar)
[docs] def convolutional_barycenter2d_debiased( A, reg, weights=None, method="sinkhorn", numItermax=10000, stopThr=1e-3, verbose=False, log=False, warn=True, **kwargs, ): r"""Compute the debiased sinkhorn barycenter of distributions :math:`\mathbf{A}` where :math:`\mathbf{A}` is a collection of 2D images. The function solves the following optimization problem: .. math:: \mathbf{a} = \mathop{\arg \min}_\mathbf{a} \quad \sum_i S_{reg}(\mathbf{a},\mathbf{a}_i) where : - :math:`S_{reg}(\cdot,\cdot)` is the debiased entropic regularized Wasserstein distance (see :py:func:`ot.bregman.barycenter_debiased`) - :math:`\mathbf{a}_i` are training distributions (2D images) in the mast two dimensions of matrix :math:`\mathbf{A}` - `reg` is the regularization strength scalar value The algorithm used for solving the problem is the debiased Sinkhorn scaling algorithm as proposed in :ref:`[37] <references-convolutional-barycenter2d-debiased>` Parameters ---------- A : array-like, shape (n_hists, width, height) `n` distributions (2D images) of size `width` x `height` reg : float Regularization term >0 weights : array-like, shape (n_hists,) Weights of each image on the simplex (barycentric coordinates) method : string, optional method used for the solver either 'sinkhorn' or 'sinkhorn_log' numItermax : int, optional Max number of iterations stopThr : float, optional Stop threshold on error (> 0) stabThr : float, optional Stabilization threshold to avoid numerical precision issue verbose : bool, optional Print information along iterations log : bool, optional record log if True warn : bool, optional if True, raises a warning if the algorithm doesn't convergence. Returns ------- a : array-like, shape (width, height) 2D Wasserstein barycenter log : dict log dictionary return only if log==True in parameters .. _references-convolutional-barycenter2d-debiased: References ---------- .. [37] Janati, H., Cuturi, M., Gramfort, A. Proceedings of the 37th International Conference on Machine Learning, PMLR 119:4692-4701, 2020 """ if method.lower() == "sinkhorn": return _convolutional_barycenter2d_debiased( A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) elif method.lower() == "sinkhorn_log": return _convolutional_barycenter2d_debiased_log( A, reg, weights=weights, numItermax=numItermax, stopThr=stopThr, verbose=verbose, log=log, warn=warn, **kwargs, ) else: raise ValueError("Unknown method '%s'." % method)
def _convolutional_barycenter2d_debiased( A, reg, weights=None, numItermax=10000, stopThr=1e-3, stabThr=1e-15, verbose=False, log=False, warn=True, ): r"""Compute the debiased barycenter of 2D images via sinkhorn convolutions.""" A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: assert len(weights) == n_hists if log: log = {"err": []} bar = nx.ones((width, height), type_as=A) bar /= width * height U = nx.ones(A.shape, type_as=A) V = nx.ones(A.shape, type_as=A) c = nx.ones((width, height), type_as=A) err = 1 # build the convolution operator convol_imgs = _get_convol_img_fn(nx, width, height, reg, type_as=A) KU = convol_imgs(U) for ii in range(numItermax): V = bar[None] / KU KV = convol_imgs(V) U = A / KV KU = convol_imgs(U) bar = c * nx.exp(nx.sum(weights[:, None, None] * nx.log(KU + stabThr), axis=0)) for _ in range(10): c = (c * bar / nx.squeeze(convol_imgs(c[None]))) ** 0.5 if ii % 10 == 9: err = nx.sum(nx.std(V * KU, axis=0)) # log and verbose print if log: log["err"].append(err) if verbose: _print_report(ii, err) # debiased Sinkhorn does not converge monotonically # guarantee a few iterations are done before stopping if err < stopThr and ii > 20: break else: if warn: warnings.warn(_warning_msg) if log: log["niter"] = ii log["U"] = U log["V"] = V return bar, log else: return bar def _convolutional_barycenter2d_debiased_log( A, reg, weights=None, numItermax=10000, stopThr=1e-3, stabThr=1e-30, verbose=False, log=False, warn=True, ): r"""Compute the debiased barycenter of 2D images in log-domain.""" A = list_to_array(A) n_hists, width, height = A.shape nx = get_backend(A) # This error is raised because we are using mutable assignment in the line # `log_KU[k] = ...` which is not allowed in Jax and TF. if nx.__name__ in ("jax", "tf"): raise NotImplementedError( "Log-domain functions are not yet implemented" " for Jax and TF. Use numpy or torch arrays instead." ) if weights is None: weights = nx.ones((n_hists,), type_as=A) / n_hists else: assert len(weights) == A.shape[0] if log: log = {"err": []} err = 1 # build the convolution operator convol_img = _get_convol_img_fn(nx, width, height, reg, type_as=A, log_domain=True) logA = nx.log(A + stabThr) log_bar, c = nx.zeros((2, width, height), type_as=A) log_KU, G, F = nx.zeros((3, *logA.shape), type_as=A) err = 1 for ii in range(numItermax): log_bar = nx.zeros((width, height), type_as=A) for k in range(n_hists): f = logA[k] - convol_img(G[k]) log_KU[k] = convol_img(f) log_bar = log_bar + weights[k] * log_KU[k] log_bar += c for _ in range(10): c = 0.5 * (c + log_bar - convol_img(c)) if ii % 10 == 9: err = nx.sum(nx.std(nx.exp(G + log_KU), axis=0)) # log and verbose print if log: log["err"].append(err) if verbose: _print_report(ii, err) if err < stopThr and ii > 20: break G = log_bar[None, :, :] - log_KU else: if warn: warnings.warn(_warning_msg) if log: log["niter"] = ii return nx.exp(log_bar), log else: return nx.exp(log_bar)