Source code for ot.batch._utils

from ot.backend import get_backend


[docs] def entropy_batch(T, nx=None, eps=1e-16): r"""Computes the entropy of a batch of transport plans T. .. math:: H(T)_b = - \sum_{i,j} T_{b,i,j} \log(T_{b,i,j}) Parameters ---------- T : array-like, shape (b,n1,n2) `b` transport of size `n1` x `n2`. eps : float, optional Small constant to avoid numerical issues with log(0). Default is 1e-16. nx : Backend, optional Backend to perform computations on. If omitted, the backend defaults to that of `T`. """ if nx is None: nx = get_backend(T) return -nx.sum(T * nx.log(T + eps), axis=(1, 2))
def norm_batch(u, p=2, nx=None): """Computes the lp norm of a batch of vectors u.""" if nx is None: nx = get_backend(u) return nx.sum(u**p, axis=1) ** (1 / p) def bmv(A, b, nx): """ Batched matrix-vector multiplication for tensor A and vector b. """ return nx.einsum("bij,bj->bi", A, b) def bop(a, b, nx): """ Batched outer product for vectors a and b. """ return nx.einsum("bi,bj->bij", a, b)
[docs] def bregman_projection_batch( K, a=None, b=None, nx=None, max_iter=10000, tol=1e-5, grad="detach" ): r""" Apply Bregman projection to a batch of affinity matrices :math:`\mathbf{K}`. The function solves the following optimization problem: .. math:: \begin{aligned} \mathbf{T} = \mathop{\arg \min}_\mathbf{T} \quad & \text{KL}(\mathbf{T} \| \mathbf{K}) \\ \text{s.t.} \quad & \mathbf{T} \mathbf{1} = \mathbf{a} \\ & \mathbf{T}^T \mathbf{1} = \mathbf{b} \\ & \mathbf{T} \geq 0 \end{aligned} This is equivalent to: .. math:: \begin{aligned} \mathbf{T} = \mathop{\arg \max}_\mathbf{T} \quad & \langle \mathbf{T}, \log(\mathbf{K}) \rangle_F \\ \text{s.t.} \quad & \mathbf{T} \mathbf{1} = \mathbf{a} \\ & \mathbf{T}^T \mathbf{1} = \mathbf{b} \\ & \mathbf{T} \geq 0 \end{aligned} The optimal solution has the form :math:`\mathbf{T} = \text{diag}(\mathbf{f}) \mathbf{K} \text{diag}(\mathbf{g})`, where the dual variables :math:`\mathbf{u}` and :math:`\mathbf{v}` are found iteratively using: .. math:: \mathbf{f}^{(k+1)} = \frac{\mathbf{a}}{\sum \mathbf{K} \mathbf{g}^{(k)}} \mathbf{g}^{(k+1)} = \frac{\mathbf{b}}{\sum \mathbf{K}^T \mathbf{f}^{(k)}} Parameters ---------- K : array-like, shape (B, n, m) Affinity matrix for each problem in the batch. a : array-like, shape (B, n), optional Source distribution for each problem. If None, uniform distribution is used. b : array-like, shape (B, m), optional Target distribution for each problem. If None, uniform distribution is used. nx : backend object, optional Numerical backend to use for computations. If None, the default backend is used. max_iter : int, optional Maximum number of iterations. tol : float, optional Tolerance for convergence. The solver stops when the maximum change in the dual variables is below this value. grad : str, optional Gradient computation mode: 'detach', 'autodiff', or 'last_step'. Returns ------- dict Dictionary containing: - 'T' : array-like, shape (B, n, m) Optimal transport plan for each problem. - 'potentials' : tuple of array-like, shapes ((B, n), (B, m)) Scaling factors (f, g). - 'n_iters' : int Number of iterations performed. Examples -------- >>> import numpy as np >>> from ot.batch import bregman_projection_batch >>> # Create batch of affinity matrices >>> K = np.random.randn(5, 10, 15) # 5 problems, 10x15 cost matrices >>> result = bregman_projection_batch(K) >>> T = result['T'] # Shape (5, 10, 15) See Also -------- ot.batch.bregman_log_projection_batch : Bregman projection in the log-domain. """ if nx is None: nx = get_backend(a, b, K) B, n, m = K.shape if a is None: a = nx.ones((B, n)) / n if b is None: b = nx.ones((B, m)) / m if grad == "detach": K = nx.detach(K) elif grad == "last_step": K_, K = K.clone(), nx.detach(K) f = nx.ones((B, n)) # a / nx.sum(K, axis=2) g = nx.ones((B, m)) # b / nx.sum(K, axis=1) for n_iters in range(max_iter): f = a / nx.sum(K * g[:, None, :], axis=2) g = b / nx.sum(K * f[:, :, None], axis=1) if n_iters % 10 == 0: T = K * f[:, :, None] * g[:, None, :] marginal = nx.sum(T, axis=2) err = nx.max(norm_batch(marginal - a)) if err < tol: break if grad == "last_step": summand_f = nx.sum(K_ * g[:, None, :], axis=2) f = a / summand_f summand_g = nx.sum(K_ * f[:, :, None], axis=1) g = b / summand_g T = K * f[:, :, None] * g[:, None, :] potentials = (f, g) out = {"T": T, "n_iters": n_iters, "potentials": potentials} return out
[docs] def bregman_log_projection_batch( K, a=None, b=None, nx=None, max_iter=10000, tol=1e-5, grad="detach" ): r""" Apply Bregman projection to a batch of affinity matrices :math:`\exp(\mathbf{K})`. The function solves the following optimization problem: .. math:: \begin{aligned} \mathbf{T} = \mathop{\arg \min}_\mathbf{T} \quad & \text{KL}(\mathbf{T} \| \exp(\mathbf{K})) \\ \text{s.t.} \quad & \mathbf{T} \mathbf{1} = \mathbf{a} \\ & \mathbf{T}^T \mathbf{1} = \mathbf{b} \\ & \mathbf{T} \geq 0 \end{aligned} This is equivalent to: .. math:: \begin{aligned} \mathbf{T} = \mathop{\arg \max}_\mathbf{T} \quad & \langle \mathbf{T}, \mathbf{K} \rangle_F \\ \text{s.t.} \quad & \mathbf{T} \mathbf{1} = \mathbf{a} \\ & \mathbf{T}^T \mathbf{1} = \mathbf{b} \\ & \mathbf{T} \geq 0 \end{aligned} The optimal solution has the form :math:`\mathbf{T} = \text{diag}(\exp(\mathbf{u})) \exp(\mathbf{K}) \text{diag}(\exp(\mathbf{v}))`, where the dual variables :math:`\mathbf{u}` and :math:`\mathbf{v}` are found iteratively using: .. math:: \mathbf{u}^{(k+1)} = \log(\mathbf{a}) - \text{LSE}(\mathbf{K} + \mathbf{v}^{(k)}) \mathbf{v}^{(k+1)} = \log(\mathbf{b}) - \text{LSE}(\mathbf{K}^T + \mathbf{u}^{(k)}) where LSE denotes the log-sum-exp operation. The iterations are performed using the log-sum-exp trick to avoid numerical issues. Parameters ---------- K : array-like, shape (B, n, m) Affinity matrix for each problem in the batch. a : array-like, shape (B, n), optional Source distribution for each problem. If None, uniform distribution is used. b : array-like, shape (B, m), optional Target distribution for each problem. If None, uniform distribution is used. nx : backend object, optional Numerical backend to use for computations. If None, the default backend is used. max_iter : int, optional Maximum number of iterations. tol : float, optional Tolerance for convergence. The solver stops when the maximum change in the dual variables is below this value. grad : str, optional Gradient computation mode: 'detach', 'autodiff', or 'last_step'. Returns ------- dict Dictionary containing: - 'T' : array-like, shape (B, n, m) Optimal transport plan for each problem. - 'log_T' : array-like, shape (B, n, m) Logarithm of the optimal transport plan for each problem. - 'potentials' : tuple of array-like, shapes ((B, n), (B, m)) Log-scaling factors (u, v). - 'n_iters' : int Number of iterations performed. Examples -------- >>> import numpy as np >>> from ot.batch import bregman_log_projection_batch >>> # Create batch of affinity matrices >>> K = np.random.randn(5, 10, 15) # 5 problems, 10x15 cost matrices >>> result = bregman_log_projection_batch(K) >>> T = result['T'] # Shape (5, 10, 15) See Also -------- ot.batch.bregman_projection_batch : standard Bregman projection. """ if nx is None: nx = get_backend(a, b, K) B, n, m = K.shape if a is None: a = nx.ones((B, n)) / n if b is None: b = nx.ones((B, m)) / m u = nx.zeros((B, n), type_as=K) # u = nx.log(a) - nx.logsumexp(K, axis=2).squeeze() v = nx.zeros((B, m), type_as=K) # v = nx.log(b) - nx.logsumexp(K, axis=1).squeeze() loga = nx.log(a) logb = nx.log(b) if grad == "detach": K = nx.detach(K) elif grad == "last_step": K_, K = K.clone(), nx.detach(K) for n_iters in range(max_iter): u = loga - nx.logsumexp(K + v[:, None, :], axis=2) v = logb - nx.logsumexp(K + u[:, :, None], axis=1) # Check convergence once every 10 iterations if n_iters % 10 == 0: T = nx.exp(K + u[:, :, None] + v[:, None, :]) marginal = nx.sum(T, axis=2) err = nx.max(norm_batch(marginal - a)) if err < tol: break if grad == "last_step": summand_u = K_ + v[:, None, :] u = loga - nx.logsumexp(summand_u, axis=2) summand_v = K_ + u[:, :, None] v = logb - nx.logsumexp(summand_v, axis=1) log_T = K + u[:, :, None] + v[:, None, :] T = nx.exp(log_T) log_potentials = (u, v) return {"T": T, "log_T": log_T, "n_iters": n_iters, "potentials": log_potentials}