Source code for ot.batch._linear

# -*- coding: utf-8 -*-
"""
Batch operations for linear optimal transport.
"""

# Author: Remi Flamary <remi.flamary@unice.fr>
#         Paul Krzakala <paul.krzakala@gmail.com>
#
# License: MIT License

from ..backend import get_backend
from ..utils import OTResult
from ._utils import (
    bregman_log_projection_batch,
    bregman_projection_batch,
    entropy_batch,
)


def dist_lp_batch(X, Y, p=2, q=1, nx=None):
    r"""Computes the cost matrix for a batch of samples using the Lp norm.

    .. math::
        M_{bij} = ( \sum_{d} (x_{bid} - y_{bjd})^p )^{q/p} = ||x_{bi} - y_{bj}||_p^q

    Parameters
    ----------
    X : array-like, shape (B, ns, d)
        Samples from source distribution
    Y : array-like, shape (B, nt, d)
        Samples from target distribution
    p : float, optional
        The order of the norm to use. Default is 2 (Euclidean distance).
    q : float, optional
        If None, use the Lp norm. If specified, it computes the Lp norm raised to the power of q.
    nx : backend, optional
        Backend to use for computations. If None, it will be inferred from the inputs.

    Returns
    -------
    M : array-like, shape (B, ns, nt)
        Cost matrix where M[bij] is the cost between sample i in batch b and sample j in batch b.
    """

    if nx is None:
        nx = get_backend(X, Y)
    M = nx.abs(X[:, :, None] - Y[:, None, :]) ** p
    M = M.sum(axis=-1)
    if q != p:
        M = M ** (q / p)
    return M


def dist_euclidean_batch(X, Y, squared=True, nx=None):
    r"""Computes the squared Euclidean cost matrix for a batch of samples.

    .. math::
        M_{bij} = \sum_{d} (x_{bid} - y_{bjd})^2 = ||x_{bi} - y_{bj}||_2^2

    Parameters
    ----------
    X : array-like, shape (B, ns, d)
        Samples from source distribution
    Y : array-like, shape (B, nt, d)
        Samples from target distribution
    squared : bool, optional
        If True, returns the squared Euclidean distance. Default is True.
    nx : backend, optional
        Backend to use for computations. If None, it will be inferred from the inputs.

    Returns
    -------
    M : array-like, shape (B, ns, nt)
        Cost matrix where M[bij] is the cost between sample i in batch b and sample j in batch b.
    """

    if nx is None:
        nx = get_backend(X, Y)
    XX = nx.sum(X**2, axis=-1, keepdims=True)
    YY = nx.sum(Y**2, axis=-1, keepdims=True)
    YY = nx.transpose(YY, axes=(0, 2, 1))
    M = XX + YY - 2 * nx.einsum("bid,bjd->bij", X, Y)
    if not squared:
        M = nx.sqrt(M)
    return M


def dist_kl_batch(X, Y, logits_X=False, nx=None, eps=1e-10):
    r"""Computes the KL divergence cost matrix for a batch of samples.

    .. math::
        M_{bij} = \sum_{d} y_{bjd} * log(y_{bjd}/X_{bid}) = KL(y_{bj} || x_{bi})

    Parameters
    ----------
    X : array-like, shape (B, ns, d)
        Samples from source distribution
    Y : array-like, shape (B, nt, d)
        Samples from target distribution
    logits_X : bool, optional
        If True, X is assumed to be in log space (logits). Default is False.
    nx : backend, optional
        Backend to use for computations. If None, it will be inferred from the inputs.

    Returns
    -------
    M : array-like, shape (B, ns, nt)
        Cost matrix where M[bij] is the cost between sample i in batch b and sample j in batch b.
    """

    if nx is None:
        nx = get_backend(X, Y)
    entr_y = nx.sum(Y * nx.log(Y + eps), axis=-1)  # B x m
    if logits_X:
        M = entr_y[:, None, :] - Y[:, None, :] * X[:, :, None]
    else:
        M = entr_y[:, None, :] - nx.sum(
            Y[:, None, :] * nx.log(X + eps)[:, :, None], axis=-1
        )
    return M


[docs] def loss_linear_batch(M, T, nx=None): r"""Computes the linear optimal transport loss given a batch cost matrices and transport plans. .. math:: L(T, M)_b = \langle T_b, M_b \rangle_F Parameters ---------- M : array-like, shape (B, ns, nt) Cost matrix T : array-like, shape (B, ns, nt) Transport plan Returns ------- loss : array-like, shape (B,) Loss value for each batch element See Also -------- ot.batch.dist_batch : batched cost matrix computation for computing M. ot.batch.solve_batch : solver for computing the optimal T. """ if nx is None: nx = get_backend(M, T) return nx.sum(M * T, axis=(1, 2))
[docs] def loss_linear_samples_batch(X, Y, T, metric="l2"): r"""Computes the linear optimal transport loss given samples and transport plan. This is the equivalent of calling `dist_batch` and then `loss_linear_batch`. Parameters ---------- X : array-like, shape (B, ns, d) Samples from source distribution Y : array-like, shape (B, nt, d) Samples from target distribution T : array-like, shape (B, ns, nt) Transport plan metric : str, optional 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' Returns ------- loss : array-like, shape (B,) Loss value for each batch element See Also -------- ot.batch.dist_batch : batched cost matrix computation for computing M. ot.batch.solve_batch : solver for computing the optimal T. """ M = dist_batch(X, Y, metric=metric) return loss_linear_batch(M, T)
[docs] def dist_batch( X1, X2=None, metric="sqeuclidean", p=2, nx=None, ): r"""Batched version of ot.dist, use it to compute many distance matrices in parallel. Parameters ---------- X1 : array-like, shape (b,n1,d) `b` matrices with `n1` samples of size `d` X2 : array-like, shape (b,n2,d), optional `b` matrices with `n2` samples of size `d` (if None then :math:`\mathbf{X_2} = \mathbf{X_1}`) metric : str, optional 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' p : float, optional p-norm for the Minkowski metrics. Default value is 2. nx : Backend, optional Backend to perform computations on. If omitted, the backend defaults to that of `x1`. Returns ------- M : array-like, shape (`b`, `n1`, `n2`) distance matrix computed with given metric Examples -------- >>> import numpy as np >>> from ot.batch import dist_batch >>> X1 = np.random.randn(5, 10, 3) >>> X2 = np.random.randn(5, 15, 3) >>> M = dist_batch(X1, X2, metric="euclidean") >>> M.shape (5, 10, 15) See Also -------- ot.dist : equivalent non-batched function. """ X2 = X2 if X2 is not None else X1 metric = metric.lower() if metric == "sqeuclidean": M = dist_euclidean_batch(X1, X2, squared=True, nx=nx) elif metric == "euclidean": M = dist_euclidean_batch(X1, X2, squared=False, nx=nx) elif metric == "minkowski": M = dist_lp_batch(X1, X2, p=p, q=1, nx=nx) elif metric == "kl": M = dist_kl_batch(X1, X2, logits_X=False, nx=nx) else: raise ValueError(f"Unknown metric: {metric}") return M
[docs] def solve_batch( M, reg, a=None, b=None, max_iter=1000, tol=1e-5, solver="log_sinkhorn", reg_type="entropy", grad="envelope", ): r"""Batched version of ot.solve, use it to solve many entropic OT problems in parallel. Parameters ---------- M : array-like, shape (B, ns, nt) Cost matrix reg : float Regularization parameter for entropic regularization a : array-like, shape (B, ns) Source distribution (optional). If None, uniform distribution is used. b : array-like, shape (B, nt) Target distribution (optional). If None, uniform distribution is used. max_iter : int Maximum number of iterations tol : float Tolerance for convergence solver: str Solver to use, either 'log_sinkhorn' or 'sinkhorn'. Default is "log_sinkhorn" which is more stable. reg_type : str, optional Type of regularization :math:`R` either "KL", or "entropy". Default is "entropy". grad : str, optional Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. 'detach' does not compute the gradients for the Sinkhorn solver. Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan See :any:`OTResult` for more information. Examples -------- >>> import numpy as np >>> from ot.batch import solve_batch, dist_batch >>> X = np.random.randn(5, 10, 3) # 5 batches of 10 samples in 3D >>> Y = np.random.randn(5, 15, 3) # 5 batches of 15 samples in 3D >>> M = dist_batch(X, Y, metric="euclidean") # Compute cost matrices >>> reg = 0.1 >>> result = solve_batch(M, reg) >>> result.plan.shape # Optimal transport plans for each batch (5, 10, 15) >>> result.value.shape # Optimal transport values for each batch (5,) See Also -------- ot.batch.dist_batch : batched cost matrix computation for computing M. ot.solve : non-batched version of the OT solver. """ nx = get_backend(a, b, M) B, n, m = M.shape if a is None: a = nx.ones((B, n)) / n if b is None: b = nx.ones((B, m)) / m if solver == "log_sinkhorn": K = -M / reg out = bregman_log_projection_batch( K, a, b, nx=nx, max_iter=max_iter, tol=tol, grad=grad ) elif solver == "sinkhorn": K = nx.exp(-M / reg) out = bregman_projection_batch( K, a, b, nx=nx, max_iter=max_iter, tol=tol, grad=grad ) else: raise ValueError(f"Unknown solver: {solver}") T = out["T"] if grad is None or grad == "detach": T = nx.detach(T) M = nx.detach(M) elif grad == "envelope": T = nx.detach(T) value_linear = loss_linear_batch(M, T) if reg_type.lower() == "entropy": entr = -entropy_batch(T, nx=nx) value = value_linear + reg * entr elif reg_type.lower() == "kl": ref = nx.einsum("bi,bj->bij", a, b) kl = nx.sum(T * nx.log(T / ref + 1e-16), axis=(1, 2)) value = value_linear + reg * kl log = {"n_iter": out["n_iters"]} res = OTResult( value=value, value_linear=value_linear, potentials=out["potentials"], plan=T, backend=nx, log=log, ) return res
[docs] def solve_sample_batch( X_a, X_b, reg, a=None, b=None, metric="sqeuclidean", p=2, max_iter=1000, tol=1e-5, solver="log_sinkhorn", reg_type="entropy", grad="envelope", ): r"""Batched version of ot.solve, use it to solve many entropic OT problems in parallel. Parameters ---------- M : array-like, shape (B, ns, nt) Cost matrix reg : float Regularization parameter for entropic regularization metric : str, optional 'sqeuclidean', 'euclidean', 'minkowski' or 'kl' p : float, optional p-norm for the Minkowski metrics. Default value is 2. a : array-like, shape (B, ns) Source distribution (optional). If None, uniform distribution is used. b : array-like, shape (B, nt) Target distribution (optional). If None, uniform distribution is used. max_iter : int Maximum number of iterations tol : float Tolerance for convergence solver: str Solver to use, either 'log_sinkhorn' or 'sinkhorn'. Default is "log_sinkhorn" which is more stable. reg_type : str, optional Type of regularization :math:`R` either "KL", or "entropy". Default is "entropy". grad : str, optional Type of gradient computation, either or 'autodiff', 'envelope' or 'last_step' used only for Sinkhorn solver. By default 'autodiff' provides gradients wrt all outputs (`plan, value, value_linear`) but with important memory cost. 'envelope' provides gradients only for `value` and and other outputs are detached. This is useful for memory saving when only the value is needed. 'last_step' provides gradients only for the last iteration of the Sinkhorn solver, but provides gradient for both the OT plan and the objective values. 'detach' does not compute the gradients for the Sinkhorn solver. Returns ------- res : OTResult() Result of the optimization problem. The information can be obtained as follows: - res.plan : OT plan :math:`\mathbf{T}` - res.potentials : OT dual potentials - res.value : Optimal value of the optimization problem - res.value_linear : Linear OT loss with the optimal OT plan See :any:`OTResult` for more information. See Also -------- ot.batch.solve_batch : solver for computing the optimal T from arbitrary cost matrix M. """ M = dist_batch(X_a, X_b, metric=metric, p=p) return solve_batch( M, reg, a=a, b=b, max_iter=max_iter, tol=tol, solver=solver, reg_type=reg_type, grad=grad, )